diff --git a/blob.go b/blob.go index 010948e..268dfab 100644 --- a/blob.go +++ b/blob.go @@ -31,7 +31,6 @@ var _ io.ReadWriteSeeker = &Blob{} // // https://sqlite.org/c3ref/blob_open.html func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) { - c.checkInterrupt() defer c.arena.mark()() blobPtr := c.arena.new(ptrlen) dbPtr := c.arena.string(db) @@ -43,6 +42,7 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, flags = 1 } + c.checkInterrupt(c.handle) r := c.call("sqlite3_blob_open", uint64(c.handle), uint64(dbPtr), uint64(tablePtr), uint64(columnPtr), uint64(row), flags, uint64(blobPtr)) diff --git a/conn.go b/conn.go index 10d4d17..cf7b46f 100644 --- a/conn.go +++ b/conn.go @@ -40,12 +40,18 @@ type Conn struct { // Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI]. func Open(filename string) (*Conn, error) { - return newConn(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI) + return newConn(context.Background(), filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI) +} + +// OpenContext is like [Open] but includes a context, +// which is used to interrupt the process of opening the connectiton. +func OpenContext(ctx context.Context, filename string) (*Conn, error) { + return newConn(ctx, filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI) } // OpenFlags opens an SQLite database file as specified by the filename argument. // -// If none of the required flags is used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used. +// If none of the required flags are used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used. // If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma": // // sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)") @@ -55,33 +61,37 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) { if flags&(OPEN_READONLY|OPEN_READWRITE|OPEN_CREATE) == 0 { flags |= OPEN_READWRITE | OPEN_CREATE } - return newConn(filename, flags) + return newConn(context.Background(), filename, flags) } type connKey struct{} -func newConn(filename string, flags OpenFlag) (conn *Conn, err error) { +func newConn(ctx context.Context, filename string, flags OpenFlag) (conn *Conn, err error) { + err = ctx.Err() + if err != nil { + return nil, err + } sqlite, err := instantiateSQLite() if err != nil { return nil, err } defer func() { - if conn == nil { + if err != nil { + conn.Close() + conn = nil sqlite.close() } }() - c := &Conn{sqlite: sqlite} - c.arena = c.newArena(1024) + c := &Conn{sqlite: sqlite, interrupt: ctx} c.ctx = context.WithValue(c.ctx, connKey{}, c) + c.arena = c.newArena(1024) c.handle, err = c.openDB(filename, flags) if err == nil { err = initExtensions(c) } - if err != nil { - return nil, err - } - return c, nil + c.interrupt = context.Background() + return c, err } func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { @@ -98,6 +108,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { return 0, err } + c.call("sqlite3_progress_handler_go", uint64(handle), 100) if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") { var pragmas strings.Builder if _, after, ok := strings.Cut(filename, "?"); ok { @@ -109,6 +120,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { } } if pragmas.Len() != 0 { + c.checkInterrupt(handle) pragmaPtr := c.arena.string(pragmas.String()) r := c.call("sqlite3_exec", uint64(handle), uint64(pragmaPtr), 0, 0, 0) if err := c.sqlite.error(r, handle, pragmas.String()); err != nil { @@ -118,7 +130,6 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { } } } - c.call("sqlite3_progress_handler_go", uint64(handle), 100) return handle, nil } @@ -160,10 +171,10 @@ func (c *Conn) Close() error { // // https://sqlite.org/c3ref/exec.html func (c *Conn) Exec(sql string) error { - c.checkInterrupt() defer c.arena.mark()() sqlPtr := c.arena.string(sql) + c.checkInterrupt(c.handle) r := c.call("sqlite3_exec", uint64(c.handle), uint64(sqlPtr), 0, 0, 0) return c.error(r, sql) } @@ -301,8 +312,7 @@ func (c *Conn) ReleaseMemory() error { return c.error(r) } -// GetInterrupt gets the context set with [Conn.SetInterrupt], -// or nil if none was set. +// GetInterrupt gets the context set with [Conn.SetInterrupt]. func (c *Conn) GetInterrupt() context.Context { return c.interrupt } @@ -322,9 +332,11 @@ func (c *Conn) GetInterrupt() context.Context { // // https://sqlite.org/c3ref/interrupt.html func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { - // Is it the same context? - if ctx == c.interrupt { - return ctx + old = c.interrupt + c.interrupt = ctx + + if ctx == old || ctx.Done() == old.Done() { + return old } // A busy SQL statement prevents SQLite from ignoring an interrupt @@ -333,32 +345,29 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { defer c.arena.mark()() stmtPtr := c.arena.new(ptrlen) loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`) - c.call("sqlite3_prepare_v3", uint64(c.handle), uint64(loopPtr), math.MaxUint64, 0, uint64(stmtPtr), 0) + c.call("sqlite3_prepare_v3", uint64(c.handle), uint64(loopPtr), math.MaxUint64, + uint64(PREPARE_PERSISTENT), uint64(stmtPtr), 0) c.pending = &Stmt{c: c} c.pending.handle = util.ReadUint32(c.mod, stmtPtr) } - old = c.interrupt - c.interrupt = ctx - - if old != nil && old.Done() != nil && (ctx == nil || ctx.Err() == nil) { + if old.Done() != nil && ctx.Err() == nil { c.pending.Reset() } - if ctx != nil && ctx.Done() != nil { + if ctx.Done() != nil { c.pending.Step() } return old } -func (c *Conn) checkInterrupt() { - if c.interrupt != nil && c.interrupt.Err() != nil { - c.call("sqlite3_interrupt", uint64(c.handle)) +func (c *Conn) checkInterrupt(handle uint32) { + if c.interrupt.Err() != nil { + c.call("sqlite3_interrupt", uint64(handle)) } } -func progressCallback(ctx context.Context, mod api.Module, pDB uint32) (interrupt uint32) { - if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && - c.interrupt != nil && c.interrupt.Err() != nil { +func progressCallback(ctx context.Context, mod api.Module, _ uint32) (interrupt uint32) { + if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.interrupt.Err() != nil { interrupt = 1 } return interrupt @@ -373,9 +382,8 @@ func (c *Conn) BusyTimeout(timeout time.Duration) error { return c.error(r) } -func timeoutCallback(ctx context.Context, mod api.Module, pDB uint32, count, tmout int32) (retry uint32) { - if c, ok := ctx.Value(connKey{}).(*Conn); ok && - (c.interrupt == nil || c.interrupt.Err() == nil) { +func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (retry uint32) { + if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.interrupt.Err() == nil { const delays = "\x01\x02\x05\x0a\x0f\x14\x19\x19\x19\x32\x32\x64" const totals = "\x00\x01\x03\x08\x12\x21\x35\x4e\x67\x80\xb2\xe4" const ndelay = int32(len(delays) - 1) @@ -391,7 +399,7 @@ func timeoutCallback(ctx context.Context, mod api.Module, pDB uint32, count, tmo if delay = min(delay, tmout-prior); delay > 0 { delay := time.Duration(delay) * time.Millisecond - if c.interrupt == nil || c.interrupt.Done() == nil { + if c.interrupt.Done() == nil { time.Sleep(delay) return 1 } diff --git a/driver/driver.go b/driver/driver.go index 720772a..086a87e 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -209,7 +209,7 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) { tmWrite: n.tmWrite, } - c.Conn, err = sqlite3.Open(n.name) + c.Conn, err = sqlite3.OpenContext(ctx, n.name) if err != nil { return nil, err } diff --git a/embed/bcw2/bcw2.wasm b/embed/bcw2/bcw2.wasm index ad8670f..23cf0f1 100755 Binary files a/embed/bcw2/bcw2.wasm and b/embed/bcw2/bcw2.wasm differ diff --git a/embed/exports.txt b/embed/exports.txt index b624ee1..5460195 100644 --- a/embed/exports.txt +++ b/embed/exports.txt @@ -51,6 +51,7 @@ sqlite3_create_collation_go sqlite3_create_function_go sqlite3_create_module_go sqlite3_create_window_function_go +sqlite3_data_count sqlite3_database_file_object sqlite3_db_cacheflush sqlite3_db_config diff --git a/embed/sqlite3.wasm b/embed/sqlite3.wasm index 59eeb6c..749a6ed 100755 Binary files a/embed/sqlite3.wasm and b/embed/sqlite3.wasm differ diff --git a/sqlite.go b/sqlite.go index f91fdcb..a5ff136 100644 --- a/sqlite.go +++ b/sqlite.go @@ -301,7 +301,7 @@ func (a *arena) string(s string) uint32 { func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { util.ExportFuncII(env, "go_progress_handler", progressCallback) - util.ExportFuncIIII(env, "go_busy_timeout", timeoutCallback) + util.ExportFuncIII(env, "go_busy_timeout", timeoutCallback) util.ExportFuncIII(env, "go_busy_handler", busyCallback) util.ExportFuncII(env, "go_commit_hook", commitCallback) util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback) diff --git a/sqlite3/hooks.c b/sqlite3/hooks.c index 1bdb405..7bf7547 100644 --- a/sqlite3/hooks.c +++ b/sqlite3/hooks.c @@ -4,7 +4,7 @@ int go_progress_handler(void *); int go_busy_handler(void *, int); -int go_busy_timeout(void *, int count, int tmout); +int go_busy_timeout(int count, int tmout); int go_commit_hook(void *); void go_rollback_hook(void *); @@ -20,7 +20,7 @@ unsigned int go_autovacuum_pages(void *, const char *, unsigned int, unsigned int, unsigned int); void sqlite3_progress_handler_go(sqlite3 *db, int n) { - sqlite3_progress_handler(db, n, go_progress_handler, /*arg=*/db); + sqlite3_progress_handler(db, n, go_progress_handler, /*arg=*/NULL); } int sqlite3_busy_handler_go(sqlite3 *db, bool enable) { @@ -66,7 +66,7 @@ int sqlite3_autovacuum_pages_go(sqlite3 *db, go_handle app) { #ifndef sqliteBusyCallback static int sqliteBusyCallback(sqlite3 *db, int count) { - return go_busy_timeout(db, count, db->busyTimeout); + return go_busy_timeout(count, db->busyTimeout); } #endif \ No newline at end of file diff --git a/stmt.go b/stmt.go index 2e1f648..82f9fb7 100644 --- a/stmt.go +++ b/stmt.go @@ -106,7 +106,7 @@ func (s *Stmt) Busy() bool { // // https://sqlite.org/c3ref/step.html func (s *Stmt) Step() bool { - s.c.checkInterrupt() + s.c.checkInterrupt(s.c.handle) r := s.c.call("sqlite3_step", uint64(s.handle)) switch r { case _ROW: @@ -377,6 +377,15 @@ func (s *Stmt) BindValue(param int, value Value) error { return s.c.error(r) } +// DataCount resets the number of columns in a result set. +// +// https://www.sqlite.org/c3ref/data_count.html +func (s *Stmt) DataCount() int { + r := s.c.call("sqlite3_data_count", + uint64(s.handle)) + return int(int32(r)) +} + // ColumnCount returns the number of columns in a result set. // // https://sqlite.org/c3ref/column_count.html diff --git a/tests/stmt_test.go b/tests/stmt_test.go index c0cb267..0a3fe45 100644 --- a/tests/stmt_test.go +++ b/tests/stmt_test.go @@ -146,6 +146,12 @@ func TestStmt(t *testing.T) { if got := stmt.ReadOnly(); got != true { t.Error("got false, want true") } + if got := stmt.DataCount(); got != 0 { + t.Errorf("got %d, want 0", got) + } + if got := stmt.ColumnCount(); got != 1 { + t.Errorf("got %d, want 1", got) + } if got := stmt.ColumnName(0); got != "c" { t.Errorf(`got %q, want "c"`, got) } @@ -503,6 +509,10 @@ func TestStmt(t *testing.T) { } } + if got := stmt.DataCount(); got != 1 { + t.Errorf("got %d, want 1", got) + } + db.Stmts()(func(s *sqlite3.Stmt) bool { if s != stmt { t.Error()