From 7438fdb664a07335d2a32063d26bb95f767d1320 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 2 Feb 2024 23:41:34 +0000 Subject: [PATCH] Busy handlers. --- conn.go | 41 +++++++++++++++++++++++++++++++-- driver/driver.go | 2 +- embed/exports.txt | 2 ++ sqlite.go | 3 ++- sqlite3/hooks.c | 9 ++++++-- tests/parallel/parallel_test.go | 21 +++++++++-------- txn.go | 2 +- 7 files changed, 64 insertions(+), 16 deletions(-) diff --git a/conn.go b/conn.go index 3636336..6addcb4 100644 --- a/conn.go +++ b/conn.go @@ -4,8 +4,10 @@ import ( "context" "errors" "fmt" + "math" "net/url" "strings" + "time" "github.com/ncruces/go-sqlite3/internal/util" "github.com/tetratelabs/wazero/api" @@ -20,6 +22,7 @@ type Conn struct { interrupt context.Context pending *Stmt + busy func(int) bool log func(xErrorCode, string) collation func(*Conn, string) authorizer func(AuthorizerActionCode, string, string, string, string) AuthorizerReturnCode @@ -322,8 +325,8 @@ func (c *Conn) checkInterrupt() { } } -func progressCallback(ctx context.Context, mod api.Module, _ uint32) uint32 { - if c, ok := ctx.Value(connKey{}).(*Conn); ok { +func progressCallback(ctx context.Context, mod api.Module, pDB uint32) uint32 { + if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.commit != nil { if c.interrupt != nil && c.interrupt.Err() != nil { return 1 } @@ -331,6 +334,40 @@ func progressCallback(ctx context.Context, mod api.Module, _ uint32) uint32 { return 0 } +// BusyTimeout sets a busy timeout. +// +// https://sqlite.org/c3ref/busy_timeout.html +func (c *Conn) BusyTimeout(timeout time.Duration) error { + ms := min((timeout+time.Millisecond-1)/time.Millisecond, math.MaxInt32) + r := c.call("sqlite3_busy_timeout", uint64(c.handle), uint64(ms)) + return c.error(r) +} + +// BusyHandler registers a callback to handle [BUSY] errors. +// +// https://sqlite.org/c3ref/busy_handler.html +func (c *Conn) BusyHandler(cb func(count int) (retry bool)) error { + var enable uint64 + if cb != nil { + enable = 1 + } + r := c.call("sqlite3_busy_handler_go", uint64(c.handle), enable) + if err := c.error(r); err != nil { + return err + } + c.busy = cb + return nil +} + +func busyCallback(ctx context.Context, mod api.Module, pDB, count uint32) uint32 { + if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil { + if retry := c.busy(int(count)); retry { + return 1 + } + } + return 0 +} + // Deprecated: executes a PRAGMA statement and returns results. func (c *Conn) Pragma(str string) ([]string, error) { stmt, _, err := c.Prepare(`PRAGMA ` + str) diff --git a/driver/driver.go b/driver/driver.go index ae55b00..19f5001 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -170,7 +170,7 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) { defer c.Conn.SetInterrupt(old) if !n.pragmas { - err = c.Conn.Exec(`PRAGMA busy_timeout=60000`) + err = c.Conn.BusyTimeout(60 * time.Second) if err != nil { return nil, err } diff --git a/embed/exports.txt b/embed/exports.txt index 1ddb57c..73c680f 100644 --- a/embed/exports.txt +++ b/embed/exports.txt @@ -24,6 +24,8 @@ sqlite3_blob_open sqlite3_blob_read sqlite3_blob_reopen sqlite3_blob_write +sqlite3_busy_handler_go +sqlite3_busy_timeout sqlite3_changes64 sqlite3_clear_bindings sqlite3_close diff --git a/sqlite.go b/sqlite.go index 9069ef2..90dc8ed 100644 --- a/sqlite.go +++ b/sqlite.go @@ -289,7 +289,8 @@ func (a *arena) string(s string) uint32 { } func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { - util.ExportFuncII(env, "go_progress", progressCallback) + util.ExportFuncIII(env, "go_busy_handler", busyCallback) + util.ExportFuncII(env, "go_progress_handler", progressCallback) util.ExportFuncII(env, "go_commit_hook", commitCallback) util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback) util.ExportFuncVIIIIJ(env, "go_update_hook", updateCallback) diff --git a/sqlite3/hooks.c b/sqlite3/hooks.c index b44f9a8..445d661 100644 --- a/sqlite3/hooks.c +++ b/sqlite3/hooks.c @@ -2,7 +2,8 @@ #include "sqlite3.h" -int go_progress(void *); +int go_progress_handler(void *); +int go_busy_handler(void *, int); int go_commit_hook(void *); void go_rollback_hook(void *); @@ -14,7 +15,11 @@ int go_authorizer(void *, int, const char *, const char *, const char *, void go_log(void *, int, const char *); void sqlite3_progress_handler_go(sqlite3 *db, int n) { - sqlite3_progress_handler(db, n, go_progress, /*arg=*/db); + sqlite3_progress_handler(db, n, go_progress_handler, /*arg=*/db); +} + +int sqlite3_busy_handler_go(sqlite3 *db, bool enable) { + return sqlite3_busy_handler(db, enable ? go_busy_handler : NULL, /*arg=*/db); } void sqlite3_commit_hook_go(sqlite3 *db, bool enable) { diff --git a/tests/parallel/parallel_test.go b/tests/parallel/parallel_test.go index 7fc7f85..7c2f672 100644 --- a/tests/parallel/parallel_test.go +++ b/tests/parallel/parallel_test.go @@ -6,6 +6,7 @@ import ( "os/exec" "path/filepath" "testing" + "time" "golang.org/x/sync/errgroup" @@ -39,10 +40,7 @@ func TestMemory(t *testing.T) { iter = 5000 } - name := "file:/test.db?vfs=memdb" + - "&_pragma=busy_timeout(10000)" + - "&_pragma=journal_mode(memory)" + - "&_pragma=synchronous(off)" + name := "file:/test.db?vfs=memdb" testParallel(t, name, iter) testIntegrity(t, name) } @@ -100,10 +98,7 @@ func TestChildProcess(t *testing.T) { func BenchmarkMemory(b *testing.B) { memdb.Delete("test.db") - name := "file:/test.db?vfs=memdb" + - "&_pragma=busy_timeout(10000)" + - "&_pragma=journal_mode(memory)" + - "&_pragma=synchronous(off)" + name := "file:/test.db?vfs=memdb" testParallel(b, name, b.N) } @@ -115,6 +110,14 @@ func testParallel(t testing.TB, name string, n int) { } defer db.Close() + err = db.BusyHandler(func(count int) (retry bool) { + time.Sleep(time.Millisecond) + return true + }) + if err != nil { + return err + } + err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`) if err != nil { return err @@ -135,7 +138,7 @@ func testParallel(t testing.TB, name string, n int) { } defer db.Close() - err = db.Exec(`PRAGMA busy_timeout=10000`) + err = db.BusyTimeout(10 * time.Second) if err != nil { return err } diff --git a/txn.go b/txn.go index 8475bfb..263c0d9 100644 --- a/txn.go +++ b/txn.go @@ -275,7 +275,7 @@ func (c *Conn) UpdateHook(cb func(action AuthorizerActionCode, schema, table str func commitCallback(ctx context.Context, mod api.Module, pDB uint32) uint32 { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.commit != nil { - if !c.commit() { + if ok := c.commit(); !ok { return 1 } }