diff --git a/conn.go b/conn.go index 8ba034f..f170ccf 100644 --- a/conn.go +++ b/conn.go @@ -346,10 +346,9 @@ func (c *Conn) checkInterrupt() { } func progressCallback(ctx context.Context, mod api.Module, pDB uint32) (interrupt uint32) { - if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.interrupt != nil { - if c.interrupt.Err() != nil { - interrupt = 1 - } + if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && + c.interrupt != nil && c.interrupt.Err() != nil { + interrupt = 1 } return interrupt } @@ -363,6 +362,30 @@ func (c *Conn) BusyTimeout(timeout time.Duration) error { return c.error(r) } +func timeoutCallback(ctx context.Context, mod api.Module, pDB uint32, count, tmout int32) (retry uint32) { + if c, ok := ctx.Value(connKey{}).(*Conn); ok && + (c.interrupt == nil || c.interrupt.Err() == nil) { + const delays = "\x01\x02\x05\x0a\x0f\x14\x19\x19\x19\x32\x32\x64" + const totals = "\x00\x01\x03\x08\x12\x21\x35\x4e\x67\x80\xb2\xe4" + const ndelay = int32(len(delays) - 1) + + var delay, prior int32 + if count <= ndelay { + delay = int32(delays[count]) + prior = int32(totals[count]) + } else { + delay = int32(delays[ndelay]) + prior = int32(totals[ndelay]) + delay*(count-ndelay) + } + + if delay = min(delay, tmout-prior); delay > 0 { + time.Sleep(time.Duration(delay) * time.Millisecond) + retry = 1 + } + } + return retry +} + // BusyHandler registers a callback to handle [BUSY] errors. // // https://sqlite.org/c3ref/busy_handler.html @@ -380,7 +403,8 @@ func (c *Conn) BusyHandler(cb func(count int) (retry bool)) error { } func busyCallback(ctx context.Context, mod api.Module, pDB uint32, count int32) (retry uint32) { - if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil { + if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil && + (c.interrupt == nil || c.interrupt.Err() == nil) { if c.busy(int(count)) { retry = 1 } diff --git a/embed/build.sh b/embed/build.sh index 54aa170..abe5e60 100755 --- a/embed/build.sh +++ b/embed/build.sh @@ -8,7 +8,7 @@ BINARYEN="$ROOT/tools/binaryen-version_117/bin" WASI_SDK="$ROOT/tools/wasi-sdk-22.0/bin" "$WASI_SDK/clang" --target=wasm32-wasi -std=c17 -flto -g0 -O2 \ - -Wall -Wextra -Wno-unused-parameter \ + -Wall -Wextra -Wno-unused-parameter -Wno-unused-function \ -o sqlite3.wasm "$ROOT/sqlite3/main.c" \ -I"$ROOT/sqlite3" \ -mexec-model=reactor \ diff --git a/embed/sqlite3.wasm b/embed/sqlite3.wasm index 6994d4d..519eb61 100755 Binary files a/embed/sqlite3.wasm and b/embed/sqlite3.wasm differ diff --git a/sqlite.go b/sqlite.go index 593ba33..61a0365 100644 --- a/sqlite.go +++ b/sqlite.go @@ -296,8 +296,9 @@ func (a *arena) string(s string) uint32 { } func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { - util.ExportFuncIII(env, "go_busy_handler", busyCallback) util.ExportFuncII(env, "go_progress_handler", progressCallback) + util.ExportFuncIIII(env, "go_busy_timeout", timeoutCallback) + util.ExportFuncIII(env, "go_busy_handler", busyCallback) util.ExportFuncII(env, "go_commit_hook", commitCallback) util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback) util.ExportFuncVIIIIJ(env, "go_update_hook", updateCallback) diff --git a/sqlite3/busy_timeout.patch b/sqlite3/busy_timeout.patch new file mode 100644 index 0000000..e9bbf63 --- /dev/null +++ b/sqlite3/busy_timeout.patch @@ -0,0 +1,13 @@ +# Replace sqliteDefaultBusyCallback. +# This patch allows Go to handle (and interrupt) sqlite3_busy_timeout. +--- sqlite3.c.orig ++++ sqlite3.c +@@ -181581,7 +181581,7 @@ + if( !sqlite3SafetyCheckOk(db) ) return SQLITE_MISUSE_BKPT; + #endif + if( ms>0 ){ +- sqlite3_busy_handler(db, (int(*)(void*,int))sqliteDefaultBusyCallback, ++ sqlite3_busy_handler(db, (int(*)(void*,int))sqliteBusyCallback, + (void*)db); + db->busyTimeout = ms; + }else{ diff --git a/sqlite3/hooks.c b/sqlite3/hooks.c index ee9b735..c872131 100644 --- a/sqlite3/hooks.c +++ b/sqlite3/hooks.c @@ -4,6 +4,7 @@ int go_progress_handler(void *); int go_busy_handler(void *, int); +int go_busy_timeout(void *, int count, int tmout); int go_commit_hook(void *); void go_rollback_hook(void *); @@ -55,4 +56,12 @@ int sqlite3_autovacuum_pages_go(sqlite3 *db, go_handle app) { int rc = sqlite3_autovacuum_pages(db, go_autovacuum_pages, app, go_destroy); if (rc) go_destroy(app); return rc; -} \ No newline at end of file +} + +#ifndef sqliteBusyCallback + +static int sqliteBusyCallback(sqlite3 *db, int count) { + return go_busy_timeout(db, count, db->busyTimeout); +} + +#endif \ No newline at end of file diff --git a/sqlite3/sqlite_cfg.h b/sqlite3/sqlite_cfg.h index 524cc72..d86793a 100644 --- a/sqlite3/sqlite_cfg.h +++ b/sqlite3/sqlite_cfg.h @@ -35,8 +35,12 @@ // Because Wasm does not support shared memory, // SQLite disables WAL for Wasm builds. -// But we want it. #undef SQLITE_OMIT_WAL // Implemented in vfs.c. -int localtime_s(struct tm *const pTm, time_t const *const pTime); \ No newline at end of file +int localtime_s(struct tm *const pTm, time_t const *const pTime); + +// Implemented in hooks.c. +#ifndef sqliteBusyCallback +static int sqliteBusyCallback(sqlite3 *, int); +#endif \ No newline at end of file diff --git a/sqlite3/vfs_find.patch b/sqlite3/vfs_find.patch index 8435763..861c049 100644 --- a/sqlite3/vfs_find.patch +++ b/sqlite3/vfs_find.patch @@ -1,7 +1,8 @@ # Wrap sqlite3_vfs_find. +# This patch allows Go VFSes to be (un)registered. --- sqlite3.c.orig +++ sqlite3.c -@@ -26089,7 +26089,7 @@ +@@ -26372,7 +26372,7 @@ ** Locate a VFS by name. If no name is given, simply return the ** first VFS on the list. */ diff --git a/tests/txn_test.go b/tests/txn_test.go index 415fcca..15b9809 100644 --- a/tests/txn_test.go +++ b/tests/txn_test.go @@ -8,6 +8,7 @@ import ( "github.com/ncruces/go-sqlite3" _ "github.com/ncruces/go-sqlite3/embed" _ "github.com/ncruces/go-sqlite3/tests/testcfg" + _ "github.com/ncruces/go-sqlite3/vfs/memdb" ) func TestConn_Transaction_exec(t *testing.T) { @@ -247,6 +248,51 @@ func TestConn_Transaction_interrupted(t *testing.T) { } } +func TestConn_Transaction_busy(t *testing.T) { + t.Parallel() + + db1, err := sqlite3.Open("file:/test.db?vfs=memdb") + if err != nil { + t.Fatal(err) + } + defer db1.Close() + + db2, err := sqlite3.Open("file:/test.db?vfs=memdb&_pragma=busy_timeout(10000)") + if err != nil { + t.Fatal(err) + } + defer db2.Close() + + err = db1.Exec(`CREATE TABLE test (col)`) + if err != nil { + t.Fatal(err) + } + + tx, err := db1.BeginImmediate() + if err != nil { + t.Fatal(err) + } + err = db1.Exec(`INSERT INTO test VALUES (1)`) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + db2.SetInterrupt(ctx) + go cancel() + + _, err = db2.BeginExclusive() + if !errors.Is(err, sqlite3.BUSY) { + t.Errorf("got %v, want sqlite3.BUSY", err) + } + + err = nil + tx.End(&err) + if err != nil { + t.Fatal(err) + } +} + func TestConn_Transaction_rollback(t *testing.T) { t.Parallel() diff --git a/vfs/tests/mptest/testdata/main.c b/vfs/tests/mptest/testdata/main.c index 4e9688e..ba02c75 100644 --- a/vfs/tests/mptest/testdata/main.c +++ b/vfs/tests/mptest/testdata/main.c @@ -1,5 +1,7 @@ #include +#define sqliteBusyCallback sqliteDefaultBusyCallback + // Amalgamation #include "sqlite3.c" // VFS diff --git a/vfs/tests/speedtest1/testdata/main.c b/vfs/tests/speedtest1/testdata/main.c index a203a3b..a5d6a5e 100644 --- a/vfs/tests/speedtest1/testdata/main.c +++ b/vfs/tests/speedtest1/testdata/main.c @@ -1,6 +1,8 @@ #include #include +#define sqliteBusyCallback sqliteDefaultBusyCallback + // Amalgamation #include "sqlite3.c" // VFS