Compare commits

...

23 Commits

Author SHA1 Message Date
Nuno Cruces
ec5bd236f8 Documentation. 2023-02-18 03:46:52 +00:00
Nuno Cruces
a51cdb04e6 Exec fast path. 2023-02-18 02:57:47 +00:00
Nuno Cruces
f50d5df3d0 Context cancellation. 2023-02-18 02:16:11 +00:00
Nuno Cruces
4ac2ccf473 Named parameters. 2023-02-18 00:47:56 +00:00
Nuno Cruces
5f7a72a553 Connection reuse. 2023-02-17 16:36:47 +00:00
Nuno Cruces
643b004727 Reuse byte slices. 2023-02-17 12:30:07 +00:00
Nuno Cruces
72e0415184 Time handling. 2023-02-17 10:40:43 +00:00
Nuno Cruces
28cb558d10 Minimal database/sql driver. 2023-02-17 02:21:07 +00:00
Nuno Cruces
23806b0db1 More tests. 2023-02-16 13:58:53 +00:00
Nuno Cruces
6a80499823 Panic consistently. 2023-02-16 13:52:05 +00:00
Nuno Cruces
110f36bdf9 Fix flakiness. 2023-02-16 13:37:29 +00:00
Nuno Cruces
f85426022d Test data races. 2023-02-15 16:24:34 +00:00
Nuno Cruces
78fd0cbee5 Towards database/sql. 2023-02-15 16:15:14 +00:00
Nuno Cruces
0d59065719 Lock errors. 2023-02-14 11:38:05 +00:00
Nuno Cruces
6110e2d6e2 Memory arenas. 2023-02-14 11:34:24 +00:00
Nuno Cruces
275b8c38a2 Documentation. 2023-02-14 11:33:41 +00:00
Nuno Cruces
fd1244c471 Support utf16 DBs. 2023-02-14 01:21:12 +00:00
Nuno Cruces
f11d294825 Check integrity. 2023-02-13 16:00:27 +00:00
Nuno Cruces
22b702fcda Synchronize IPC test. 2023-02-13 15:23:11 +00:00
Nuno Cruces
831817a737 Test IPC. 2023-02-13 15:01:36 +00:00
Nuno Cruces
7329d9f2fb Avoid writer starvation. 2023-02-13 13:53:32 +00:00
Nuno Cruces
3aad1d5d79 Towards xFileControl. 2023-02-13 13:52:52 +00:00
Nuno Cruces
f72c599d2d illumos OFD locks. 2023-02-13 13:51:35 +00:00
30 changed files with 1226 additions and 189 deletions

View File

@@ -28,7 +28,13 @@ jobs:
- name: Test
run: go test -v ./...
- if: matrix.os == 'ubuntu-latest'
name: Update coverage report
- name: Test data races
run: go test -v -race ./...
if: matrix.os == 'ubuntu-latest'
- name: Update coverage report
uses: ncruces/go-coverage-report@main
if: |
matrix.os == 'ubuntu-latest' &&
github.event_name == 'push'
continue-on-error: true

View File

@@ -6,13 +6,16 @@
⚠️ CAUTION ⚠️
This is still very much a WIP.\
DO NOT USE this with data you care about.
This is a WIP.\
DO NOT USE with data you care about.
Roadmap:
- [x] build SQLite using `zig cc --target=wasm32-wasi`
- [x] `:memory:` databases
- [x] port [`test_demovfs.c`](https://www.sqlite.org/src/doc/trunk/src/test_demovfs.c) to Go
- branch [`wasi`](https://github.com/ncruces/go-sqlite3/tree/wasi) uses `test_demovfs.c` directly
- [x] come up with a simple, nice API, enough for simple queries
- [x] file locking, compatible with SQLite on Windows/Unix
- [x] design a simple, nice API, enough for simple use cases
- [x] provide a simple `database/sql` driver
- [x] file locking, compatible with SQLite on Windows/Unix
- [ ] shared memory, compatible with SQLite on Windows/Unix
- needed for improved WAL mode

25
api.go
View File

@@ -1,3 +1,4 @@
// Package sqlite3 wraps the C SQLite API.
package sqlite3
import (
@@ -44,18 +45,26 @@ func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) {
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
bindName: getFun("sqlite3_bind_parameter_name"),
bindNull: getFun("sqlite3_bind_null"),
bindInteger: getFun("sqlite3_bind_int64"),
bindFloat: getFun("sqlite3_bind_double"),
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
bindNull: getFun("sqlite3_bind_null"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
columnInteger: getFun("sqlite3_column_int64"),
columnFloat: getFun("sqlite3_column_double"),
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
columnType: getFun("sqlite3_column_type"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
changes: getFun("sqlite3_changes64"),
interrupt: getFun("sqlite3_interrupt"),
},
}
if err != nil {
@@ -80,16 +89,24 @@ type sqliteAPI struct {
step api.Function
exec api.Function
clearBindings api.Function
bindNull api.Function
bindCount api.Function
bindIndex api.Function
bindName api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
bindNull api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
columnInteger api.Function
columnFloat api.Function
columnText api.Function
columnBlob api.Function
columnBytes api.Function
columnType api.Function
lastRowid api.Function
changes api.Function
interrupt api.Function
}

View File

@@ -24,13 +24,11 @@ type sqlite3Runtime struct {
runtime wazero.Runtime
compiled wazero.CompiledModule
instances atomic.Uint64
ctx context.Context
err error
}
func (s *sqlite3Runtime) instantiateModule(ctx context.Context) (api.Module, error) {
s.ctx = ctx
s.once.Do(s.compileModule)
s.once.Do(func() { s.compileModule(ctx) })
if s.err != nil {
return nil, s.err
}
@@ -40,12 +38,9 @@ func (s *sqlite3Runtime) instantiateModule(ctx context.Context) (api.Module, err
return s.runtime.InstantiateModule(ctx, s.compiled, cfg)
}
func (s *sqlite3Runtime) compileModule() {
s.runtime = wazero.NewRuntime(s.ctx)
s.err = vfsInstantiate(s.ctx, s.runtime)
if s.err != nil {
return
}
func (s *sqlite3Runtime) compileModule(ctx context.Context) {
s.runtime = wazero.NewRuntime(ctx)
vfsInstantiate(ctx, s.runtime)
bin := Binary
if bin == nil && Path != "" {
@@ -55,5 +50,5 @@ func (s *sqlite3Runtime) compileModule() {
}
}
s.compiled, s.err = s.runtime.CompileModule(s.ctx, bin)
s.compiled, s.err = s.runtime.CompileModule(ctx, bin)
}

206
conn.go
View File

@@ -13,6 +13,11 @@ type Conn struct {
api sqliteAPI
mem memory
handle uint32
arena arena
pending *Stmt
waiter chan struct{}
done <-chan struct{}
}
// Open calls [OpenFlags] with [OPEN_READWRITE] and [OPEN_CREATE].
@@ -39,15 +44,15 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
if err != nil {
return nil, err
}
c.arena = c.newArena(1024)
namePtr := c.newString(filename)
connPtr := c.new(ptrlen)
defer c.free(namePtr)
defer c.free(connPtr)
defer c.arena.reset()
connPtr := c.arena.new(ptrlen)
namePtr := c.arena.string(filename)
r, err := c.api.open.Call(c.ctx, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
if err != nil {
return nil, err
panic(err)
}
c.handle = c.mem.readUint32(connPtr)
@@ -65,13 +70,15 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
//
// https://www.sqlite.org/c3ref/close.html
func (c *Conn) Close() error {
if c == nil {
if c == nil || c.handle == 0 {
return nil
}
c.SetInterrupt(nil)
r, err := c.api.close.Call(c.ctx, uint64(c.handle))
if err != nil {
return err
panic(err)
}
if err := c.error(r[0]); err != nil {
@@ -82,17 +89,80 @@ func (c *Conn) Close() error {
return c.mem.mod.Close(c.ctx)
}
// SetInterrupt interrupts a long-running query when done is closed.
//
// Subsequent uses of the connection will return [INTERRUPT]
// until done is reset by another call to SetInterrupt.
//
// Typically, done is provided by [context.Context.Done]:
//
// ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
// conn.SetInterrupt(ctx.Done())
// defer cancel()
//
// https://www.sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) {
// Is a waiter running?
if c.waiter != nil {
c.waiter <- struct{}{} // Cancel the waiter.
<-c.waiter // Wait for it to finish.
c.waiter = nil
}
// Finalize the uncompleted SQL statement.
if c.pending != nil {
c.pending.Close()
c.pending = nil
}
old = c.done
c.done = done
if done == nil {
return old
}
// Creating an uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
c.pending.Step()
waiter := make(chan struct{})
c.waiter = waiter
go func() {
select {
case <-waiter: // Waiter was cancelled.
break
case <-done: // Done was closed.
// This is safe to call from a goroutine
// because it doesn't touch the C stack.
_, err := c.api.interrupt.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
// Wait for the next call to SetInterrupt.
<-waiter
}
// Signal that the waiter has finished.
waiter <- struct{}{}
}()
return old
}
// Exec is a convenience function that allows an application to run
// multiple statements of SQL without having to use a lot of code.
//
// https://www.sqlite.org/c3ref/exec.html
func (c *Conn) Exec(sql string) error {
sqlPtr := c.newString(sql)
defer c.free(sqlPtr)
defer c.arena.reset()
sqlPtr := c.arena.string(sql)
r, err := c.api.exec.Call(c.ctx, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
if err != nil {
return err
panic(err)
}
return c.error(r[0])
}
@@ -109,18 +179,16 @@ func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
//
// https://www.sqlite.org/c3ref/prepare.html
func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) {
sqlPtr := c.newString(sql)
stmtPtr := c.new(ptrlen)
tailPtr := c.new(ptrlen)
defer c.free(sqlPtr)
defer c.free(stmtPtr)
defer c.free(tailPtr)
defer c.arena.reset()
stmtPtr := c.arena.new(ptrlen)
tailPtr := c.arena.new(ptrlen)
sqlPtr := c.arena.string(sql)
r, err := c.api.prepare.Call(c.ctx, uint64(c.handle),
uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),
uint64(stmtPtr), uint64(tailPtr))
if err != nil {
return nil, "", err
panic(err)
}
stmt = &Stmt{c: c}
@@ -137,6 +205,31 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
return
}
// LastInsertRowID returns the rowid of the most recent successful INSERT
// on the database connection.
//
// https://www.sqlite.org/c3ref/last_insert_rowid.html
func (c *Conn) LastInsertRowID() uint64 {
r, err := c.api.lastRowid.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
return r[0]
}
// Changes returns the number of rows modified, inserted or deleted
// by the most recently completed INSERT, UPDATE or DELETE statement
// on the database connection.
//
// https://www.sqlite.org/c3ref/changes.html
func (c *Conn) Changes() uint64 {
r, err := c.api.changes.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
return r[0]
}
func (c *Conn) error(rc uint64, sql ...string) error {
if rc == _OK {
return nil
@@ -150,28 +243,26 @@ func (c *Conn) error(rc uint64, sql ...string) error {
var r []uint64
// sqlite3_errmsg is guaranteed to never change the value of the error code.
r, _ = c.api.errstr.Call(c.ctx, rc)
if r != nil {
err.str = c.mem.readString(uint32(r[0]), _MAX_STRING)
}
r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle))
if r != nil {
err.msg = c.mem.readString(uint32(r[0]), 512)
err.msg = c.mem.readString(uint32(r[0]), _MAX_STRING)
}
if sql != nil {
// sqlite3_error_offset is guaranteed to never change the value of the error code.
r, _ = c.api.erroff.Call(c.ctx, uint64(c.handle))
if r != nil && r[0] != math.MaxUint32 {
err.sql = sql[0][r[0]:]
}
}
r, _ = c.api.errstr.Call(c.ctx, rc)
if r != nil {
err.str = c.mem.readString(uint32(r[0]), 512)
}
if err.msg == err.str {
switch err.msg {
case err.str, "not an error":
err.msg = ""
}
return &err
}
@@ -186,13 +277,13 @@ func (c *Conn) free(ptr uint32) {
}
}
func (c *Conn) new(len uint32) uint32 {
r, err := c.api.malloc.Call(c.ctx, uint64(len))
func (c *Conn) new(size uint32) uint32 {
r, err := c.api.malloc.Call(c.ctx, uint64(size))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 && len != 0 {
if ptr == 0 && size != 0 {
panic(oomErr)
}
return ptr
@@ -202,19 +293,54 @@ func (c *Conn) newBytes(b []byte) uint32 {
if b == nil {
return 0
}
siz := uint32(len(b))
ptr := c.new(siz)
buf := c.mem.view(ptr, siz)
copy(buf, b)
ptr := c.new(uint32(len(b)))
c.mem.writeBytes(ptr, b)
return ptr
}
func (c *Conn) newString(s string) uint32 {
siz := uint32(len(s) + 1)
ptr := c.new(siz)
buf := c.mem.view(ptr, siz)
buf[len(s)] = 0
copy(buf, s)
ptr := c.new(uint32(len(s) + 1))
c.mem.writeString(ptr, s)
return ptr
}
func (c *Conn) newArena(size uint32) arena {
return arena{
c: c,
size: size,
base: c.new(size),
}
}
type arena struct {
c *Conn
base uint32
next uint32
size uint32
ptrs []uint32
}
func (a *arena) reset() {
for _, ptr := range a.ptrs {
a.c.free(ptr)
}
a.ptrs = nil
a.next = 0
}
func (a *arena) new(size uint32) uint32 {
if a.next+size <= a.size {
ptr := a.base + a.next
a.next += size
return ptr
}
ptr := a.c.new(size)
a.ptrs = append(a.ptrs, ptr)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint32(len(s) + 1))
a.c.mem.writeString(ptr, s)
return ptr
}

View File

@@ -2,6 +2,7 @@ package sqlite3
import (
"bytes"
"context"
"errors"
"math"
"testing"
@@ -13,13 +14,15 @@ func TestConn_Close(t *testing.T) {
}
func TestConn_Close_BUSY(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare("BEGIN")
stmt, _, err := db.Prepare(`BEGIN`)
if err != nil {
t.Fatal(err)
}
@@ -41,14 +44,92 @@ func TestConn_Close_BUSY(t *testing.T) {
}
}
func TestConn_Prepare_Empty(t *testing.T) {
func TestConn_SetInterrupt(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare("")
ctx, cancel := context.WithCancel(context.TODO())
db.SetInterrupt(ctx.Done())
// Interrupt doesn't interrupt this.
err = db.Exec(`SELECT 1`)
if err != nil {
t.Fatal(err)
}
db.SetInterrupt(nil)
stmt, _, err := db.Prepare(`
WITH RECURSIVE
fibonacci (curr, next)
AS (
SELECT 0, 1
UNION ALL
SELECT next, curr + next FROM fibonacci
LIMIT 1e6
)
SELECT min(curr) FROM fibonacci
`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
cancel()
db.SetInterrupt(ctx.Done())
var serr *Error
// Interrupting works.
err = stmt.Exec()
if err != nil {
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
}
// Interrupting sticks.
err = db.Exec(`SELECT 1`)
if err != nil {
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
}
db.SetInterrupt(nil)
// Interrupting can be cleared.
err = db.Exec(`SELECT 1`)
if err != nil {
t.Fatal(err)
}
}
func TestConn_Prepare_Empty(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(``)
if err != nil {
t.Fatal(err)
}
@@ -60,6 +141,8 @@ func TestConn_Prepare_Empty(t *testing.T) {
}
func TestConn_Prepare_Invalid(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
@@ -68,7 +151,7 @@ func TestConn_Prepare_Invalid(t *testing.T) {
var serr *Error
_, _, err = db.Prepare("SELECT")
_, _, err = db.Prepare(`SELECT`)
if err == nil {
t.Fatal("want error")
}
@@ -82,7 +165,7 @@ func TestConn_Prepare_Invalid(t *testing.T) {
t.Error("got message: ", got)
}
_, _, err = db.Prepare("SELECT * FRM sqlite_schema")
_, _, err = db.Prepare(`SELECT * FRM sqlite_schema`)
if err == nil {
t.Fatal("want error")
}
@@ -101,6 +184,8 @@ func TestConn_Prepare_Invalid(t *testing.T) {
}
func TestConn_new(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
@@ -112,7 +197,41 @@ func TestConn_new(t *testing.T) {
t.Error("want panic")
}
func TestConn_newArena(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
arena := db.newArena(16)
defer arena.reset()
const title = "Lorem ipsum"
ptr := arena.string(title)
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := db.mem.readString(ptr, math.MaxUint32); got != title {
t.Errorf("got %q, want %q", got, title)
}
const body = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
ptr = arena.string(body)
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := db.mem.readString(ptr, math.MaxUint32); got != body {
t.Errorf("got %q, want %q", got, body)
}
}
func TestConn_newBytes(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
@@ -137,6 +256,8 @@ func TestConn_newBytes(t *testing.T) {
}
func TestConn_newString(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
@@ -161,6 +282,8 @@ func TestConn_newString(t *testing.T) {
}
func TestConn_getString(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
@@ -200,6 +323,8 @@ func TestConn_getString(t *testing.T) {
}
func TestConn_free(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)

View File

@@ -9,11 +9,15 @@ const (
_UTF8 = 1
_MAX_STRING = 512 // Used for short strings: names, error messages…
_MAX_PATHNAME = 512
ptrlen = 4
)
// ErrorCode is a result code that [Error.Code] might return.
//
// https://www.sqlite.org/rescode.html
type ErrorCode uint8
const (
@@ -47,6 +51,9 @@ const (
WARNING ErrorCode = 28 /* Warnings from sqlite3_log() */
)
// ExtendedErrorCode is a result code that [Error.ExtendedCode] might return.
//
// https://www.sqlite.org/rescode.html
type (
ExtendedErrorCode uint16
xErrorCode = ExtendedErrorCode
@@ -128,6 +135,9 @@ const (
AUTH_USER ExtendedErrorCode = xErrorCode(AUTH) | (1 << 8)
)
// OpenFlag is a flag for a file open operation.
//
// https://www.sqlite.org/c3ref/c_open_autoproxy.html
type OpenFlag uint32
const (
@@ -155,14 +165,17 @@ const (
OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */
)
type AccessFlag uint32
type _AccessFlag uint32
const (
ACCESS_EXISTS AccessFlag = 0
ACCESS_READWRITE AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
ACCESS_READ AccessFlag = 2 /* Unused */
_ACCESS_EXISTS _AccessFlag = 0
_ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
_ACCESS_READ _AccessFlag = 2 /* Unused */
)
// PrepareFlag is a flag that can be passed to [Conn.PrepareFlags].
//
// https://www.sqlite.org/c3ref/c_prepare_normalize.html
type PrepareFlag uint32
const (
@@ -171,6 +184,9 @@ const (
PREPARE_NO_VTAB PrepareFlag = 0x04
)
// Datatype is a fundamental datatype of SQLite.
//
// https://www.sqlite.org/c3ref/c_blob.html
type Datatype uint32
const (

284
driver/driver.go Normal file
View File

@@ -0,0 +1,284 @@
// Package driver provides a database/sql driver for SQLite.
package driver
import (
"context"
"database/sql"
"database/sql/driver"
"io"
"time"
"github.com/ncruces/go-sqlite3"
)
func init() {
sql.Register("sqlite3", sqlite{})
}
type sqlite struct{}
func (sqlite) Open(name string) (driver.Conn, error) {
c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE)
if err != nil {
return nil, err
}
// If the database is not in WAL mode,
// use normal locking mode.
journal, err := pragma(c, "journal_mode")
if err != nil {
return nil, err
}
if journal != "wal" {
pragma(c, "locking_mode=normal")
}
return conn{c}, nil
}
type conn struct{ conn *sqlite3.Conn }
var (
// Ensure these interfaces are implemented:
_ driver.Validator = conn{}
_ driver.ExecerContext = conn{}
// _ driver.ConnBeginTx = conn{}
// _ driver.SessionResetter = conn{}
)
func (c conn) Close() error {
return c.conn.Close()
}
func (c conn) IsValid() bool {
// Pool only normal locking mode connections.
mode, _ := pragma(c.conn, "locking_mode")
return mode == "normal"
}
func (c conn) Begin() (driver.Tx, error) {
err := c.conn.Exec(`BEGIN`)
if err != nil {
return nil, err
}
return c, nil
}
func (c conn) Commit() error {
err := c.conn.Exec(`COMMIT`)
if err != nil {
c.Rollback()
}
return err
}
func (c conn) Rollback() error {
return c.conn.Exec(`ROLLBACK`)
}
func (c conn) Prepare(query string) (driver.Stmt, error) {
s, tail, err := c.conn.Prepare(query)
if err != nil {
return nil, err
}
if tail != "" {
// Check if the tail contains any SQL.
s, _, err := c.conn.Prepare(tail)
if err != nil {
return nil, err
}
if s != nil {
s.Close()
return nil, tailErr
}
}
return stmt{s, c.conn}, nil
}
func (c conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if len(args) != 0 {
// Slow path.
return nil, driver.ErrSkip
}
ch := c.conn.SetInterrupt(ctx.Done())
defer c.conn.SetInterrupt(ch)
err := c.conn.Exec(query)
if err != nil {
return nil, err
}
return result{
int64(c.conn.LastInsertRowID()),
int64(c.conn.Changes()),
}, nil
}
func pragma(c *sqlite3.Conn, pragma string) (string, error) {
stmt, _, err := c.Prepare(`PRAGMA ` + pragma)
if err != nil {
return "", err
}
defer stmt.Close()
if stmt.Step() {
return stmt.ColumnText(0), nil
}
return "", stmt.Err()
}
type stmt struct {
stmt *sqlite3.Stmt
conn *sqlite3.Conn
}
var (
// Ensure these interfaces are implemented:
_ driver.StmtExecContext = stmt{}
_ driver.StmtQueryContext = stmt{}
)
func (s stmt) Close() error {
return s.stmt.Close()
}
func (s stmt) NumInput() int {
return s.stmt.BindCount()
}
// Deprecated: use ExecContext instead.
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
return s.ExecContext(context.Background(), namedValues(args))
}
// Deprecated: use QueryContext instead.
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
return s.QueryContext(context.Background(), namedValues(args))
}
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
_, err := s.QueryContext(ctx, args)
if err != nil {
return nil, err
}
err = s.stmt.Exec()
if err != nil {
return nil, err
}
return result{
int64(s.conn.LastInsertRowID()),
int64(s.conn.Changes()),
}, nil
}
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
err := s.stmt.ClearBindings()
if err != nil {
return nil, err
}
var ids [3]int
for _, arg := range args {
ids := ids[:0]
if arg.Name == "" {
ids = append(ids, arg.Ordinal)
} else {
for _, prefix := range []string{":", "@", "$"} {
if id := s.stmt.BindIndex(prefix + arg.Name); id != 0 {
ids = append(ids, id)
}
}
}
for _, id := range ids {
switch a := arg.Value.(type) {
case bool:
err = s.stmt.BindBool(id, a)
case int64:
err = s.stmt.BindInt64(id, a)
case float64:
err = s.stmt.BindFloat(id, a)
case string:
err = s.stmt.BindText(id, a)
case []byte:
err = s.stmt.BindBlob(id, a)
case time.Time:
err = s.stmt.BindText(id, a.Format(time.RFC3339Nano))
case nil:
err = s.stmt.BindNull(id)
default:
panic(assertErr)
}
}
if err != nil {
return nil, err
}
}
return rows{ctx, s.stmt, s.conn}, nil
}
type result struct{ lastInsertId, rowsAffected int64 }
func (r result) LastInsertId() (int64, error) {
return r.lastInsertId, nil
}
func (r result) RowsAffected() (int64, error) {
return r.rowsAffected, nil
}
type rows struct {
ctx context.Context
stmt *sqlite3.Stmt
conn *sqlite3.Conn
}
func (r rows) Close() error {
return r.stmt.Reset()
}
func (r rows) Columns() []string {
count := r.stmt.ColumnCount()
columns := make([]string, count)
for i := range columns {
columns[i] = r.stmt.ColumnName(i)
}
return columns
}
func (r rows) Next(dest []driver.Value) error {
ch := r.conn.SetInterrupt(r.ctx.Done())
defer r.conn.SetInterrupt(ch)
if !r.stmt.Step() {
if err := r.stmt.Err(); err != nil {
return err
}
return io.EOF
}
for i := range dest {
switch r.stmt.ColumnType(i) {
case sqlite3.INTEGER:
dest[i] = r.stmt.ColumnInt64(i)
case sqlite3.FLOAT:
dest[i] = r.stmt.ColumnFloat(i)
case sqlite3.TEXT:
dest[i] = maybeDate(r.stmt.ColumnText(i))
case sqlite3.BLOB:
buf, _ := dest[i].([]byte)
dest[i] = r.stmt.ColumnBlob(i, buf)
case sqlite3.NULL:
if buf, ok := dest[i].([]byte); ok {
dest[i] = buf[0:0]
} else {
dest[i] = nil
}
default:
panic(assertErr)
}
}
return r.stmt.Err()
}

10
driver/error.go Normal file
View File

@@ -0,0 +1,10 @@
package driver
type errorString string
func (e errorString) Error() string { return string(e) }
const (
assertErr = errorString("sqlite3: assertion failed")
tailErr = errorString("sqlite3: multiple statements")
)

149
driver/example_test.go Normal file
View File

@@ -0,0 +1,149 @@
package driver_test
// Adapted from: https://go.dev/doc/tutorial/database-access
import (
"database/sql"
"fmt"
"log"
"os"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
var db *sql.DB
type Album struct {
ID int64
Title string
Artist string
Price float32
}
func Example() {
// Get a database handle.
var err error
db, err = sql.Open("sqlite3", "./recordings.db")
if err != nil {
log.Fatal(err)
}
defer db.Close()
defer os.Remove("./recordings.db")
err = createAlbumsTable()
if err != nil {
log.Fatal(err)
}
albums, err := albumsByArtist("John Coltrane")
if err != nil {
log.Fatal(err)
}
fmt.Printf("Albums found: %v\n", albums)
// Hard-code ID 2 here to test the query.
alb, err := albumByID(2)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Album found: %v\n", alb)
albID, err := addAlbum(Album{
Title: "The Modern Sound of Betty Carter",
Artist: "Betty Carter",
Price: 49.99,
})
if err != nil {
log.Fatal(err)
}
fmt.Printf("ID of added album: %v\n", albID)
// Output:
// Albums found: [{1 Blue Train John Coltrane 56.99} {2 Giant Steps John Coltrane 63.99}]
// Album found: {2 Giant Steps John Coltrane 63.99}
// ID of added album: 5
}
func createAlbumsTable() error {
_, err := db.Exec(`
DROP TABLE IF EXISTS album;
CREATE TABLE album (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title VARCHAR(128) NOT NULL,
artist VARCHAR(255) NOT NULL,
price DECIMAL(5,2) NOT NULL
);
`)
if err != nil {
return err
}
_, err = db.Exec(`
INSERT INTO album
(title, artist, price)
VALUES
('Blue Train', 'John Coltrane', 56.99),
('Giant Steps', 'John Coltrane', 63.99),
('Jeru', 'Gerry Mulligan', 17.99),
('Sarah Vaughan', 'Sarah Vaughan', 34.98)
`)
if err != nil {
return err
}
return nil
}
// albumsByArtist queries for albums that have the specified artist name.
func albumsByArtist(name string) ([]Album, error) {
// An albums slice to hold data from returned rows.
var albums []Album
rows, err := db.Query("SELECT * FROM album WHERE artist = ?", name)
if err != nil {
return nil, fmt.Errorf("albumsByArtist %q: %w", name, err)
}
defer rows.Close()
// Loop through rows, using Scan to assign column data to struct fields.
for rows.Next() {
var alb Album
if err := rows.Scan(&alb.ID, &alb.Title, &alb.Artist, &alb.Price); err != nil {
return nil, fmt.Errorf("albumsByArtist %q: %w", name, err)
}
albums = append(albums, alb)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("albumsByArtist %q: %w", name, err)
}
return albums, nil
}
// albumByID queries for the album with the specified ID.
func albumByID(id int64) (Album, error) {
// An album to hold data from the returned row.
var alb Album
row := db.QueryRow("SELECT * FROM album WHERE id = ?", id)
if err := row.Scan(&alb.ID, &alb.Title, &alb.Artist, &alb.Price); err != nil {
if err == sql.ErrNoRows {
return alb, fmt.Errorf("albumsById %d: no such album", id)
}
return alb, fmt.Errorf("albumsById %d: %w", id, err)
}
return alb, nil
}
// addAlbum adds the specified album to the database,
// returning the album ID of the new entry
func addAlbum(alb Album) (int64, error) {
result, err := db.Exec("INSERT INTO album (title, artist, price) VALUES (?, ?, ?)", alb.Title, alb.Artist, alb.Price)
if err != nil {
return 0, fmt.Errorf("addAlbum: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("addAlbum: %w", err)
}
return id, nil
}

19
driver/time.go Normal file
View File

@@ -0,0 +1,19 @@
package driver
import (
"database/sql/driver"
"time"
)
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
// if it roundtrips back to the same string.
// This way times can be persisted to, and recovered from, the database,
// but if a string is needed, [database.sql] will recover the same string.
// TODO: optimize and fuzz test.
func maybeDate(text string) driver.Value {
date, err := time.Parse(time.RFC3339Nano, text)
if err == nil && date.Format(time.RFC3339Nano) == text {
return date
}
return text
}

14
driver/util.go Normal file
View File

@@ -0,0 +1,14 @@
package driver
import "database/sql/driver"
func namedValues(args []driver.Value) []driver.NamedValue {
named := make([]driver.NamedValue, len(args))
for i, v := range args {
named[i] = driver.NamedValue{
Ordinal: i + 1,
Value: v,
}
}
return named
}

18
driver/util_test.go Normal file
View File

@@ -0,0 +1,18 @@
package driver
import (
"database/sql/driver"
"reflect"
"testing"
)
func Test_namedValues(t *testing.T) {
want := []driver.NamedValue{
{Ordinal: 1, Value: true},
{Ordinal: 2, Value: false},
}
got := namedValues([]driver.Value{true, false})
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
}

View File

@@ -28,15 +28,23 @@ zig cc --target=wasm32-wasi -flto -g0 -Os \
-Wl,--export=sqlite3_step \
-Wl,--export=sqlite3_exec \
-Wl,--export=sqlite3_clear_bindings \
-Wl,--export=sqlite3_bind_parameter_count \
-Wl,--export=sqlite3_bind_parameter_index \
-Wl,--export=sqlite3_bind_parameter_name \
-Wl,--export=sqlite3_bind_null \
-Wl,--export=sqlite3_bind_int64 \
-Wl,--export=sqlite3_bind_double \
-Wl,--export=sqlite3_bind_text64 \
-Wl,--export=sqlite3_bind_blob64 \
-Wl,--export=sqlite3_bind_zeroblob64 \
-Wl,--export=sqlite3_bind_null \
-Wl,--export=sqlite3_column_count \
-Wl,--export=sqlite3_column_name \
-Wl,--export=sqlite3_column_type \
-Wl,--export=sqlite3_column_int64 \
-Wl,--export=sqlite3_column_double \
-Wl,--export=sqlite3_column_text \
-Wl,--export=sqlite3_column_blob \
-Wl,--export=sqlite3_column_bytes \
-Wl,--export=sqlite3_column_type \
-Wl,--export=sqlite3_last_insert_rowid \
-Wl,--export=sqlite3_changes64 \
-Wl,--export=sqlite3_interrupt \

View File

@@ -1,3 +1,7 @@
// Package embed embeds SQLite into your application.
//
// You can obtain this build of SQLite from:
// https://github.com/ncruces/go-sqlite3/tree/main/embed
package embed
import (

Binary file not shown.

View File

@@ -1,4 +1,4 @@
package main
package sqlite3_test
import (
"fmt"
@@ -8,8 +8,10 @@ import (
_ "github.com/ncruces/go-sqlite3/embed"
)
func main() {
db, err := sqlite3.Open(":memory:")
const memory = ":memory:"
func Example() {
db, err := sqlite3.Open(memory)
if err != nil {
log.Fatal(err)
}
@@ -45,4 +47,9 @@ func main() {
if err != nil {
log.Fatal(err)
}
// Output:
// 0 go
// 1 zig
// 2 whatever
}

8
mem.go
View File

@@ -99,9 +99,13 @@ func (m memory) readString(ptr, maxlen uint32) string {
}
}
func (m memory) writeBytes(ptr uint32, b []byte) {
buf := m.view(ptr, uint32(len(b)))
copy(buf, b)
}
func (m memory) writeString(ptr uint32, s string) {
siz := uint32(len(s) + 1)
buf := m.view(ptr, siz)
buf := m.view(ptr, uint32(len(s)+1))
buf[len(s)] = 0
copy(buf, s)
}

View File

@@ -1,3 +1,4 @@
#include <stdbool.h>
#include <stdlib.h>
#include <time.h>
@@ -33,6 +34,7 @@ int go_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
int go_truncate(sqlite3_file *, sqlite3_int64 size);
int go_sync(sqlite3_file *, int flags);
int go_file_size(sqlite3_file *, sqlite3_int64 *pSize);
int go_file_control(sqlite3_file *pFile, int op, void *pArg);
int go_lock(sqlite3_file *pFile, int eLock);
int go_unlock(sqlite3_file *pFile, int eLock);
@@ -94,7 +96,7 @@ int sqlite3_os_init() {
.xCurrentTime = go_current_time,
.xCurrentTimeInt64 = go_current_time_64,
};
return sqlite3_vfs_register(&go_vfs, /*default=*/1);
return sqlite3_vfs_register(&go_vfs, /*default=*/true);
}
sqlite3_destructor_type malloc_destructor = &free;

View File

@@ -5,6 +5,9 @@
#define SQLITE_OS_OTHER 1
#define SQLITE_BYTEORDER 1234
#define HAVE_STDINT_H 1
#define HAVE_INTTYPES_H 1
#define HAVE_ISNAN 1
#define HAVE_USLEEP 1
#define HAVE_LOCALTIME_S 1
@@ -25,12 +28,18 @@
#define SQLITE_OMIT_AUTOINIT
#define SQLITE_USE_ALLOCA
// Recommended Extensions
// #define SQLITE_ENABLE_MATH_FUNCTIONS 1
// #define SQLITE_ENABLE_FTS3 1
// #define SQLITE_ENABLE_FTS3_PARENTHESIS 1
// #define SQLITE_ENABLE_FTS4 1
// #define SQLITE_ENABLE_FTS5 1
// #define SQLITE_ENABLE_RTREE 1
// #define SQLITE_ENABLE_GEOPOLY 1
// Need this to access WAL databases without the use of shared memory.
#define SQLITE_DEFAULT_LOCKING_MODE 1
// Go uses UTF-8 everywhere.
#define SQLITE_OMIT_UTF16
// Remove some testing code.
#define SQLITE_UNTESTABLE
// Implemented in Go.
int localtime_s(struct tm *const pTm, time_t const *const pTime);

96
stmt.go
View File

@@ -17,13 +17,13 @@ type Stmt struct {
//
// https://www.sqlite.org/c3ref/finalize.html
func (s *Stmt) Close() error {
if s == nil {
if s == nil || s.handle == 0 {
return nil
}
r, err := s.c.api.finalize.Call(s.c.ctx, uint64(s.handle))
if err != nil {
return err
panic(err)
}
s.handle = 0
@@ -36,7 +36,7 @@ func (s *Stmt) Close() error {
func (s *Stmt) Reset() error {
r, err := s.c.api.reset.Call(s.c.ctx, uint64(s.handle))
if err != nil {
return err
panic(err)
}
s.err = nil
return s.c.error(r[0])
@@ -48,7 +48,7 @@ func (s *Stmt) Reset() error {
func (s *Stmt) ClearBindings() error {
r, err := s.c.api.clearBindings.Call(s.c.ctx, uint64(s.handle))
if err != nil {
return err
panic(err)
}
return s.c.error(r[0])
}
@@ -65,8 +65,7 @@ func (s *Stmt) ClearBindings() error {
func (s *Stmt) Step() bool {
r, err := s.c.api.step.Call(s.c.ctx, uint64(s.handle))
if err != nil {
s.err = err
return false
panic(err)
}
if r[0] == _ROW {
return true
@@ -95,6 +94,51 @@ func (s *Stmt) Exec() error {
return s.Reset()
}
// BindCount returns the number of SQL parameters in the prepared statement.
//
// https://www.sqlite.org/c3ref/bind_parameter_count.html
func (s *Stmt) BindCount() int {
r, err := s.c.api.bindCount.Call(s.c.ctx,
uint64(s.handle))
if err != nil {
panic(err)
}
return int(r[0])
}
// BindIndex returns the index of a parameter in the prepared statement
// given its name.
//
// https://www.sqlite.org/c3ref/bind_parameter_index.html
func (s *Stmt) BindIndex(name string) int {
defer s.c.arena.reset()
namePtr := s.c.arena.string(name)
r, err := s.c.api.bindIndex.Call(s.c.ctx,
uint64(s.handle), uint64(namePtr))
if err != nil {
panic(err)
}
return int(r[0])
}
// BindName returns the name of a parameter in the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://www.sqlite.org/c3ref/bind_parameter_name.html
func (s *Stmt) BindName(param int) string {
r, err := s.c.api.bindName.Call(s.c.ctx,
uint64(s.handle), uint64(param))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
return ""
}
return s.c.mem.readString(ptr, _MAX_STRING)
}
// BindBool binds a bool to the prepared statement.
// The leftmost SQL parameter has an index of 1.
// SQLite does not have a separate boolean storage class.
@@ -124,7 +168,7 @@ func (s *Stmt) BindInt64(param int, value int64) error {
r, err := s.c.api.bindInteger.Call(s.c.ctx,
uint64(s.handle), uint64(param), uint64(value))
if err != nil {
return err
panic(err)
}
return s.c.error(r[0])
}
@@ -137,7 +181,7 @@ func (s *Stmt) BindFloat(param int, value float64) error {
r, err := s.c.api.bindFloat.Call(s.c.ctx,
uint64(s.handle), uint64(param), math.Float64bits(value))
if err != nil {
return err
panic(err)
}
return s.c.error(r[0])
}
@@ -153,7 +197,7 @@ func (s *Stmt) BindText(param int, value string) error {
uint64(ptr), uint64(len(value)),
s.c.api.destructor, _UTF8)
if err != nil {
return err
panic(err)
}
return s.c.error(r[0])
}
@@ -170,7 +214,7 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
uint64(ptr), uint64(len(value)),
s.c.api.destructor)
if err != nil {
return err
panic(err)
}
return s.c.error(r[0])
}
@@ -183,11 +227,41 @@ func (s *Stmt) BindNull(param int) error {
r, err := s.c.api.bindNull.Call(s.c.ctx,
uint64(s.handle), uint64(param))
if err != nil {
return err
panic(err)
}
return s.c.error(r[0])
}
// ColumnCount returns the number of columns in a result set.
//
// https://www.sqlite.org/c3ref/column_count.html
func (s *Stmt) ColumnCount() int {
r, err := s.c.api.columnCount.Call(s.c.ctx,
uint64(s.handle))
if err != nil {
panic(err)
}
return int(r[0])
}
// ColumnName returns the name of the result column.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_name.html
func (s *Stmt) ColumnName(col int) string {
r, err := s.c.api.columnName.Call(s.c.ctx,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
return ""
}
return s.c.mem.readString(ptr, _MAX_STRING)
}
// ColumnType returns the initial [Datatype] of the result column.
// The leftmost column of the result set has the index 0.
//

View File

@@ -6,6 +6,8 @@ import (
)
func TestStmt(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
@@ -23,6 +25,10 @@ func TestStmt(t *testing.T) {
}
defer stmt.Close()
if got := stmt.BindCount(); got != 1 {
t.Errorf("got %d, want 1", got)
}
err = stmt.BindBool(1, false)
if err != nil {
t.Fatal(err)
@@ -133,6 +139,7 @@ func TestStmt(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnType(0); got != INTEGER {
@@ -359,3 +366,35 @@ func TestStmt_Close(t *testing.T) {
var stmt *Stmt
stmt.Close()
}
func TestStmt_BindName(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
want := []string{"", "", "", "", "?5", ":AAA", "@AAA", "$AAA"}
stmt, _, err := db.Prepare(`SELECT ?, ?5, :AAA, @AAA, $AAA`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if got := stmt.BindCount(); got != len(want) {
t.Errorf("got %d, want %d", got, len(want))
}
for i, name := range want {
id := i + 1
if got := stmt.BindName(id); got != name {
t.Errorf("got %q, want %q", got, name)
}
if name == "" {
id = 0
}
if got := stmt.BindIndex(name); got != id {
t.Errorf("got %d, want %d", got, id)
}
}
}

View File

@@ -1,7 +1,6 @@
package tests
import (
"os"
"path/filepath"
"testing"
@@ -14,13 +13,7 @@ func TestDB_memory(t *testing.T) {
}
func TestDB_file(t *testing.T) {
dir, err := os.MkdirTemp("", "sqlite3-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
testDB(t, filepath.Join(dir, "test.db"))
testDB(t, filepath.Join(t.TempDir(), "test.db"))
}
func testDB(t *testing.T, name string) {
@@ -44,6 +37,7 @@ func testDB(t *testing.T, name string) {
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
row := 0
ids := []int{0, 1, 2}

View File

@@ -1,9 +1,10 @@
package tests
import (
"io"
"os"
"os/exec"
"path/filepath"
"runtime"
"testing"
"golang.org/x/sync/errgroup"
@@ -13,18 +14,53 @@ import (
)
func TestParallel(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip()
name := filepath.Join(t.TempDir(), "test.db")
testParallel(t, name, 100)
testIntegrity(t, name)
}
func TestMultiProcess(t *testing.T) {
if testing.Short() {
return
}
dir, err := os.MkdirTemp("", "sqlite3-")
name := filepath.Join(t.TempDir(), "test.db")
t.Setenv("TestMultiProcess_dbname", name)
cmd := exec.Command("go", "test", "-v", "-run", "TestChildProcess")
out, err := cmd.StdoutPipe()
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
var buf [3]byte
// Wait for child to start.
if _, err := io.ReadFull(out, buf[:]); err != nil || string(buf[:]) != "===" {
t.Fatal(err)
}
testParallel(t, name, 1000)
if err := cmd.Wait(); err != nil {
t.Fatal(err)
}
testIntegrity(t, name)
}
func TestChildProcess(t *testing.T) {
name := os.Getenv("TestMultiProcess_dbname")
if name == "" || testing.Short() {
return
}
testParallel(t, name, 1000)
}
func testParallel(t *testing.T, name string, n int) {
writer := func() error {
db, err := sqlite3.Open(filepath.Join(dir, "test.db"))
db, err := sqlite3.Open(name)
if err != nil {
return err
}
@@ -32,7 +68,7 @@ func TestParallel(t *testing.T) {
err = db.Exec(`
PRAGMA locking_mode = NORMAL;
PRAGMA busy_timeout = 1000;
PRAGMA busy_timeout = 10000;
`)
if err != nil {
return err
@@ -52,7 +88,7 @@ func TestParallel(t *testing.T) {
}
reader := func() error {
db, err := sqlite3.Open(filepath.Join(dir, "test.db"))
db, err := sqlite3.Open(name)
if err != nil {
return err
}
@@ -60,7 +96,7 @@ func TestParallel(t *testing.T) {
err = db.Exec(`
PRAGMA locking_mode = NORMAL;
PRAGMA busy_timeout = 1000;
PRAGMA busy_timeout = 10000;
`)
if err != nil {
return err
@@ -70,6 +106,7 @@ func TestParallel(t *testing.T) {
if err != nil {
return err
}
defer stmt.Close()
row := 0
for stmt.Step() {
@@ -90,14 +127,14 @@ func TestParallel(t *testing.T) {
return db.Close()
}
err = writer()
err := writer()
if err != nil {
t.Fatal(err)
}
var group errgroup.Group
group.SetLimit(4)
for i := 0; i < 32; i++ {
for i := 0; i < n; i++ {
if i&7 != 7 {
group.Go(reader)
} else {
@@ -109,3 +146,41 @@ func TestParallel(t *testing.T) {
t.Fatal(err)
}
}
func testIntegrity(t *testing.T, name string) {
db, err := sqlite3.Open(name)
if err != nil {
t.Fatal(err)
}
defer db.Close()
test := `PRAGMA integrity_check`
if testing.Short() {
test = `PRAGMA quick_check`
}
stmt, _, err := db.Prepare(test)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
if row := stmt.ColumnText(0); row != "ok" {
t.Error(row)
}
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}

35
vfs.go
View File

@@ -18,12 +18,12 @@ import (
"github.com/tetratelabs/wazero/sys"
)
func vfsInstantiate(ctx context.Context, r wazero.Runtime) (err error) {
func vfsInstantiate(ctx context.Context, r wazero.Runtime) {
wasi := r.NewHostModuleBuilder("wasi_snapshot_preview1")
wasi.NewFunctionBuilder().WithFunc(vfsExit).Export("proc_exit")
_, err = wasi.Instantiate(ctx)
_, err := wasi.Instantiate(ctx)
if err != nil {
return err
panic(err)
}
env := r.NewHostModuleBuilder("env")
@@ -45,8 +45,11 @@ func vfsInstantiate(ctx context.Context, r wazero.Runtime) (err error) {
env.NewFunctionBuilder().WithFunc(vfsLock).Export("go_lock")
env.NewFunctionBuilder().WithFunc(vfsUnlock).Export("go_unlock")
env.NewFunctionBuilder().WithFunc(vfsCheckReservedLock).Export("go_check_reserved_lock")
env.NewFunctionBuilder().WithFunc(vfsFileControl).Export("go_file_control")
_, err = env.Instantiate(ctx)
return err
if err != nil {
panic(err)
}
}
func vfsExit(ctx context.Context, mod api.Module, exitCode uint32) {
@@ -112,11 +115,11 @@ func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull
// Or using [os.Readlink] to resolve a symbolic link (as the Unix VFS did).
// This might be buggy on Windows (the Windows VFS doesn't try).
siz := uint32(len(abs) + 1)
if siz > nFull {
size := uint32(len(abs) + 1)
if size > nFull {
return uint32(CANTOPEN_FULLPATH)
}
mem := memory{mod}.view(zFull, siz)
mem := memory{mod}.view(zFull, size)
mem[len(abs)] = 0
copy(mem, abs)
@@ -145,7 +148,7 @@ func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32)
return _OK
}
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags AccessFlag, pResOut uint32) uint32 {
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags _AccessFlag, pResOut uint32) uint32 {
// Consider using [syscall.Access] for [ACCESS_READWRITE]/[ACCESS_READ]
// (as the Unix VFS does).
@@ -154,7 +157,7 @@ func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags Ac
var res uint32
switch {
case flags == ACCESS_EXISTS:
case flags == _ACCESS_EXISTS:
switch {
case err == nil:
res = 1
@@ -166,7 +169,7 @@ func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags Ac
case err == nil:
var want fs.FileMode = syscall.S_IRUSR
if flags == ACCESS_READWRITE {
if flags == _ACCESS_READWRITE {
want |= syscall.S_IWUSR
}
if fi.IsDir() {
@@ -304,3 +307,15 @@ func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) uint3
memory{mod}.writeUint64(pSize, uint64(off))
return _OK
}
func vfsFileControl(ctx context.Context, pFile, op, pArg uint32) uint32 {
// SQLite calls vfsFileControl with these opcodes:
// SQLITE_FCNTL_SIZE_HINT
// SQLITE_FCNTL_PRAGMA
// SQLITE_FCNTL_BUSYHANDLER
// SQLITE_FCNTL_HAS_MOVED
// SQLITE_FCNTL_SYNC
// SQLITE_FCNTL_COMMIT_PHASETWO
// SQLITE_FCNTL_PDB
return uint32(NOTFOUND)
}

View File

@@ -64,7 +64,7 @@ type vfsFileLocker struct {
}
func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) uint32 {
// SQLite never explicitly requests a pendig lock.
// Argument check. SQLite never explicitly requests a pendig lock.
if eLock != _SHARED_LOCK && eLock != _RESERVED_LOCK && eLock != _EXCLUSIVE_LOCK {
panic(assertErr())
}
@@ -72,12 +72,10 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
ptr := vfsFilePtr{mod, pFile}
cLock := ptr.Lock()
// If we already have an equal or more restrictive lock, do nothing.
if cLock >= eLock {
return _OK
}
switch {
case cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK:
// Connection state check.
panic(assertErr())
case cLock == _NO_LOCK && eLock > _SHARED_LOCK:
// We never move from unlocked to anything higher than a shared lock.
panic(assertErr())
@@ -86,31 +84,51 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
panic(assertErr())
}
// If we already have an equal or more restrictive lock, do nothing.
if cLock >= eLock {
return _OK
}
fLock := ptr.Locker()
fLock.Lock()
defer fLock.Unlock()
// File state check.
switch {
case fLock.state < _NO_LOCK || fLock.state > _EXCLUSIVE_LOCK:
panic(assertErr())
case fLock.state == _NO_LOCK && fLock.shared != 0:
panic(assertErr())
case fLock.state == _EXCLUSIVE_LOCK && fLock.shared != 1:
panic(assertErr())
case fLock.state != _NO_LOCK && fLock.shared <= 0:
panic(assertErr())
case fLock.state < cLock:
panic(assertErr())
}
// If some other connection has a lock that precludes the requested lock, return BUSY.
if cLock != fLock.state && (eLock > _SHARED_LOCK || fLock.state >= _PENDING_LOCK) {
return uint32(BUSY)
}
// If a SHARED lock is requested, and some other connection has a SHARED or RESERVED lock,
// then increment the reference count and return OK.
if eLock == _SHARED_LOCK && (fLock.state == _SHARED_LOCK || fLock.state == _RESERVED_LOCK) {
if cLock != _NO_LOCK || fLock.shared <= 0 {
panic(assertErr())
}
ptr.SetLock(_SHARED_LOCK)
fLock.shared++
return _OK
}
// If control gets to this point, then actually go ahead and make
// operating system calls for the specified lock.
switch eLock {
case _SHARED_LOCK:
if fLock.state != _NO_LOCK || fLock.shared != 0 {
// Test the PENDING lock before acquiring a new SHARED lock.
if locked, _ := fLock.CheckPending(); locked {
return uint32(BUSY)
}
// If some other connection has a SHARED or RESERVED lock,
// increment the reference count and return OK.
if fLock.state == _SHARED_LOCK || fLock.state == _RESERVED_LOCK {
ptr.SetLock(_SHARED_LOCK)
fLock.shared++
return _OK
}
// Must be unlocked to get SHARED.
if fLock.state != _NO_LOCK {
panic(assertErr())
}
if rc := fLock.GetShared(); rc != _OK {
@@ -122,7 +140,8 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
return _OK
case _RESERVED_LOCK:
if fLock.state != _SHARED_LOCK || fLock.shared <= 0 {
// Must be SHARED to get RESERVED.
if fLock.state != _SHARED_LOCK {
panic(assertErr())
}
if rc := fLock.GetReserved(); rc != _OK {
@@ -133,7 +152,8 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
return _OK
case _EXCLUSIVE_LOCK:
if fLock.state <= _NO_LOCK || fLock.state >= _EXCLUSIVE_LOCK || fLock.shared <= 0 {
// Must be SHARED, PENDING or RESERVED to get EXCLUSIVE.
if fLock.state <= _NO_LOCK || fLock.state >= _EXCLUSIVE_LOCK {
panic(assertErr())
}
@@ -164,6 +184,7 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
}
func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) uint32 {
// Argument check.
if eLock != _NO_LOCK && eLock != _SHARED_LOCK {
panic(assertErr())
}
@@ -171,6 +192,11 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS
ptr := vfsFilePtr{mod, pFile}
cLock := ptr.Lock()
// Connection state check.
if cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK {
panic(assertErr())
}
// If we don't have a more restrictive lock, do nothing.
if cLock <= eLock {
return _OK
@@ -180,10 +206,20 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS
fLock.Lock()
defer fLock.Unlock()
if fLock.shared <= 0 {
// File state check.
switch {
case fLock.state <= _NO_LOCK || fLock.state > _EXCLUSIVE_LOCK:
panic(assertErr())
case fLock.state == _EXCLUSIVE_LOCK && fLock.shared != 1:
panic(assertErr())
case fLock.shared <= 0:
panic(assertErr())
case fLock.state < cLock:
panic(assertErr())
}
if cLock > _SHARED_LOCK {
// The connection must own the lock to release it.
if cLock != fLock.state {
panic(assertErr())
}
@@ -197,6 +233,7 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS
}
}
// If we get here, make sure we're dropping all locks.
if eLock != _NO_LOCK {
panic(assertErr())
}

View File

@@ -10,7 +10,7 @@ import (
func Test_vfsLock(t *testing.T) {
// Other OSes lack open file descriptors locks.
switch runtime.GOOS {
case "linux", "darwin", "solaris", "windows":
case "linux", "darwin", "illumos", "windows":
//
default:
t.Skip()

View File

@@ -163,16 +163,10 @@ func Test_vfsDelete(t *testing.T) {
}
func Test_vfsAccess(t *testing.T) {
dir, err := os.MkdirTemp("", "sqlite3-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(8, dir)
mem.writeString(8, t.TempDir())
rc := vfsAccess(context.TODO(), mem.mod, 0, 8, ACCESS_EXISTS, 4)
rc := vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_EXISTS, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -180,7 +174,7 @@ func Test_vfsAccess(t *testing.T) {
t.Error("directory did not exist")
}
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, ACCESS_READWRITE, 4)
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}

View File

@@ -13,19 +13,8 @@ func deleteOnClose(f *os.File) {
}
func (l *vfsFileLocker) GetShared() xErrorCode {
// A PENDING lock is needed before acquiring a SHARED lock.
if rc := l.readLock(_PENDING_BYTE, 1); rc != _OK {
return rc
}
// Acquire the SHARED lock.
rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE)
// Drop the temporary PENDING lock.
if rc2 := l.unlock(_PENDING_BYTE, 1); rc == _OK {
return rc2
}
return rc
return l.readLock(_SHARED_FIRST, _SHARED_SIZE)
}
func (l *vfsFileLocker) GetReserved() xErrorCode {
@@ -68,6 +57,11 @@ func (l *vfsFileLocker) CheckReserved() (bool, xErrorCode) {
return l.checkLock(_RESERVED_BYTE, 1)
}
func (l *vfsFileLocker) CheckPending() (bool, xErrorCode) {
// Test the PENDING lock.
return l.checkLock(_PENDING_BYTE, 1)
}
func (l *vfsFileLocker) unlock(start, len int64) xErrorCode {
err := l.fcntlSetLock(&syscall.Flock_t{
Type: syscall.F_UNLCK,
@@ -85,7 +79,7 @@ func (l *vfsFileLocker) readLock(start, len int64) xErrorCode {
Type: syscall.F_RDLCK,
Start: start,
Len: len,
}), IOERR_LOCK)
}), IOERR_RDLOCK)
}
func (l *vfsFileLocker) writeLock(start, len int64) xErrorCode {
@@ -117,7 +111,7 @@ func (l *vfsFileLocker) fcntlGetLock(lock *syscall.Flock_t) error {
case "darwin":
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
F_GETLK = 92 // F_OFD_GETLK
case "solaris":
case "illumos":
// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/fcntl.h
F_GETLK = 47 // F_OFD_GETLK
}
@@ -133,7 +127,7 @@ func (l *vfsFileLocker) fcntlSetLock(lock *syscall.Flock_t) error {
case "darwin":
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
F_SETLK = 90 // F_OFD_SETLK
case "solaris":
case "illumos":
// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/fcntl.h
F_SETLK = 48 // F_OFD_SETLK
}
@@ -146,13 +140,14 @@ func (*vfsFileLocker) errorCode(err error, def xErrorCode) xErrorCode {
}
if errno, ok := err.(syscall.Errno); ok {
switch errno {
case syscall.EACCES:
case syscall.EAGAIN:
case syscall.EBUSY:
case syscall.EINTR:
case syscall.ENOLCK:
case syscall.EDEADLK:
case syscall.ETIMEDOUT:
case
syscall.EACCES,
syscall.EAGAIN,
syscall.EBUSY,
syscall.EINTR,
syscall.ENOLCK,
syscall.EDEADLK,
syscall.ETIMEDOUT:
return xErrorCode(BUSY)
case syscall.EPERM:
return xErrorCode(PERM)

View File

@@ -10,19 +10,8 @@ import (
func deleteOnClose(f *os.File) {}
func (l *vfsFileLocker) GetShared() xErrorCode {
// A PENDING lock is needed before acquiring a SHARED lock.
if rc := l.readLock(_PENDING_BYTE, 1); rc != _OK {
return rc
}
// Acquire the SHARED lock.
rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE)
// Drop the temporary PENDING lock.
if rc2 := l.unlock(_PENDING_BYTE, 1); rc == _OK {
return rc2
}
return rc
return l.readLock(_SHARED_FIRST, _SHARED_SIZE)
}
func (l *vfsFileLocker) GetReserved() xErrorCode {
@@ -83,6 +72,15 @@ func (l *vfsFileLocker) CheckReserved() (bool, xErrorCode) {
return rc != _OK, _OK
}
func (l *vfsFileLocker) CheckPending() (bool, xErrorCode) {
// Test the PENDING lock.
rc := l.readLock(_PENDING_BYTE, 1)
if rc == _OK {
l.unlock(_PENDING_BYTE, 1)
}
return rc != _OK, _OK
}
func (l *vfsFileLocker) unlock(start, len uint32) xErrorCode {
err := windows.UnlockFileEx(windows.Handle(l.file.Fd()),
0, len, 0, &windows.Overlapped{Offset: start})
@@ -96,7 +94,7 @@ func (l *vfsFileLocker) readLock(start, len uint32) xErrorCode {
return l.errorCode(windows.LockFileEx(windows.Handle(l.file.Fd()),
windows.LOCKFILE_FAIL_IMMEDIATELY,
0, len, 0, &windows.Overlapped{Offset: start}),
IOERR_LOCK)
IOERR_RDLOCK)
}
func (l *vfsFileLocker) writeLock(start, len uint32) xErrorCode {