diff --git a/conn.go b/conn.go index 3d19b70..4ffa27c 100644 --- a/conn.go +++ b/conn.go @@ -7,10 +7,9 @@ import ( "net/url" "runtime" "strings" - "sync/atomic" - "unsafe" "github.com/ncruces/go-sqlite3/internal/util" + "github.com/tetratelabs/wazero/api" ) // Conn is a database connection handle. @@ -21,7 +20,6 @@ type Conn struct { *sqlite interrupt context.Context - waiter chan struct{} pending *Stmt arena arena @@ -48,6 +46,8 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) { return newConn(filename, flags) } +type connKey struct{} + func newConn(filename string, flags OpenFlag) (conn *Conn, err error) { sqlite, err := instantiateSQLite() if err != nil { @@ -63,6 +63,7 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) { c := &Conn{sqlite: sqlite} c.arena = c.newArena(1024) + c.ctx = context.WithValue(c.ctx, connKey{}, c) c.handle, err = c.openDB(filename, flags) if err != nil { return nil, err @@ -131,7 +132,6 @@ func (c *Conn) Close() error { return nil } - c.SetInterrupt(context.Background()) c.pending.Close() c.pending = nil @@ -244,65 +244,40 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { return ctx } - // Is a waiter running? - if c.waiter != nil { - c.waiter <- struct{}{} // Cancel the waiter. - <-c.waiter // Wait for it to finish. - c.waiter = nil - } - // Reset the pending statement. - if c.pending != nil { + // An uncompleted SQL statement prevents SQLite from ignoring + // an interrupt that comes before any other statements are started. + if c.pending == nil { + c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`) + } else { c.pending.Reset() } old = c.interrupt c.interrupt = ctx + // Remove the handler if the context can't be canceled. if ctx == nil || ctx.Done() == nil { + c.call(c.api.progressHandler, uint64(c.handle), 0) return old } - // Creating an uncompleted SQL statement prevents SQLite from ignoring - // an interrupt that comes before any other statements are started. - if c.pending == nil { - c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`) - } c.pending.Step() - - // Don't create the goroutine if we're already interrupted. - // This happens frequently while restoring to a previously interrupted state. - if c.checkInterrupt() { - return old - } - - waiter := make(chan struct{}) - c.waiter = waiter - go func() { - select { - case <-waiter: // Waiter was cancelled. - break - - case <-ctx.Done(): // Done was closed. - const isInterruptedOffset = 288 - buf := util.View(c.mod, c.handle+isInterruptedOffset, 4) - (*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1) - // Wait for the next call to SetInterrupt. - <-waiter - } - - // Signal that the waiter has finished. - waiter <- struct{}{} - }() + c.call(c.api.progressHandler, uint64(c.handle), 100) return old } -func (c *Conn) checkInterrupt() bool { - if c.interrupt == nil || c.interrupt.Err() == nil { - return false +func callbackProgress(ctx context.Context, mod api.Module, _ uint32) uint32 { + if c, ok := ctx.Value(connKey{}).(*Conn); ok { + if c.interrupt != nil && c.interrupt.Err() != nil { + return 1 + } + } + return 0 +} + +func (c *Conn) checkInterrupt() { + if c.interrupt != nil && c.interrupt.Err() != nil { + c.call(c.api.interrupt, uint64(c.handle)) } - const isInterruptedOffset = 288 - buf := util.View(c.mod, c.handle+isInterruptedOffset, 4) - (*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1) - return true } // Pragma executes a PRAGMA statement and returns any results. diff --git a/driver/driver.go b/driver/driver.go index 6bd4bb9..f788be4 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -31,6 +31,7 @@ import ( "database/sql" "database/sql/driver" "encoding/json" + "errors" "fmt" "io" "net/url" @@ -225,7 +226,13 @@ func (c *conn) Commit() error { } func (c *conn) Rollback() error { - return c.Conn.Exec(c.txRollback) + err := c.Conn.Exec(c.txRollback) + if errors.Is(err, sqlite3.INTERRUPT) { + old := c.Conn.SetInterrupt(context.Background()) + defer c.Conn.SetInterrupt(old) + err = c.Conn.Exec(c.txRollback) + } + return err } func (c *conn) Prepare(query string) (driver.Stmt, error) { diff --git a/embed/exports.txt b/embed/exports.txt index 4f10313..b96d40a 100644 --- a/embed/exports.txt +++ b/embed/exports.txt @@ -13,6 +13,8 @@ sqlite3_finalize sqlite3_reset sqlite3_step sqlite3_exec +sqlite3_interrupt +sqlite3_progress_handler_go sqlite3_clear_bindings sqlite3_bind_parameter_count sqlite3_bind_parameter_index diff --git a/embed/sqlite3.wasm b/embed/sqlite3.wasm index 9361e4e..7f456bb 100755 Binary files a/embed/sqlite3.wasm and b/embed/sqlite3.wasm differ diff --git a/func.go b/func.go index 4de8b98..fc4f846 100644 --- a/func.go +++ b/func.go @@ -4,7 +4,6 @@ import ( "context" "github.com/ncruces/go-sqlite3/internal/util" - "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" ) @@ -47,6 +46,7 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func( // CreateWindowFunction defines a new aggregate or aggregate window SQL function. // If fn returns a [WindowFunction], then an aggregate window function is created. +// If fn returns an [io.Closer], it will be called to free resources. // // https://www.sqlite.org/c3ref/create_function.html func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error { @@ -70,7 +70,7 @@ type AggregateFunction interface { // The function arguments, if any, corresponding to the row being added are passed to Step. Step(ctx Context, arg ...Value) - // Value is invoked to return the current value of the aggregate. + // Value is invoked to return the current (or final) value of the aggregate. Value(ctx Context) } @@ -85,17 +85,6 @@ type WindowFunction interface { Inverse(ctx Context, arg ...Value) } -func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { - util.ExportFuncVI(env, "go_destroy", callbackDestroy) - util.ExportFuncIIIIII(env, "go_compare", callbackCompare) - util.ExportFuncVIII(env, "go_func", callbackFunc) - util.ExportFuncVIII(env, "go_step", callbackStep) - util.ExportFuncVI(env, "go_final", callbackFinal) - util.ExportFuncVI(env, "go_value", callbackValue) - util.ExportFuncVIII(env, "go_inverse", callbackInverse) - return env -} - func callbackDestroy(ctx context.Context, mod api.Module, pApp uint32) { util.DelHandle(ctx, pApp) } @@ -106,20 +95,20 @@ func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nK } func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { - sqlite := ctx.Value(sqliteKey{}).(*sqlite) + sqlite := ctx.Value(connKey{}).(*Conn).sqlite fn := callbackHandle(sqlite, pCtx).(func(ctx Context, arg ...Value)) fn(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...) } func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { - sqlite := ctx.Value(sqliteKey{}).(*sqlite) + sqlite := ctx.Value(connKey{}).(*Conn).sqlite fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction) fn.Step(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...) } func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) { var handle uint32 - sqlite := ctx.Value(sqliteKey{}).(*sqlite) + sqlite := ctx.Value(connKey{}).(*Conn).sqlite fn := callbackAggregate(sqlite, pCtx, &handle).(AggregateFunction) fn.Value(Context{sqlite, pCtx}) if err := util.DelHandle(ctx, handle); err != nil { @@ -128,13 +117,13 @@ func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) { } func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) { - sqlite := ctx.Value(sqliteKey{}).(*sqlite) + sqlite := ctx.Value(connKey{}).(*Conn).sqlite fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction) fn.Value(Context{sqlite, pCtx}) } func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { - sqlite := ctx.Value(sqliteKey{}).(*sqlite) + sqlite := ctx.Value(connKey{}).(*Conn).sqlite fn := callbackAggregate(sqlite, pCtx, nil).(WindowFunction) fn.Inverse(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...) } diff --git a/sqlite.go b/sqlite.go index ce9966f..bbf900f 100644 --- a/sqlite.go +++ b/sqlite.go @@ -43,7 +43,7 @@ func compileSQLite() { env := instance.runtime.NewHostModuleBuilder("env") env = vfs.ExportHostFunctions(env) - env = exportHostFunctions(env) + env = exportCallbacks(env) _, instance.err = env.Instantiate(ctx) if instance.err != nil { return @@ -71,8 +71,6 @@ type sqlite struct { stack [8]uint64 } -type sqliteKey struct{} - func instantiateSQLite() (sqlt *sqlite, err error) { instance.once.Do(compileSQLite) if instance.err != nil { @@ -81,7 +79,6 @@ func instantiateSQLite() (sqlt *sqlite, err error) { sqlt = new(sqlite) sqlt.ctx = util.NewContext(context.Background()) - sqlt.ctx = context.WithValue(sqlt.ctx, sqliteKey{}, sqlt) sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx, instance.compiled, wazero.NewModuleConfig()) @@ -123,6 +120,8 @@ func instantiateSQLite() (sqlt *sqlite, err error) { reset: getFun("sqlite3_reset"), step: getFun("sqlite3_step"), exec: getFun("sqlite3_exec"), + interrupt: getFun("sqlite3_interrupt"), + progressHandler: getFun("sqlite3_progress_handler_go"), clearBindings: getFun("sqlite3_clear_bindings"), bindCount: getFun("sqlite3_bind_parameter_count"), bindIndex: getFun("sqlite3_bind_parameter_index"), @@ -342,6 +341,8 @@ type sqliteAPI struct { reset api.Function step api.Function exec api.Function + interrupt api.Function + progressHandler api.Function clearBindings api.Function bindCount api.Function bindIndex api.Function @@ -402,3 +403,15 @@ type sqliteAPI struct { resultErrorBig api.Function destructor uint32 } + +func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { + util.ExportFuncII(env, "go_progress", callbackProgress) + util.ExportFuncVI(env, "go_destroy", callbackDestroy) + util.ExportFuncIIIIII(env, "go_compare", callbackCompare) + util.ExportFuncVIII(env, "go_func", callbackFunc) + util.ExportFuncVIII(env, "go_step", callbackStep) + util.ExportFuncVI(env, "go_final", callbackFinal) + util.ExportFuncVI(env, "go_value", callbackValue) + util.ExportFuncVIII(env, "go_inverse", callbackInverse) + return env +} diff --git a/sqlite3/main.c b/sqlite3/main.c index 50f5867..894464b 100644 --- a/sqlite3/main.c +++ b/sqlite3/main.c @@ -11,6 +11,7 @@ #include "ext/uint.c" #include "ext/uuid.c" #include "func.c" +#include "progress.c" #include "time.c" __attribute__((constructor)) void init() { diff --git a/sqlite3/progress.c b/sqlite3/progress.c new file mode 100644 index 0000000..8049ecb --- /dev/null +++ b/sqlite3/progress.c @@ -0,0 +1,9 @@ +#include + +#include "sqlite3.h" + +int go_progress(void *); + +void sqlite3_progress_handler_go(sqlite3 *db, int n) { + sqlite3_progress_handler(db, n, go_progress, NULL); +} diff --git a/sqlite3/sqlite_cfg.h b/sqlite3/sqlite_cfg.h index 89b8681..c70982c 100644 --- a/sqlite3/sqlite_cfg.h +++ b/sqlite3/sqlite_cfg.h @@ -23,7 +23,6 @@ #define SQLITE_MAX_EXPR_DEPTH 0 #define SQLITE_OMIT_DECLTYPE #define SQLITE_OMIT_DEPRECATED -#define SQLITE_OMIT_PROGRESS_CALLBACK #define SQLITE_OMIT_SHARED_CACHE #define SQLITE_OMIT_AUTOINIT #define SQLITE_USE_ALLOCA diff --git a/sqlite3/vfs.c b/sqlite3/vfs.c index cdc9fe5..ec6cdf2 100644 --- a/sqlite3/vfs.c +++ b/sqlite3/vfs.c @@ -137,6 +137,5 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) { return sqlite3_vfs_find_orig(zVfsName); } -static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset"); static_assert(offsetof(sqlite3_vfs, zName) == 16, "Unexpected offset"); -static_assert(offsetof(sqlite3, u1.isInterrupted) == 288, "Unexpected offset"); \ No newline at end of file +static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset"); \ No newline at end of file diff --git a/tests/conn_test.go b/tests/conn_test.go index 88ae11e..8d5c143 100644 --- a/tests/conn_test.go +++ b/tests/conn_test.go @@ -182,7 +182,7 @@ func TestConn_SetInterrupt(t *testing.T) { defer stmt.Close() db.SetInterrupt(ctx) - cancel() + go cancel() // Interrupting works. err = stmt.Exec() diff --git a/vfs/tests/mptest/testdata/mptest.wasm.bz2 b/vfs/tests/mptest/testdata/mptest.wasm.bz2 index db9b87e..f22e2ed 100644 --- a/vfs/tests/mptest/testdata/mptest.wasm.bz2 +++ b/vfs/tests/mptest/testdata/mptest.wasm.bz2 @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:da493a827d5b2985ba80d7425092a891311633d6c80c559119f87609d0f0e02a -size 508796 +oid sha256:c59231ce10786b45be958027d23cffc74894a00120b30c8d3accb26f4182b29a +size 509312 diff --git a/vfs/tests/speedtest1/testdata/speedtest1.wasm.bz2 b/vfs/tests/speedtest1/testdata/speedtest1.wasm.bz2 index d7f1fa2..31601d9 100644 --- a/vfs/tests/speedtest1/testdata/speedtest1.wasm.bz2 +++ b/vfs/tests/speedtest1/testdata/speedtest1.wasm.bz2 @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c5ffb1dc0f046bb7e4ea4cd992e1d625c05c7d5cf33c78edc5a2155eb3d3c097 -size 523393 +oid sha256:9f715bad486eeae35ecb3cf05a2e6265fbfc24a2de0836bdc8fd760510ac1d3a +size 524127