mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-17 16:09:13 +00:00
Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ec5bd236f8 | ||
|
|
a51cdb04e6 | ||
|
|
f50d5df3d0 | ||
|
|
4ac2ccf473 | ||
|
|
5f7a72a553 | ||
|
|
643b004727 | ||
|
|
72e0415184 | ||
|
|
28cb558d10 | ||
|
|
23806b0db1 | ||
|
|
6a80499823 | ||
|
|
110f36bdf9 | ||
|
|
f85426022d | ||
|
|
78fd0cbee5 | ||
|
|
0d59065719 | ||
|
|
6110e2d6e2 | ||
|
|
275b8c38a2 | ||
|
|
fd1244c471 | ||
|
|
f11d294825 | ||
|
|
22b702fcda | ||
|
|
831817a737 | ||
|
|
7329d9f2fb | ||
|
|
3aad1d5d79 | ||
|
|
f72c599d2d |
10
.github/workflows/go.yml
vendored
10
.github/workflows/go.yml
vendored
@@ -28,7 +28,13 @@ jobs:
|
||||
- name: Test
|
||||
run: go test -v ./...
|
||||
|
||||
- if: matrix.os == 'ubuntu-latest'
|
||||
name: Update coverage report
|
||||
- name: Test data races
|
||||
run: go test -v -race ./...
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
|
||||
- name: Update coverage report
|
||||
uses: ncruces/go-coverage-report@main
|
||||
if: |
|
||||
matrix.os == 'ubuntu-latest' &&
|
||||
github.event_name == 'push'
|
||||
continue-on-error: true
|
||||
|
||||
11
README.md
11
README.md
@@ -6,13 +6,16 @@
|
||||
|
||||
⚠️ CAUTION ⚠️
|
||||
|
||||
This is still very much a WIP.\
|
||||
DO NOT USE this with data you care about.
|
||||
This is a WIP.\
|
||||
DO NOT USE with data you care about.
|
||||
|
||||
Roadmap:
|
||||
- [x] build SQLite using `zig cc --target=wasm32-wasi`
|
||||
- [x] `:memory:` databases
|
||||
- [x] port [`test_demovfs.c`](https://www.sqlite.org/src/doc/trunk/src/test_demovfs.c) to Go
|
||||
- branch [`wasi`](https://github.com/ncruces/go-sqlite3/tree/wasi) uses `test_demovfs.c` directly
|
||||
- [x] come up with a simple, nice API, enough for simple queries
|
||||
- [x] file locking, compatible with SQLite on Windows/Unix
|
||||
- [x] design a simple, nice API, enough for simple use cases
|
||||
- [x] provide a simple `database/sql` driver
|
||||
- [x] file locking, compatible with SQLite on Windows/Unix
|
||||
- [ ] shared memory, compatible with SQLite on Windows/Unix
|
||||
- needed for improved WAL mode
|
||||
25
api.go
25
api.go
@@ -1,3 +1,4 @@
|
||||
// Package sqlite3 wraps the C SQLite API.
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
@@ -44,18 +45,26 @@ func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) {
|
||||
step: getFun("sqlite3_step"),
|
||||
exec: getFun("sqlite3_exec"),
|
||||
clearBindings: getFun("sqlite3_clear_bindings"),
|
||||
bindCount: getFun("sqlite3_bind_parameter_count"),
|
||||
bindIndex: getFun("sqlite3_bind_parameter_index"),
|
||||
bindName: getFun("sqlite3_bind_parameter_name"),
|
||||
bindNull: getFun("sqlite3_bind_null"),
|
||||
bindInteger: getFun("sqlite3_bind_int64"),
|
||||
bindFloat: getFun("sqlite3_bind_double"),
|
||||
bindText: getFun("sqlite3_bind_text64"),
|
||||
bindBlob: getFun("sqlite3_bind_blob64"),
|
||||
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
|
||||
bindNull: getFun("sqlite3_bind_null"),
|
||||
columnCount: getFun("sqlite3_column_count"),
|
||||
columnName: getFun("sqlite3_column_name"),
|
||||
columnType: getFun("sqlite3_column_type"),
|
||||
columnInteger: getFun("sqlite3_column_int64"),
|
||||
columnFloat: getFun("sqlite3_column_double"),
|
||||
columnText: getFun("sqlite3_column_text"),
|
||||
columnBlob: getFun("sqlite3_column_blob"),
|
||||
columnBytes: getFun("sqlite3_column_bytes"),
|
||||
columnType: getFun("sqlite3_column_type"),
|
||||
lastRowid: getFun("sqlite3_last_insert_rowid"),
|
||||
changes: getFun("sqlite3_changes64"),
|
||||
interrupt: getFun("sqlite3_interrupt"),
|
||||
},
|
||||
}
|
||||
if err != nil {
|
||||
@@ -80,16 +89,24 @@ type sqliteAPI struct {
|
||||
step api.Function
|
||||
exec api.Function
|
||||
clearBindings api.Function
|
||||
bindNull api.Function
|
||||
bindCount api.Function
|
||||
bindIndex api.Function
|
||||
bindName api.Function
|
||||
bindInteger api.Function
|
||||
bindFloat api.Function
|
||||
bindText api.Function
|
||||
bindBlob api.Function
|
||||
bindZeroBlob api.Function
|
||||
bindNull api.Function
|
||||
columnCount api.Function
|
||||
columnName api.Function
|
||||
columnType api.Function
|
||||
columnInteger api.Function
|
||||
columnFloat api.Function
|
||||
columnText api.Function
|
||||
columnBlob api.Function
|
||||
columnBytes api.Function
|
||||
columnType api.Function
|
||||
lastRowid api.Function
|
||||
changes api.Function
|
||||
interrupt api.Function
|
||||
}
|
||||
|
||||
15
compile.go
15
compile.go
@@ -24,13 +24,11 @@ type sqlite3Runtime struct {
|
||||
runtime wazero.Runtime
|
||||
compiled wazero.CompiledModule
|
||||
instances atomic.Uint64
|
||||
ctx context.Context
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *sqlite3Runtime) instantiateModule(ctx context.Context) (api.Module, error) {
|
||||
s.ctx = ctx
|
||||
s.once.Do(s.compileModule)
|
||||
s.once.Do(func() { s.compileModule(ctx) })
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
@@ -40,12 +38,9 @@ func (s *sqlite3Runtime) instantiateModule(ctx context.Context) (api.Module, err
|
||||
return s.runtime.InstantiateModule(ctx, s.compiled, cfg)
|
||||
}
|
||||
|
||||
func (s *sqlite3Runtime) compileModule() {
|
||||
s.runtime = wazero.NewRuntime(s.ctx)
|
||||
s.err = vfsInstantiate(s.ctx, s.runtime)
|
||||
if s.err != nil {
|
||||
return
|
||||
}
|
||||
func (s *sqlite3Runtime) compileModule(ctx context.Context) {
|
||||
s.runtime = wazero.NewRuntime(ctx)
|
||||
vfsInstantiate(ctx, s.runtime)
|
||||
|
||||
bin := Binary
|
||||
if bin == nil && Path != "" {
|
||||
@@ -55,5 +50,5 @@ func (s *sqlite3Runtime) compileModule() {
|
||||
}
|
||||
}
|
||||
|
||||
s.compiled, s.err = s.runtime.CompileModule(s.ctx, bin)
|
||||
s.compiled, s.err = s.runtime.CompileModule(ctx, bin)
|
||||
}
|
||||
|
||||
206
conn.go
206
conn.go
@@ -13,6 +13,11 @@ type Conn struct {
|
||||
api sqliteAPI
|
||||
mem memory
|
||||
handle uint32
|
||||
|
||||
arena arena
|
||||
pending *Stmt
|
||||
waiter chan struct{}
|
||||
done <-chan struct{}
|
||||
}
|
||||
|
||||
// Open calls [OpenFlags] with [OPEN_READWRITE] and [OPEN_CREATE].
|
||||
@@ -39,15 +44,15 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.arena = c.newArena(1024)
|
||||
|
||||
namePtr := c.newString(filename)
|
||||
connPtr := c.new(ptrlen)
|
||||
defer c.free(namePtr)
|
||||
defer c.free(connPtr)
|
||||
defer c.arena.reset()
|
||||
connPtr := c.arena.new(ptrlen)
|
||||
namePtr := c.arena.string(filename)
|
||||
|
||||
r, err := c.api.open.Call(c.ctx, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
panic(err)
|
||||
}
|
||||
|
||||
c.handle = c.mem.readUint32(connPtr)
|
||||
@@ -65,13 +70,15 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/close.html
|
||||
func (c *Conn) Close() error {
|
||||
if c == nil {
|
||||
if c == nil || c.handle == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.SetInterrupt(nil)
|
||||
|
||||
r, err := c.api.close.Call(c.ctx, uint64(c.handle))
|
||||
if err != nil {
|
||||
return err
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := c.error(r[0]); err != nil {
|
||||
@@ -82,17 +89,80 @@ func (c *Conn) Close() error {
|
||||
return c.mem.mod.Close(c.ctx)
|
||||
}
|
||||
|
||||
// SetInterrupt interrupts a long-running query when done is closed.
|
||||
//
|
||||
// Subsequent uses of the connection will return [INTERRUPT]
|
||||
// until done is reset by another call to SetInterrupt.
|
||||
//
|
||||
// Typically, done is provided by [context.Context.Done]:
|
||||
//
|
||||
// ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
|
||||
// conn.SetInterrupt(ctx.Done())
|
||||
// defer cancel()
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/interrupt.html
|
||||
func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) {
|
||||
// Is a waiter running?
|
||||
if c.waiter != nil {
|
||||
c.waiter <- struct{}{} // Cancel the waiter.
|
||||
<-c.waiter // Wait for it to finish.
|
||||
c.waiter = nil
|
||||
}
|
||||
|
||||
// Finalize the uncompleted SQL statement.
|
||||
if c.pending != nil {
|
||||
c.pending.Close()
|
||||
c.pending = nil
|
||||
}
|
||||
|
||||
old = c.done
|
||||
c.done = done
|
||||
if done == nil {
|
||||
return old
|
||||
}
|
||||
|
||||
// Creating an uncompleted SQL statement prevents SQLite from ignoring
|
||||
// an interrupt that comes before any other statements are started.
|
||||
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
|
||||
c.pending.Step()
|
||||
|
||||
waiter := make(chan struct{})
|
||||
c.waiter = waiter
|
||||
go func() {
|
||||
select {
|
||||
case <-waiter: // Waiter was cancelled.
|
||||
break
|
||||
|
||||
case <-done: // Done was closed.
|
||||
|
||||
// This is safe to call from a goroutine
|
||||
// because it doesn't touch the C stack.
|
||||
_, err := c.api.interrupt.Call(c.ctx, uint64(c.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Wait for the next call to SetInterrupt.
|
||||
<-waiter
|
||||
}
|
||||
|
||||
// Signal that the waiter has finished.
|
||||
waiter <- struct{}{}
|
||||
}()
|
||||
return old
|
||||
}
|
||||
|
||||
// Exec is a convenience function that allows an application to run
|
||||
// multiple statements of SQL without having to use a lot of code.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/exec.html
|
||||
func (c *Conn) Exec(sql string) error {
|
||||
sqlPtr := c.newString(sql)
|
||||
defer c.free(sqlPtr)
|
||||
defer c.arena.reset()
|
||||
sqlPtr := c.arena.string(sql)
|
||||
|
||||
r, err := c.api.exec.Call(c.ctx, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
panic(err)
|
||||
}
|
||||
return c.error(r[0])
|
||||
}
|
||||
@@ -109,18 +179,16 @@ func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/prepare.html
|
||||
func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) {
|
||||
sqlPtr := c.newString(sql)
|
||||
stmtPtr := c.new(ptrlen)
|
||||
tailPtr := c.new(ptrlen)
|
||||
defer c.free(sqlPtr)
|
||||
defer c.free(stmtPtr)
|
||||
defer c.free(tailPtr)
|
||||
defer c.arena.reset()
|
||||
stmtPtr := c.arena.new(ptrlen)
|
||||
tailPtr := c.arena.new(ptrlen)
|
||||
sqlPtr := c.arena.string(sql)
|
||||
|
||||
r, err := c.api.prepare.Call(c.ctx, uint64(c.handle),
|
||||
uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),
|
||||
uint64(stmtPtr), uint64(tailPtr))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
panic(err)
|
||||
}
|
||||
|
||||
stmt = &Stmt{c: c}
|
||||
@@ -137,6 +205,31 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
|
||||
return
|
||||
}
|
||||
|
||||
// LastInsertRowID returns the rowid of the most recent successful INSERT
|
||||
// on the database connection.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/last_insert_rowid.html
|
||||
func (c *Conn) LastInsertRowID() uint64 {
|
||||
r, err := c.api.lastRowid.Call(c.ctx, uint64(c.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r[0]
|
||||
}
|
||||
|
||||
// Changes returns the number of rows modified, inserted or deleted
|
||||
// by the most recently completed INSERT, UPDATE or DELETE statement
|
||||
// on the database connection.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/changes.html
|
||||
func (c *Conn) Changes() uint64 {
|
||||
r, err := c.api.changes.Call(c.ctx, uint64(c.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r[0]
|
||||
}
|
||||
|
||||
func (c *Conn) error(rc uint64, sql ...string) error {
|
||||
if rc == _OK {
|
||||
return nil
|
||||
@@ -150,28 +243,26 @@ func (c *Conn) error(rc uint64, sql ...string) error {
|
||||
|
||||
var r []uint64
|
||||
|
||||
// sqlite3_errmsg is guaranteed to never change the value of the error code.
|
||||
r, _ = c.api.errstr.Call(c.ctx, rc)
|
||||
if r != nil {
|
||||
err.str = c.mem.readString(uint32(r[0]), _MAX_STRING)
|
||||
}
|
||||
|
||||
r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle))
|
||||
if r != nil {
|
||||
err.msg = c.mem.readString(uint32(r[0]), 512)
|
||||
err.msg = c.mem.readString(uint32(r[0]), _MAX_STRING)
|
||||
}
|
||||
|
||||
if sql != nil {
|
||||
// sqlite3_error_offset is guaranteed to never change the value of the error code.
|
||||
r, _ = c.api.erroff.Call(c.ctx, uint64(c.handle))
|
||||
if r != nil && r[0] != math.MaxUint32 {
|
||||
err.sql = sql[0][r[0]:]
|
||||
}
|
||||
}
|
||||
|
||||
r, _ = c.api.errstr.Call(c.ctx, rc)
|
||||
if r != nil {
|
||||
err.str = c.mem.readString(uint32(r[0]), 512)
|
||||
}
|
||||
|
||||
if err.msg == err.str {
|
||||
switch err.msg {
|
||||
case err.str, "not an error":
|
||||
err.msg = ""
|
||||
|
||||
}
|
||||
return &err
|
||||
}
|
||||
@@ -186,13 +277,13 @@ func (c *Conn) free(ptr uint32) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) new(len uint32) uint32 {
|
||||
r, err := c.api.malloc.Call(c.ctx, uint64(len))
|
||||
func (c *Conn) new(size uint32) uint32 {
|
||||
r, err := c.api.malloc.Call(c.ctx, uint64(size))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ptr := uint32(r[0])
|
||||
if ptr == 0 && len != 0 {
|
||||
if ptr == 0 && size != 0 {
|
||||
panic(oomErr)
|
||||
}
|
||||
return ptr
|
||||
@@ -202,19 +293,54 @@ func (c *Conn) newBytes(b []byte) uint32 {
|
||||
if b == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
siz := uint32(len(b))
|
||||
ptr := c.new(siz)
|
||||
buf := c.mem.view(ptr, siz)
|
||||
copy(buf, b)
|
||||
ptr := c.new(uint32(len(b)))
|
||||
c.mem.writeBytes(ptr, b)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (c *Conn) newString(s string) uint32 {
|
||||
siz := uint32(len(s) + 1)
|
||||
ptr := c.new(siz)
|
||||
buf := c.mem.view(ptr, siz)
|
||||
buf[len(s)] = 0
|
||||
copy(buf, s)
|
||||
ptr := c.new(uint32(len(s) + 1))
|
||||
c.mem.writeString(ptr, s)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (c *Conn) newArena(size uint32) arena {
|
||||
return arena{
|
||||
c: c,
|
||||
size: size,
|
||||
base: c.new(size),
|
||||
}
|
||||
}
|
||||
|
||||
type arena struct {
|
||||
c *Conn
|
||||
base uint32
|
||||
next uint32
|
||||
size uint32
|
||||
ptrs []uint32
|
||||
}
|
||||
|
||||
func (a *arena) reset() {
|
||||
for _, ptr := range a.ptrs {
|
||||
a.c.free(ptr)
|
||||
}
|
||||
a.ptrs = nil
|
||||
a.next = 0
|
||||
}
|
||||
|
||||
func (a *arena) new(size uint32) uint32 {
|
||||
if a.next+size <= a.size {
|
||||
ptr := a.base + a.next
|
||||
a.next += size
|
||||
return ptr
|
||||
}
|
||||
ptr := a.c.new(size)
|
||||
a.ptrs = append(a.ptrs, ptr)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (a *arena) string(s string) uint32 {
|
||||
ptr := a.new(uint32(len(s) + 1))
|
||||
a.c.mem.writeString(ptr, s)
|
||||
return ptr
|
||||
}
|
||||
|
||||
135
conn_test.go
135
conn_test.go
@@ -2,6 +2,7 @@ package sqlite3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"testing"
|
||||
@@ -13,13 +14,15 @@ func TestConn_Close(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_Close_BUSY(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare("BEGIN")
|
||||
stmt, _, err := db.Prepare(`BEGIN`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -41,14 +44,92 @@ func TestConn_Close_BUSY(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Prepare_Empty(t *testing.T) {
|
||||
func TestConn_SetInterrupt(t *testing.T) {
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare("")
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
db.SetInterrupt(ctx.Done())
|
||||
|
||||
// Interrupt doesn't interrupt this.
|
||||
err = db.Exec(`SELECT 1`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.SetInterrupt(nil)
|
||||
|
||||
stmt, _, err := db.Prepare(`
|
||||
WITH RECURSIVE
|
||||
fibonacci (curr, next)
|
||||
AS (
|
||||
SELECT 0, 1
|
||||
UNION ALL
|
||||
SELECT next, curr + next FROM fibonacci
|
||||
LIMIT 1e6
|
||||
)
|
||||
SELECT min(curr) FROM fibonacci
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
cancel()
|
||||
db.SetInterrupt(ctx.Done())
|
||||
|
||||
var serr *Error
|
||||
|
||||
// Interrupting works.
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != INTERRUPT {
|
||||
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: interrupted` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Interrupting sticks.
|
||||
err = db.Exec(`SELECT 1`)
|
||||
if err != nil {
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != INTERRUPT {
|
||||
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: interrupted` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
db.SetInterrupt(nil)
|
||||
|
||||
// Interrupting can be cleared.
|
||||
err = db.Exec(`SELECT 1`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Prepare_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(``)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -60,6 +141,8 @@ func TestConn_Prepare_Empty(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_Prepare_Invalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -68,7 +151,7 @@ func TestConn_Prepare_Invalid(t *testing.T) {
|
||||
|
||||
var serr *Error
|
||||
|
||||
_, _, err = db.Prepare("SELECT")
|
||||
_, _, err = db.Prepare(`SELECT`)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
@@ -82,7 +165,7 @@ func TestConn_Prepare_Invalid(t *testing.T) {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
|
||||
_, _, err = db.Prepare("SELECT * FRM sqlite_schema")
|
||||
_, _, err = db.Prepare(`SELECT * FRM sqlite_schema`)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
@@ -101,6 +184,8 @@ func TestConn_Prepare_Invalid(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_new(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -112,7 +197,41 @@ func TestConn_new(t *testing.T) {
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestConn_newArena(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
arena := db.newArena(16)
|
||||
defer arena.reset()
|
||||
|
||||
const title = "Lorem ipsum"
|
||||
|
||||
ptr := arena.string(title)
|
||||
if ptr == 0 {
|
||||
t.Fatalf("got nullptr")
|
||||
}
|
||||
if got := db.mem.readString(ptr, math.MaxUint32); got != title {
|
||||
t.Errorf("got %q, want %q", got, title)
|
||||
}
|
||||
|
||||
const body = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
|
||||
ptr = arena.string(body)
|
||||
if ptr == 0 {
|
||||
t.Fatalf("got nullptr")
|
||||
}
|
||||
if got := db.mem.readString(ptr, math.MaxUint32); got != body {
|
||||
t.Errorf("got %q, want %q", got, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_newBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -137,6 +256,8 @@ func TestConn_newBytes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_newString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -161,6 +282,8 @@ func TestConn_newString(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_getString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -200,6 +323,8 @@ func TestConn_getString(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_free(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
24
const.go
24
const.go
@@ -9,11 +9,15 @@ const (
|
||||
|
||||
_UTF8 = 1
|
||||
|
||||
_MAX_STRING = 512 // Used for short strings: names, error messages…
|
||||
_MAX_PATHNAME = 512
|
||||
|
||||
ptrlen = 4
|
||||
)
|
||||
|
||||
// ErrorCode is a result code that [Error.Code] might return.
|
||||
//
|
||||
// https://www.sqlite.org/rescode.html
|
||||
type ErrorCode uint8
|
||||
|
||||
const (
|
||||
@@ -47,6 +51,9 @@ const (
|
||||
WARNING ErrorCode = 28 /* Warnings from sqlite3_log() */
|
||||
)
|
||||
|
||||
// ExtendedErrorCode is a result code that [Error.ExtendedCode] might return.
|
||||
//
|
||||
// https://www.sqlite.org/rescode.html
|
||||
type (
|
||||
ExtendedErrorCode uint16
|
||||
xErrorCode = ExtendedErrorCode
|
||||
@@ -128,6 +135,9 @@ const (
|
||||
AUTH_USER ExtendedErrorCode = xErrorCode(AUTH) | (1 << 8)
|
||||
)
|
||||
|
||||
// OpenFlag is a flag for a file open operation.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_open_autoproxy.html
|
||||
type OpenFlag uint32
|
||||
|
||||
const (
|
||||
@@ -155,14 +165,17 @@ const (
|
||||
OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */
|
||||
)
|
||||
|
||||
type AccessFlag uint32
|
||||
type _AccessFlag uint32
|
||||
|
||||
const (
|
||||
ACCESS_EXISTS AccessFlag = 0
|
||||
ACCESS_READWRITE AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
|
||||
ACCESS_READ AccessFlag = 2 /* Unused */
|
||||
_ACCESS_EXISTS _AccessFlag = 0
|
||||
_ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
|
||||
_ACCESS_READ _AccessFlag = 2 /* Unused */
|
||||
)
|
||||
|
||||
// PrepareFlag is a flag that can be passed to [Conn.PrepareFlags].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_prepare_normalize.html
|
||||
type PrepareFlag uint32
|
||||
|
||||
const (
|
||||
@@ -171,6 +184,9 @@ const (
|
||||
PREPARE_NO_VTAB PrepareFlag = 0x04
|
||||
)
|
||||
|
||||
// Datatype is a fundamental datatype of SQLite.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_blob.html
|
||||
type Datatype uint32
|
||||
|
||||
const (
|
||||
|
||||
284
driver/driver.go
Normal file
284
driver/driver.go
Normal file
@@ -0,0 +1,284 @@
|
||||
// Package driver provides a database/sql driver for SQLite.
|
||||
package driver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
sql.Register("sqlite3", sqlite{})
|
||||
}
|
||||
|
||||
type sqlite struct{}
|
||||
|
||||
func (sqlite) Open(name string) (driver.Conn, error) {
|
||||
c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// If the database is not in WAL mode,
|
||||
// use normal locking mode.
|
||||
journal, err := pragma(c, "journal_mode")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if journal != "wal" {
|
||||
pragma(c, "locking_mode=normal")
|
||||
}
|
||||
return conn{c}, nil
|
||||
}
|
||||
|
||||
type conn struct{ conn *sqlite3.Conn }
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ driver.Validator = conn{}
|
||||
_ driver.ExecerContext = conn{}
|
||||
// _ driver.ConnBeginTx = conn{}
|
||||
// _ driver.SessionResetter = conn{}
|
||||
)
|
||||
|
||||
func (c conn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c conn) IsValid() bool {
|
||||
// Pool only normal locking mode connections.
|
||||
mode, _ := pragma(c.conn, "locking_mode")
|
||||
return mode == "normal"
|
||||
}
|
||||
|
||||
func (c conn) Begin() (driver.Tx, error) {
|
||||
err := c.conn.Exec(`BEGIN`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c conn) Commit() error {
|
||||
err := c.conn.Exec(`COMMIT`)
|
||||
if err != nil {
|
||||
c.Rollback()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c conn) Rollback() error {
|
||||
return c.conn.Exec(`ROLLBACK`)
|
||||
}
|
||||
|
||||
func (c conn) Prepare(query string) (driver.Stmt, error) {
|
||||
s, tail, err := c.conn.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tail != "" {
|
||||
// Check if the tail contains any SQL.
|
||||
s, _, err := c.conn.Prepare(tail)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s != nil {
|
||||
s.Close()
|
||||
return nil, tailErr
|
||||
}
|
||||
}
|
||||
return stmt{s, c.conn}, nil
|
||||
}
|
||||
|
||||
func (c conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
if len(args) != 0 {
|
||||
// Slow path.
|
||||
return nil, driver.ErrSkip
|
||||
}
|
||||
|
||||
ch := c.conn.SetInterrupt(ctx.Done())
|
||||
defer c.conn.SetInterrupt(ch)
|
||||
|
||||
err := c.conn.Exec(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result{
|
||||
int64(c.conn.LastInsertRowID()),
|
||||
int64(c.conn.Changes()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func pragma(c *sqlite3.Conn, pragma string) (string, error) {
|
||||
stmt, _, err := c.Prepare(`PRAGMA ` + pragma)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if stmt.Step() {
|
||||
return stmt.ColumnText(0), nil
|
||||
}
|
||||
return "", stmt.Err()
|
||||
}
|
||||
|
||||
type stmt struct {
|
||||
stmt *sqlite3.Stmt
|
||||
conn *sqlite3.Conn
|
||||
}
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ driver.StmtExecContext = stmt{}
|
||||
_ driver.StmtQueryContext = stmt{}
|
||||
)
|
||||
|
||||
func (s stmt) Close() error {
|
||||
return s.stmt.Close()
|
||||
}
|
||||
|
||||
func (s stmt) NumInput() int {
|
||||
return s.stmt.BindCount()
|
||||
}
|
||||
|
||||
// Deprecated: use ExecContext instead.
|
||||
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
return s.ExecContext(context.Background(), namedValues(args))
|
||||
}
|
||||
|
||||
// Deprecated: use QueryContext instead.
|
||||
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
return s.QueryContext(context.Background(), namedValues(args))
|
||||
}
|
||||
|
||||
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
_, err := s.QueryContext(ctx, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.stmt.Exec()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result{
|
||||
int64(s.conn.LastInsertRowID()),
|
||||
int64(s.conn.Changes()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
err := s.stmt.ClearBindings()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ids [3]int
|
||||
for _, arg := range args {
|
||||
ids := ids[:0]
|
||||
if arg.Name == "" {
|
||||
ids = append(ids, arg.Ordinal)
|
||||
} else {
|
||||
for _, prefix := range []string{":", "@", "$"} {
|
||||
if id := s.stmt.BindIndex(prefix + arg.Name); id != 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
switch a := arg.Value.(type) {
|
||||
case bool:
|
||||
err = s.stmt.BindBool(id, a)
|
||||
case int64:
|
||||
err = s.stmt.BindInt64(id, a)
|
||||
case float64:
|
||||
err = s.stmt.BindFloat(id, a)
|
||||
case string:
|
||||
err = s.stmt.BindText(id, a)
|
||||
case []byte:
|
||||
err = s.stmt.BindBlob(id, a)
|
||||
case time.Time:
|
||||
err = s.stmt.BindText(id, a.Format(time.RFC3339Nano))
|
||||
case nil:
|
||||
err = s.stmt.BindNull(id)
|
||||
default:
|
||||
panic(assertErr)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return rows{ctx, s.stmt, s.conn}, nil
|
||||
}
|
||||
|
||||
type result struct{ lastInsertId, rowsAffected int64 }
|
||||
|
||||
func (r result) LastInsertId() (int64, error) {
|
||||
return r.lastInsertId, nil
|
||||
}
|
||||
|
||||
func (r result) RowsAffected() (int64, error) {
|
||||
return r.rowsAffected, nil
|
||||
}
|
||||
|
||||
type rows struct {
|
||||
ctx context.Context
|
||||
stmt *sqlite3.Stmt
|
||||
conn *sqlite3.Conn
|
||||
}
|
||||
|
||||
func (r rows) Close() error {
|
||||
return r.stmt.Reset()
|
||||
}
|
||||
|
||||
func (r rows) Columns() []string {
|
||||
count := r.stmt.ColumnCount()
|
||||
columns := make([]string, count)
|
||||
for i := range columns {
|
||||
columns[i] = r.stmt.ColumnName(i)
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
func (r rows) Next(dest []driver.Value) error {
|
||||
ch := r.conn.SetInterrupt(r.ctx.Done())
|
||||
defer r.conn.SetInterrupt(ch)
|
||||
|
||||
if !r.stmt.Step() {
|
||||
if err := r.stmt.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
for i := range dest {
|
||||
switch r.stmt.ColumnType(i) {
|
||||
case sqlite3.INTEGER:
|
||||
dest[i] = r.stmt.ColumnInt64(i)
|
||||
case sqlite3.FLOAT:
|
||||
dest[i] = r.stmt.ColumnFloat(i)
|
||||
case sqlite3.TEXT:
|
||||
dest[i] = maybeDate(r.stmt.ColumnText(i))
|
||||
case sqlite3.BLOB:
|
||||
buf, _ := dest[i].([]byte)
|
||||
dest[i] = r.stmt.ColumnBlob(i, buf)
|
||||
case sqlite3.NULL:
|
||||
if buf, ok := dest[i].([]byte); ok {
|
||||
dest[i] = buf[0:0]
|
||||
} else {
|
||||
dest[i] = nil
|
||||
}
|
||||
default:
|
||||
panic(assertErr)
|
||||
}
|
||||
}
|
||||
|
||||
return r.stmt.Err()
|
||||
}
|
||||
10
driver/error.go
Normal file
10
driver/error.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package driver
|
||||
|
||||
type errorString string
|
||||
|
||||
func (e errorString) Error() string { return string(e) }
|
||||
|
||||
const (
|
||||
assertErr = errorString("sqlite3: assertion failed")
|
||||
tailErr = errorString("sqlite3: multiple statements")
|
||||
)
|
||||
149
driver/example_test.go
Normal file
149
driver/example_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package driver_test
|
||||
|
||||
// Adapted from: https://go.dev/doc/tutorial/database-access
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
var db *sql.DB
|
||||
|
||||
type Album struct {
|
||||
ID int64
|
||||
Title string
|
||||
Artist string
|
||||
Price float32
|
||||
}
|
||||
|
||||
func Example() {
|
||||
// Get a database handle.
|
||||
var err error
|
||||
db, err = sql.Open("sqlite3", "./recordings.db")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
defer os.Remove("./recordings.db")
|
||||
|
||||
err = createAlbumsTable()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
albums, err := albumsByArtist("John Coltrane")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Albums found: %v\n", albums)
|
||||
|
||||
// Hard-code ID 2 here to test the query.
|
||||
alb, err := albumByID(2)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Album found: %v\n", alb)
|
||||
|
||||
albID, err := addAlbum(Album{
|
||||
Title: "The Modern Sound of Betty Carter",
|
||||
Artist: "Betty Carter",
|
||||
Price: 49.99,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("ID of added album: %v\n", albID)
|
||||
|
||||
// Output:
|
||||
// Albums found: [{1 Blue Train John Coltrane 56.99} {2 Giant Steps John Coltrane 63.99}]
|
||||
// Album found: {2 Giant Steps John Coltrane 63.99}
|
||||
// ID of added album: 5
|
||||
}
|
||||
|
||||
func createAlbumsTable() error {
|
||||
_, err := db.Exec(`
|
||||
DROP TABLE IF EXISTS album;
|
||||
CREATE TABLE album (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
title VARCHAR(128) NOT NULL,
|
||||
artist VARCHAR(255) NOT NULL,
|
||||
price DECIMAL(5,2) NOT NULL
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO album
|
||||
(title, artist, price)
|
||||
VALUES
|
||||
('Blue Train', 'John Coltrane', 56.99),
|
||||
('Giant Steps', 'John Coltrane', 63.99),
|
||||
('Jeru', 'Gerry Mulligan', 17.99),
|
||||
('Sarah Vaughan', 'Sarah Vaughan', 34.98)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// albumsByArtist queries for albums that have the specified artist name.
|
||||
func albumsByArtist(name string) ([]Album, error) {
|
||||
// An albums slice to hold data from returned rows.
|
||||
var albums []Album
|
||||
|
||||
rows, err := db.Query("SELECT * FROM album WHERE artist = ?", name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("albumsByArtist %q: %w", name, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
// Loop through rows, using Scan to assign column data to struct fields.
|
||||
for rows.Next() {
|
||||
var alb Album
|
||||
if err := rows.Scan(&alb.ID, &alb.Title, &alb.Artist, &alb.Price); err != nil {
|
||||
return nil, fmt.Errorf("albumsByArtist %q: %w", name, err)
|
||||
}
|
||||
albums = append(albums, alb)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("albumsByArtist %q: %w", name, err)
|
||||
}
|
||||
return albums, nil
|
||||
}
|
||||
|
||||
// albumByID queries for the album with the specified ID.
|
||||
func albumByID(id int64) (Album, error) {
|
||||
// An album to hold data from the returned row.
|
||||
var alb Album
|
||||
|
||||
row := db.QueryRow("SELECT * FROM album WHERE id = ?", id)
|
||||
if err := row.Scan(&alb.ID, &alb.Title, &alb.Artist, &alb.Price); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return alb, fmt.Errorf("albumsById %d: no such album", id)
|
||||
}
|
||||
return alb, fmt.Errorf("albumsById %d: %w", id, err)
|
||||
}
|
||||
return alb, nil
|
||||
}
|
||||
|
||||
// addAlbum adds the specified album to the database,
|
||||
// returning the album ID of the new entry
|
||||
func addAlbum(alb Album) (int64, error) {
|
||||
result, err := db.Exec("INSERT INTO album (title, artist, price) VALUES (?, ?, ?)", alb.Title, alb.Artist, alb.Price)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("addAlbum: %w", err)
|
||||
}
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("addAlbum: %w", err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
19
driver/time.go
Normal file
19
driver/time.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package driver
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
|
||||
// if it roundtrips back to the same string.
|
||||
// This way times can be persisted to, and recovered from, the database,
|
||||
// but if a string is needed, [database.sql] will recover the same string.
|
||||
// TODO: optimize and fuzz test.
|
||||
func maybeDate(text string) driver.Value {
|
||||
date, err := time.Parse(time.RFC3339Nano, text)
|
||||
if err == nil && date.Format(time.RFC3339Nano) == text {
|
||||
return date
|
||||
}
|
||||
return text
|
||||
}
|
||||
14
driver/util.go
Normal file
14
driver/util.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package driver
|
||||
|
||||
import "database/sql/driver"
|
||||
|
||||
func namedValues(args []driver.Value) []driver.NamedValue {
|
||||
named := make([]driver.NamedValue, len(args))
|
||||
for i, v := range args {
|
||||
named[i] = driver.NamedValue{
|
||||
Ordinal: i + 1,
|
||||
Value: v,
|
||||
}
|
||||
}
|
||||
return named
|
||||
}
|
||||
18
driver/util_test.go
Normal file
18
driver/util_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package driver
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_namedValues(t *testing.T) {
|
||||
want := []driver.NamedValue{
|
||||
{Ordinal: 1, Value: true},
|
||||
{Ordinal: 2, Value: false},
|
||||
}
|
||||
got := namedValues([]driver.Value{true, false})
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
@@ -28,15 +28,23 @@ zig cc --target=wasm32-wasi -flto -g0 -Os \
|
||||
-Wl,--export=sqlite3_step \
|
||||
-Wl,--export=sqlite3_exec \
|
||||
-Wl,--export=sqlite3_clear_bindings \
|
||||
-Wl,--export=sqlite3_bind_parameter_count \
|
||||
-Wl,--export=sqlite3_bind_parameter_index \
|
||||
-Wl,--export=sqlite3_bind_parameter_name \
|
||||
-Wl,--export=sqlite3_bind_null \
|
||||
-Wl,--export=sqlite3_bind_int64 \
|
||||
-Wl,--export=sqlite3_bind_double \
|
||||
-Wl,--export=sqlite3_bind_text64 \
|
||||
-Wl,--export=sqlite3_bind_blob64 \
|
||||
-Wl,--export=sqlite3_bind_zeroblob64 \
|
||||
-Wl,--export=sqlite3_bind_null \
|
||||
-Wl,--export=sqlite3_column_count \
|
||||
-Wl,--export=sqlite3_column_name \
|
||||
-Wl,--export=sqlite3_column_type \
|
||||
-Wl,--export=sqlite3_column_int64 \
|
||||
-Wl,--export=sqlite3_column_double \
|
||||
-Wl,--export=sqlite3_column_text \
|
||||
-Wl,--export=sqlite3_column_blob \
|
||||
-Wl,--export=sqlite3_column_bytes \
|
||||
-Wl,--export=sqlite3_column_type \
|
||||
-Wl,--export=sqlite3_last_insert_rowid \
|
||||
-Wl,--export=sqlite3_changes64 \
|
||||
-Wl,--export=sqlite3_interrupt \
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
// Package embed embeds SQLite into your application.
|
||||
//
|
||||
// You can obtain this build of SQLite from:
|
||||
// https://github.com/ncruces/go-sqlite3/tree/main/embed
|
||||
package embed
|
||||
|
||||
import (
|
||||
|
||||
Binary file not shown.
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -8,8 +8,10 @@ import (
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
const memory = ":memory:"
|
||||
|
||||
func Example() {
|
||||
db, err := sqlite3.Open(memory)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -45,4 +47,9 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Output:
|
||||
// 0 go
|
||||
// 1 zig
|
||||
// 2 whatever
|
||||
}
|
||||
8
mem.go
8
mem.go
@@ -99,9 +99,13 @@ func (m memory) readString(ptr, maxlen uint32) string {
|
||||
}
|
||||
}
|
||||
|
||||
func (m memory) writeBytes(ptr uint32, b []byte) {
|
||||
buf := m.view(ptr, uint32(len(b)))
|
||||
copy(buf, b)
|
||||
}
|
||||
|
||||
func (m memory) writeString(ptr uint32, s string) {
|
||||
siz := uint32(len(s) + 1)
|
||||
buf := m.view(ptr, siz)
|
||||
buf := m.view(ptr, uint32(len(s)+1))
|
||||
buf[len(s)] = 0
|
||||
copy(buf, s)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include <stdbool.h>
|
||||
#include <stdlib.h>
|
||||
#include <time.h>
|
||||
|
||||
@@ -33,6 +34,7 @@ int go_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
|
||||
int go_truncate(sqlite3_file *, sqlite3_int64 size);
|
||||
int go_sync(sqlite3_file *, int flags);
|
||||
int go_file_size(sqlite3_file *, sqlite3_int64 *pSize);
|
||||
int go_file_control(sqlite3_file *pFile, int op, void *pArg);
|
||||
|
||||
int go_lock(sqlite3_file *pFile, int eLock);
|
||||
int go_unlock(sqlite3_file *pFile, int eLock);
|
||||
@@ -94,7 +96,7 @@ int sqlite3_os_init() {
|
||||
.xCurrentTime = go_current_time,
|
||||
.xCurrentTimeInt64 = go_current_time_64,
|
||||
};
|
||||
return sqlite3_vfs_register(&go_vfs, /*default=*/1);
|
||||
return sqlite3_vfs_register(&go_vfs, /*default=*/true);
|
||||
}
|
||||
|
||||
sqlite3_destructor_type malloc_destructor = &free;
|
||||
@@ -5,6 +5,9 @@
|
||||
#define SQLITE_OS_OTHER 1
|
||||
#define SQLITE_BYTEORDER 1234
|
||||
|
||||
#define HAVE_STDINT_H 1
|
||||
#define HAVE_INTTYPES_H 1
|
||||
|
||||
#define HAVE_ISNAN 1
|
||||
#define HAVE_USLEEP 1
|
||||
#define HAVE_LOCALTIME_S 1
|
||||
@@ -25,12 +28,18 @@
|
||||
#define SQLITE_OMIT_AUTOINIT
|
||||
#define SQLITE_USE_ALLOCA
|
||||
|
||||
// Recommended Extensions
|
||||
|
||||
// #define SQLITE_ENABLE_MATH_FUNCTIONS 1
|
||||
// #define SQLITE_ENABLE_FTS3 1
|
||||
// #define SQLITE_ENABLE_FTS3_PARENTHESIS 1
|
||||
// #define SQLITE_ENABLE_FTS4 1
|
||||
// #define SQLITE_ENABLE_FTS5 1
|
||||
// #define SQLITE_ENABLE_RTREE 1
|
||||
// #define SQLITE_ENABLE_GEOPOLY 1
|
||||
|
||||
// Need this to access WAL databases without the use of shared memory.
|
||||
#define SQLITE_DEFAULT_LOCKING_MODE 1
|
||||
// Go uses UTF-8 everywhere.
|
||||
#define SQLITE_OMIT_UTF16
|
||||
// Remove some testing code.
|
||||
#define SQLITE_UNTESTABLE
|
||||
|
||||
// Implemented in Go.
|
||||
int localtime_s(struct tm *const pTm, time_t const *const pTime);
|
||||
96
stmt.go
96
stmt.go
@@ -17,13 +17,13 @@ type Stmt struct {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/finalize.html
|
||||
func (s *Stmt) Close() error {
|
||||
if s == nil {
|
||||
if s == nil || s.handle == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
r, err := s.c.api.finalize.Call(s.c.ctx, uint64(s.handle))
|
||||
if err != nil {
|
||||
return err
|
||||
panic(err)
|
||||
}
|
||||
|
||||
s.handle = 0
|
||||
@@ -36,7 +36,7 @@ func (s *Stmt) Close() error {
|
||||
func (s *Stmt) Reset() error {
|
||||
r, err := s.c.api.reset.Call(s.c.ctx, uint64(s.handle))
|
||||
if err != nil {
|
||||
return err
|
||||
panic(err)
|
||||
}
|
||||
s.err = nil
|
||||
return s.c.error(r[0])
|
||||
@@ -48,7 +48,7 @@ func (s *Stmt) Reset() error {
|
||||
func (s *Stmt) ClearBindings() error {
|
||||
r, err := s.c.api.clearBindings.Call(s.c.ctx, uint64(s.handle))
|
||||
if err != nil {
|
||||
return err
|
||||
panic(err)
|
||||
}
|
||||
return s.c.error(r[0])
|
||||
}
|
||||
@@ -65,8 +65,7 @@ func (s *Stmt) ClearBindings() error {
|
||||
func (s *Stmt) Step() bool {
|
||||
r, err := s.c.api.step.Call(s.c.ctx, uint64(s.handle))
|
||||
if err != nil {
|
||||
s.err = err
|
||||
return false
|
||||
panic(err)
|
||||
}
|
||||
if r[0] == _ROW {
|
||||
return true
|
||||
@@ -95,6 +94,51 @@ func (s *Stmt) Exec() error {
|
||||
return s.Reset()
|
||||
}
|
||||
|
||||
// BindCount returns the number of SQL parameters in the prepared statement.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_parameter_count.html
|
||||
func (s *Stmt) BindCount() int {
|
||||
r, err := s.c.api.bindCount.Call(s.c.ctx,
|
||||
uint64(s.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int(r[0])
|
||||
}
|
||||
|
||||
// BindIndex returns the index of a parameter in the prepared statement
|
||||
// given its name.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_parameter_index.html
|
||||
func (s *Stmt) BindIndex(name string) int {
|
||||
defer s.c.arena.reset()
|
||||
namePtr := s.c.arena.string(name)
|
||||
r, err := s.c.api.bindIndex.Call(s.c.ctx,
|
||||
uint64(s.handle), uint64(namePtr))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int(r[0])
|
||||
}
|
||||
|
||||
// BindName returns the name of a parameter in the prepared statement.
|
||||
// The leftmost SQL parameter has an index of 1.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_parameter_name.html
|
||||
func (s *Stmt) BindName(param int) string {
|
||||
r, err := s.c.api.bindName.Call(s.c.ctx,
|
||||
uint64(s.handle), uint64(param))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ptr := uint32(r[0])
|
||||
if ptr == 0 {
|
||||
return ""
|
||||
}
|
||||
return s.c.mem.readString(ptr, _MAX_STRING)
|
||||
}
|
||||
|
||||
// BindBool binds a bool to the prepared statement.
|
||||
// The leftmost SQL parameter has an index of 1.
|
||||
// SQLite does not have a separate boolean storage class.
|
||||
@@ -124,7 +168,7 @@ func (s *Stmt) BindInt64(param int, value int64) error {
|
||||
r, err := s.c.api.bindInteger.Call(s.c.ctx,
|
||||
uint64(s.handle), uint64(param), uint64(value))
|
||||
if err != nil {
|
||||
return err
|
||||
panic(err)
|
||||
}
|
||||
return s.c.error(r[0])
|
||||
}
|
||||
@@ -137,7 +181,7 @@ func (s *Stmt) BindFloat(param int, value float64) error {
|
||||
r, err := s.c.api.bindFloat.Call(s.c.ctx,
|
||||
uint64(s.handle), uint64(param), math.Float64bits(value))
|
||||
if err != nil {
|
||||
return err
|
||||
panic(err)
|
||||
}
|
||||
return s.c.error(r[0])
|
||||
}
|
||||
@@ -153,7 +197,7 @@ func (s *Stmt) BindText(param int, value string) error {
|
||||
uint64(ptr), uint64(len(value)),
|
||||
s.c.api.destructor, _UTF8)
|
||||
if err != nil {
|
||||
return err
|
||||
panic(err)
|
||||
}
|
||||
return s.c.error(r[0])
|
||||
}
|
||||
@@ -170,7 +214,7 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
|
||||
uint64(ptr), uint64(len(value)),
|
||||
s.c.api.destructor)
|
||||
if err != nil {
|
||||
return err
|
||||
panic(err)
|
||||
}
|
||||
return s.c.error(r[0])
|
||||
}
|
||||
@@ -183,11 +227,41 @@ func (s *Stmt) BindNull(param int) error {
|
||||
r, err := s.c.api.bindNull.Call(s.c.ctx,
|
||||
uint64(s.handle), uint64(param))
|
||||
if err != nil {
|
||||
return err
|
||||
panic(err)
|
||||
}
|
||||
return s.c.error(r[0])
|
||||
}
|
||||
|
||||
// ColumnCount returns the number of columns in a result set.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_count.html
|
||||
func (s *Stmt) ColumnCount() int {
|
||||
r, err := s.c.api.columnCount.Call(s.c.ctx,
|
||||
uint64(s.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int(r[0])
|
||||
}
|
||||
|
||||
// ColumnName returns the name of the result column.
|
||||
// The leftmost column of the result set has the index 0.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_name.html
|
||||
func (s *Stmt) ColumnName(col int) string {
|
||||
r, err := s.c.api.columnName.Call(s.c.ctx,
|
||||
uint64(s.handle), uint64(col))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ptr := uint32(r[0])
|
||||
if ptr == 0 {
|
||||
return ""
|
||||
}
|
||||
return s.c.mem.readString(ptr, _MAX_STRING)
|
||||
}
|
||||
|
||||
// ColumnType returns the initial [Datatype] of the result column.
|
||||
// The leftmost column of the result set has the index 0.
|
||||
//
|
||||
|
||||
39
stmt_test.go
39
stmt_test.go
@@ -6,6 +6,8 @@ import (
|
||||
)
|
||||
|
||||
func TestStmt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -23,6 +25,10 @@ func TestStmt(t *testing.T) {
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if got := stmt.BindCount(); got != 1 {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
|
||||
err = stmt.BindBool(1, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -133,6 +139,7 @@ func TestStmt(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != INTEGER {
|
||||
@@ -359,3 +366,35 @@ func TestStmt_Close(t *testing.T) {
|
||||
var stmt *Stmt
|
||||
stmt.Close()
|
||||
}
|
||||
|
||||
func TestStmt_BindName(t *testing.T) {
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
want := []string{"", "", "", "", "?5", ":AAA", "@AAA", "$AAA"}
|
||||
stmt, _, err := db.Prepare(`SELECT ?, ?5, :AAA, @AAA, $AAA`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if got := stmt.BindCount(); got != len(want) {
|
||||
t.Errorf("got %d, want %d", got, len(want))
|
||||
}
|
||||
|
||||
for i, name := range want {
|
||||
id := i + 1
|
||||
if got := stmt.BindName(id); got != name {
|
||||
t.Errorf("got %q, want %q", got, name)
|
||||
}
|
||||
if name == "" {
|
||||
id = 0
|
||||
}
|
||||
if got := stmt.BindIndex(name); got != id {
|
||||
t.Errorf("got %d, want %d", got, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
@@ -14,13 +13,7 @@ func TestDB_memory(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDB_file(t *testing.T) {
|
||||
dir, err := os.MkdirTemp("", "sqlite3-")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
testDB(t, filepath.Join(dir, "test.db"))
|
||||
testDB(t, filepath.Join(t.TempDir(), "test.db"))
|
||||
}
|
||||
|
||||
func testDB(t *testing.T, name string) {
|
||||
@@ -44,6 +37,7 @@ func testDB(t *testing.T, name string) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
row := 0
|
||||
ids := []int{0, 1, 2}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -13,18 +14,53 @@ import (
|
||||
)
|
||||
|
||||
func TestParallel(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip()
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
testParallel(t, name, 100)
|
||||
testIntegrity(t, name)
|
||||
}
|
||||
|
||||
func TestMultiProcess(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "sqlite3-")
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
t.Setenv("TestMultiProcess_dbname", name)
|
||||
|
||||
cmd := exec.Command("go", "test", "-v", "-run", "TestChildProcess")
|
||||
out, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var buf [3]byte
|
||||
// Wait for child to start.
|
||||
if _, err := io.ReadFull(out, buf[:]); err != nil || string(buf[:]) != "===" {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testParallel(t, name, 1000)
|
||||
if err := cmd.Wait(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testIntegrity(t, name)
|
||||
}
|
||||
|
||||
func TestChildProcess(t *testing.T) {
|
||||
name := os.Getenv("TestMultiProcess_dbname")
|
||||
if name == "" || testing.Short() {
|
||||
return
|
||||
}
|
||||
|
||||
testParallel(t, name, 1000)
|
||||
}
|
||||
|
||||
func testParallel(t *testing.T, name string, n int) {
|
||||
writer := func() error {
|
||||
db, err := sqlite3.Open(filepath.Join(dir, "test.db"))
|
||||
db, err := sqlite3.Open(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -32,7 +68,7 @@ func TestParallel(t *testing.T) {
|
||||
|
||||
err = db.Exec(`
|
||||
PRAGMA locking_mode = NORMAL;
|
||||
PRAGMA busy_timeout = 1000;
|
||||
PRAGMA busy_timeout = 10000;
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -52,7 +88,7 @@ func TestParallel(t *testing.T) {
|
||||
}
|
||||
|
||||
reader := func() error {
|
||||
db, err := sqlite3.Open(filepath.Join(dir, "test.db"))
|
||||
db, err := sqlite3.Open(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -60,7 +96,7 @@ func TestParallel(t *testing.T) {
|
||||
|
||||
err = db.Exec(`
|
||||
PRAGMA locking_mode = NORMAL;
|
||||
PRAGMA busy_timeout = 1000;
|
||||
PRAGMA busy_timeout = 10000;
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -70,6 +106,7 @@ func TestParallel(t *testing.T) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
row := 0
|
||||
for stmt.Step() {
|
||||
@@ -90,14 +127,14 @@ func TestParallel(t *testing.T) {
|
||||
return db.Close()
|
||||
}
|
||||
|
||||
err = writer()
|
||||
err := writer()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var group errgroup.Group
|
||||
group.SetLimit(4)
|
||||
for i := 0; i < 32; i++ {
|
||||
for i := 0; i < n; i++ {
|
||||
if i&7 != 7 {
|
||||
group.Go(reader)
|
||||
} else {
|
||||
@@ -109,3 +146,41 @@ func TestParallel(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func testIntegrity(t *testing.T, name string) {
|
||||
db, err := sqlite3.Open(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
test := `PRAGMA integrity_check`
|
||||
if testing.Short() {
|
||||
test = `PRAGMA quick_check`
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(test)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
if row := stmt.ColumnText(0); row != "ok" {
|
||||
t.Error(row)
|
||||
}
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
35
vfs.go
35
vfs.go
@@ -18,12 +18,12 @@ import (
|
||||
"github.com/tetratelabs/wazero/sys"
|
||||
)
|
||||
|
||||
func vfsInstantiate(ctx context.Context, r wazero.Runtime) (err error) {
|
||||
func vfsInstantiate(ctx context.Context, r wazero.Runtime) {
|
||||
wasi := r.NewHostModuleBuilder("wasi_snapshot_preview1")
|
||||
wasi.NewFunctionBuilder().WithFunc(vfsExit).Export("proc_exit")
|
||||
_, err = wasi.Instantiate(ctx)
|
||||
_, err := wasi.Instantiate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
panic(err)
|
||||
}
|
||||
|
||||
env := r.NewHostModuleBuilder("env")
|
||||
@@ -45,8 +45,11 @@ func vfsInstantiate(ctx context.Context, r wazero.Runtime) (err error) {
|
||||
env.NewFunctionBuilder().WithFunc(vfsLock).Export("go_lock")
|
||||
env.NewFunctionBuilder().WithFunc(vfsUnlock).Export("go_unlock")
|
||||
env.NewFunctionBuilder().WithFunc(vfsCheckReservedLock).Export("go_check_reserved_lock")
|
||||
env.NewFunctionBuilder().WithFunc(vfsFileControl).Export("go_file_control")
|
||||
_, err = env.Instantiate(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func vfsExit(ctx context.Context, mod api.Module, exitCode uint32) {
|
||||
@@ -112,11 +115,11 @@ func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull
|
||||
// Or using [os.Readlink] to resolve a symbolic link (as the Unix VFS did).
|
||||
// This might be buggy on Windows (the Windows VFS doesn't try).
|
||||
|
||||
siz := uint32(len(abs) + 1)
|
||||
if siz > nFull {
|
||||
size := uint32(len(abs) + 1)
|
||||
if size > nFull {
|
||||
return uint32(CANTOPEN_FULLPATH)
|
||||
}
|
||||
mem := memory{mod}.view(zFull, siz)
|
||||
mem := memory{mod}.view(zFull, size)
|
||||
|
||||
mem[len(abs)] = 0
|
||||
copy(mem, abs)
|
||||
@@ -145,7 +148,7 @@ func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32)
|
||||
return _OK
|
||||
}
|
||||
|
||||
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags AccessFlag, pResOut uint32) uint32 {
|
||||
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags _AccessFlag, pResOut uint32) uint32 {
|
||||
// Consider using [syscall.Access] for [ACCESS_READWRITE]/[ACCESS_READ]
|
||||
// (as the Unix VFS does).
|
||||
|
||||
@@ -154,7 +157,7 @@ func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags Ac
|
||||
|
||||
var res uint32
|
||||
switch {
|
||||
case flags == ACCESS_EXISTS:
|
||||
case flags == _ACCESS_EXISTS:
|
||||
switch {
|
||||
case err == nil:
|
||||
res = 1
|
||||
@@ -166,7 +169,7 @@ func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags Ac
|
||||
|
||||
case err == nil:
|
||||
var want fs.FileMode = syscall.S_IRUSR
|
||||
if flags == ACCESS_READWRITE {
|
||||
if flags == _ACCESS_READWRITE {
|
||||
want |= syscall.S_IWUSR
|
||||
}
|
||||
if fi.IsDir() {
|
||||
@@ -304,3 +307,15 @@ func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) uint3
|
||||
memory{mod}.writeUint64(pSize, uint64(off))
|
||||
return _OK
|
||||
}
|
||||
|
||||
func vfsFileControl(ctx context.Context, pFile, op, pArg uint32) uint32 {
|
||||
// SQLite calls vfsFileControl with these opcodes:
|
||||
// SQLITE_FCNTL_SIZE_HINT
|
||||
// SQLITE_FCNTL_PRAGMA
|
||||
// SQLITE_FCNTL_BUSYHANDLER
|
||||
// SQLITE_FCNTL_HAS_MOVED
|
||||
// SQLITE_FCNTL_SYNC
|
||||
// SQLITE_FCNTL_COMMIT_PHASETWO
|
||||
// SQLITE_FCNTL_PDB
|
||||
return uint32(NOTFOUND)
|
||||
}
|
||||
|
||||
83
vfs_lock.go
83
vfs_lock.go
@@ -64,7 +64,7 @@ type vfsFileLocker struct {
|
||||
}
|
||||
|
||||
func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) uint32 {
|
||||
// SQLite never explicitly requests a pendig lock.
|
||||
// Argument check. SQLite never explicitly requests a pendig lock.
|
||||
if eLock != _SHARED_LOCK && eLock != _RESERVED_LOCK && eLock != _EXCLUSIVE_LOCK {
|
||||
panic(assertErr())
|
||||
}
|
||||
@@ -72,12 +72,10 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
|
||||
ptr := vfsFilePtr{mod, pFile}
|
||||
cLock := ptr.Lock()
|
||||
|
||||
// If we already have an equal or more restrictive lock, do nothing.
|
||||
if cLock >= eLock {
|
||||
return _OK
|
||||
}
|
||||
|
||||
switch {
|
||||
case cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK:
|
||||
// Connection state check.
|
||||
panic(assertErr())
|
||||
case cLock == _NO_LOCK && eLock > _SHARED_LOCK:
|
||||
// We never move from unlocked to anything higher than a shared lock.
|
||||
panic(assertErr())
|
||||
@@ -86,31 +84,51 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
|
||||
panic(assertErr())
|
||||
}
|
||||
|
||||
// If we already have an equal or more restrictive lock, do nothing.
|
||||
if cLock >= eLock {
|
||||
return _OK
|
||||
}
|
||||
|
||||
fLock := ptr.Locker()
|
||||
fLock.Lock()
|
||||
defer fLock.Unlock()
|
||||
|
||||
// File state check.
|
||||
switch {
|
||||
case fLock.state < _NO_LOCK || fLock.state > _EXCLUSIVE_LOCK:
|
||||
panic(assertErr())
|
||||
case fLock.state == _NO_LOCK && fLock.shared != 0:
|
||||
panic(assertErr())
|
||||
case fLock.state == _EXCLUSIVE_LOCK && fLock.shared != 1:
|
||||
panic(assertErr())
|
||||
case fLock.state != _NO_LOCK && fLock.shared <= 0:
|
||||
panic(assertErr())
|
||||
case fLock.state < cLock:
|
||||
panic(assertErr())
|
||||
}
|
||||
|
||||
// If some other connection has a lock that precludes the requested lock, return BUSY.
|
||||
if cLock != fLock.state && (eLock > _SHARED_LOCK || fLock.state >= _PENDING_LOCK) {
|
||||
return uint32(BUSY)
|
||||
}
|
||||
|
||||
// If a SHARED lock is requested, and some other connection has a SHARED or RESERVED lock,
|
||||
// then increment the reference count and return OK.
|
||||
if eLock == _SHARED_LOCK && (fLock.state == _SHARED_LOCK || fLock.state == _RESERVED_LOCK) {
|
||||
if cLock != _NO_LOCK || fLock.shared <= 0 {
|
||||
panic(assertErr())
|
||||
}
|
||||
ptr.SetLock(_SHARED_LOCK)
|
||||
fLock.shared++
|
||||
return _OK
|
||||
}
|
||||
|
||||
// If control gets to this point, then actually go ahead and make
|
||||
// operating system calls for the specified lock.
|
||||
switch eLock {
|
||||
case _SHARED_LOCK:
|
||||
if fLock.state != _NO_LOCK || fLock.shared != 0 {
|
||||
// Test the PENDING lock before acquiring a new SHARED lock.
|
||||
if locked, _ := fLock.CheckPending(); locked {
|
||||
return uint32(BUSY)
|
||||
}
|
||||
|
||||
// If some other connection has a SHARED or RESERVED lock,
|
||||
// increment the reference count and return OK.
|
||||
if fLock.state == _SHARED_LOCK || fLock.state == _RESERVED_LOCK {
|
||||
ptr.SetLock(_SHARED_LOCK)
|
||||
fLock.shared++
|
||||
return _OK
|
||||
}
|
||||
|
||||
// Must be unlocked to get SHARED.
|
||||
if fLock.state != _NO_LOCK {
|
||||
panic(assertErr())
|
||||
}
|
||||
if rc := fLock.GetShared(); rc != _OK {
|
||||
@@ -122,7 +140,8 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
|
||||
return _OK
|
||||
|
||||
case _RESERVED_LOCK:
|
||||
if fLock.state != _SHARED_LOCK || fLock.shared <= 0 {
|
||||
// Must be SHARED to get RESERVED.
|
||||
if fLock.state != _SHARED_LOCK {
|
||||
panic(assertErr())
|
||||
}
|
||||
if rc := fLock.GetReserved(); rc != _OK {
|
||||
@@ -133,7 +152,8 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
|
||||
return _OK
|
||||
|
||||
case _EXCLUSIVE_LOCK:
|
||||
if fLock.state <= _NO_LOCK || fLock.state >= _EXCLUSIVE_LOCK || fLock.shared <= 0 {
|
||||
// Must be SHARED, PENDING or RESERVED to get EXCLUSIVE.
|
||||
if fLock.state <= _NO_LOCK || fLock.state >= _EXCLUSIVE_LOCK {
|
||||
panic(assertErr())
|
||||
}
|
||||
|
||||
@@ -164,6 +184,7 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
|
||||
}
|
||||
|
||||
func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) uint32 {
|
||||
// Argument check.
|
||||
if eLock != _NO_LOCK && eLock != _SHARED_LOCK {
|
||||
panic(assertErr())
|
||||
}
|
||||
@@ -171,6 +192,11 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS
|
||||
ptr := vfsFilePtr{mod, pFile}
|
||||
cLock := ptr.Lock()
|
||||
|
||||
// Connection state check.
|
||||
if cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK {
|
||||
panic(assertErr())
|
||||
}
|
||||
|
||||
// If we don't have a more restrictive lock, do nothing.
|
||||
if cLock <= eLock {
|
||||
return _OK
|
||||
@@ -180,10 +206,20 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS
|
||||
fLock.Lock()
|
||||
defer fLock.Unlock()
|
||||
|
||||
if fLock.shared <= 0 {
|
||||
// File state check.
|
||||
switch {
|
||||
case fLock.state <= _NO_LOCK || fLock.state > _EXCLUSIVE_LOCK:
|
||||
panic(assertErr())
|
||||
case fLock.state == _EXCLUSIVE_LOCK && fLock.shared != 1:
|
||||
panic(assertErr())
|
||||
case fLock.shared <= 0:
|
||||
panic(assertErr())
|
||||
case fLock.state < cLock:
|
||||
panic(assertErr())
|
||||
}
|
||||
|
||||
if cLock > _SHARED_LOCK {
|
||||
// The connection must own the lock to release it.
|
||||
if cLock != fLock.state {
|
||||
panic(assertErr())
|
||||
}
|
||||
@@ -197,6 +233,7 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS
|
||||
}
|
||||
}
|
||||
|
||||
// If we get here, make sure we're dropping all locks.
|
||||
if eLock != _NO_LOCK {
|
||||
panic(assertErr())
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
func Test_vfsLock(t *testing.T) {
|
||||
// Other OSes lack open file descriptors locks.
|
||||
switch runtime.GOOS {
|
||||
case "linux", "darwin", "solaris", "windows":
|
||||
case "linux", "darwin", "illumos", "windows":
|
||||
//
|
||||
default:
|
||||
t.Skip()
|
||||
|
||||
12
vfs_test.go
12
vfs_test.go
@@ -163,16 +163,10 @@ func Test_vfsDelete(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_vfsAccess(t *testing.T) {
|
||||
dir, err := os.MkdirTemp("", "sqlite3-")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
mem := newMemory(128 + _MAX_PATHNAME)
|
||||
mem.writeString(8, dir)
|
||||
mem.writeString(8, t.TempDir())
|
||||
|
||||
rc := vfsAccess(context.TODO(), mem.mod, 0, 8, ACCESS_EXISTS, 4)
|
||||
rc := vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_EXISTS, 4)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
@@ -180,7 +174,7 @@ func Test_vfsAccess(t *testing.T) {
|
||||
t.Error("directory did not exist")
|
||||
}
|
||||
|
||||
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, ACCESS_READWRITE, 4)
|
||||
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_READWRITE, 4)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
|
||||
39
vfs_unix.go
39
vfs_unix.go
@@ -13,19 +13,8 @@ func deleteOnClose(f *os.File) {
|
||||
}
|
||||
|
||||
func (l *vfsFileLocker) GetShared() xErrorCode {
|
||||
// A PENDING lock is needed before acquiring a SHARED lock.
|
||||
if rc := l.readLock(_PENDING_BYTE, 1); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
|
||||
// Acquire the SHARED lock.
|
||||
rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE)
|
||||
|
||||
// Drop the temporary PENDING lock.
|
||||
if rc2 := l.unlock(_PENDING_BYTE, 1); rc == _OK {
|
||||
return rc2
|
||||
}
|
||||
return rc
|
||||
return l.readLock(_SHARED_FIRST, _SHARED_SIZE)
|
||||
}
|
||||
|
||||
func (l *vfsFileLocker) GetReserved() xErrorCode {
|
||||
@@ -68,6 +57,11 @@ func (l *vfsFileLocker) CheckReserved() (bool, xErrorCode) {
|
||||
return l.checkLock(_RESERVED_BYTE, 1)
|
||||
}
|
||||
|
||||
func (l *vfsFileLocker) CheckPending() (bool, xErrorCode) {
|
||||
// Test the PENDING lock.
|
||||
return l.checkLock(_PENDING_BYTE, 1)
|
||||
}
|
||||
|
||||
func (l *vfsFileLocker) unlock(start, len int64) xErrorCode {
|
||||
err := l.fcntlSetLock(&syscall.Flock_t{
|
||||
Type: syscall.F_UNLCK,
|
||||
@@ -85,7 +79,7 @@ func (l *vfsFileLocker) readLock(start, len int64) xErrorCode {
|
||||
Type: syscall.F_RDLCK,
|
||||
Start: start,
|
||||
Len: len,
|
||||
}), IOERR_LOCK)
|
||||
}), IOERR_RDLOCK)
|
||||
}
|
||||
|
||||
func (l *vfsFileLocker) writeLock(start, len int64) xErrorCode {
|
||||
@@ -117,7 +111,7 @@ func (l *vfsFileLocker) fcntlGetLock(lock *syscall.Flock_t) error {
|
||||
case "darwin":
|
||||
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
|
||||
F_GETLK = 92 // F_OFD_GETLK
|
||||
case "solaris":
|
||||
case "illumos":
|
||||
// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/fcntl.h
|
||||
F_GETLK = 47 // F_OFD_GETLK
|
||||
}
|
||||
@@ -133,7 +127,7 @@ func (l *vfsFileLocker) fcntlSetLock(lock *syscall.Flock_t) error {
|
||||
case "darwin":
|
||||
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
|
||||
F_SETLK = 90 // F_OFD_SETLK
|
||||
case "solaris":
|
||||
case "illumos":
|
||||
// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/fcntl.h
|
||||
F_SETLK = 48 // F_OFD_SETLK
|
||||
}
|
||||
@@ -146,13 +140,14 @@ func (*vfsFileLocker) errorCode(err error, def xErrorCode) xErrorCode {
|
||||
}
|
||||
if errno, ok := err.(syscall.Errno); ok {
|
||||
switch errno {
|
||||
case syscall.EACCES:
|
||||
case syscall.EAGAIN:
|
||||
case syscall.EBUSY:
|
||||
case syscall.EINTR:
|
||||
case syscall.ENOLCK:
|
||||
case syscall.EDEADLK:
|
||||
case syscall.ETIMEDOUT:
|
||||
case
|
||||
syscall.EACCES,
|
||||
syscall.EAGAIN,
|
||||
syscall.EBUSY,
|
||||
syscall.EINTR,
|
||||
syscall.ENOLCK,
|
||||
syscall.EDEADLK,
|
||||
syscall.ETIMEDOUT:
|
||||
return xErrorCode(BUSY)
|
||||
case syscall.EPERM:
|
||||
return xErrorCode(PERM)
|
||||
|
||||
@@ -10,19 +10,8 @@ import (
|
||||
func deleteOnClose(f *os.File) {}
|
||||
|
||||
func (l *vfsFileLocker) GetShared() xErrorCode {
|
||||
// A PENDING lock is needed before acquiring a SHARED lock.
|
||||
if rc := l.readLock(_PENDING_BYTE, 1); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
|
||||
// Acquire the SHARED lock.
|
||||
rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE)
|
||||
|
||||
// Drop the temporary PENDING lock.
|
||||
if rc2 := l.unlock(_PENDING_BYTE, 1); rc == _OK {
|
||||
return rc2
|
||||
}
|
||||
return rc
|
||||
return l.readLock(_SHARED_FIRST, _SHARED_SIZE)
|
||||
}
|
||||
|
||||
func (l *vfsFileLocker) GetReserved() xErrorCode {
|
||||
@@ -83,6 +72,15 @@ func (l *vfsFileLocker) CheckReserved() (bool, xErrorCode) {
|
||||
return rc != _OK, _OK
|
||||
}
|
||||
|
||||
func (l *vfsFileLocker) CheckPending() (bool, xErrorCode) {
|
||||
// Test the PENDING lock.
|
||||
rc := l.readLock(_PENDING_BYTE, 1)
|
||||
if rc == _OK {
|
||||
l.unlock(_PENDING_BYTE, 1)
|
||||
}
|
||||
return rc != _OK, _OK
|
||||
}
|
||||
|
||||
func (l *vfsFileLocker) unlock(start, len uint32) xErrorCode {
|
||||
err := windows.UnlockFileEx(windows.Handle(l.file.Fd()),
|
||||
0, len, 0, &windows.Overlapped{Offset: start})
|
||||
@@ -96,7 +94,7 @@ func (l *vfsFileLocker) readLock(start, len uint32) xErrorCode {
|
||||
return l.errorCode(windows.LockFileEx(windows.Handle(l.file.Fd()),
|
||||
windows.LOCKFILE_FAIL_IMMEDIATELY,
|
||||
0, len, 0, &windows.Overlapped{Offset: start}),
|
||||
IOERR_LOCK)
|
||||
IOERR_RDLOCK)
|
||||
}
|
||||
|
||||
func (l *vfsFileLocker) writeLock(start, len uint32) xErrorCode {
|
||||
|
||||
Reference in New Issue
Block a user