Improved interrupts.

This commit is contained in:
Nuno Cruces
2024-10-04 13:31:53 +01:00
parent ac94a5406e
commit 96c61a2f55
10 changed files with 69 additions and 41 deletions

View File

@@ -31,7 +31,6 @@ var _ io.ReadWriteSeeker = &Blob{}
// //
// https://sqlite.org/c3ref/blob_open.html // https://sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) { func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
c.checkInterrupt()
defer c.arena.mark()() defer c.arena.mark()()
blobPtr := c.arena.new(ptrlen) blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db) dbPtr := c.arena.string(db)
@@ -43,6 +42,7 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob,
flags = 1 flags = 1
} }
c.checkInterrupt(c.handle)
r := c.call("sqlite3_blob_open", uint64(c.handle), r := c.call("sqlite3_blob_open", uint64(c.handle),
uint64(dbPtr), uint64(tablePtr), uint64(columnPtr), uint64(dbPtr), uint64(tablePtr), uint64(columnPtr),
uint64(row), flags, uint64(blobPtr)) uint64(row), flags, uint64(blobPtr))

76
conn.go
View File

@@ -40,12 +40,18 @@ type Conn struct {
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI]. // Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI].
func Open(filename string) (*Conn, error) { func Open(filename string) (*Conn, error) {
return newConn(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI) return newConn(context.Background(), filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
}
// OpenContext is like [Open] but includes a context,
// which is used to interrupt the process of opening the connectiton.
func OpenContext(ctx context.Context, filename string) (*Conn, error) {
return newConn(ctx, filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
} }
// OpenFlags opens an SQLite database file as specified by the filename argument. // OpenFlags opens an SQLite database file as specified by the filename argument.
// //
// If none of the required flags is used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used. // If none of the required flags are used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used.
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma": // If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
// //
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)") // sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)")
@@ -55,33 +61,37 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
if flags&(OPEN_READONLY|OPEN_READWRITE|OPEN_CREATE) == 0 { if flags&(OPEN_READONLY|OPEN_READWRITE|OPEN_CREATE) == 0 {
flags |= OPEN_READWRITE | OPEN_CREATE flags |= OPEN_READWRITE | OPEN_CREATE
} }
return newConn(filename, flags) return newConn(context.Background(), filename, flags)
} }
type connKey struct{} type connKey struct{}
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) { func newConn(ctx context.Context, filename string, flags OpenFlag) (conn *Conn, err error) {
err = ctx.Err()
if err != nil {
return nil, err
}
sqlite, err := instantiateSQLite() sqlite, err := instantiateSQLite()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { defer func() {
if conn == nil { if err != nil {
conn.Close()
conn = nil
sqlite.close() sqlite.close()
} }
}() }()
c := &Conn{sqlite: sqlite} c := &Conn{sqlite: sqlite, interrupt: ctx}
c.arena = c.newArena(1024)
c.ctx = context.WithValue(c.ctx, connKey{}, c) c.ctx = context.WithValue(c.ctx, connKey{}, c)
c.arena = c.newArena(1024)
c.handle, err = c.openDB(filename, flags) c.handle, err = c.openDB(filename, flags)
if err == nil { if err == nil {
err = initExtensions(c) err = initExtensions(c)
} }
if err != nil { c.interrupt = context.Background()
return nil, err return c, err
}
return c, nil
} }
func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
@@ -98,6 +108,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
return 0, err return 0, err
} }
c.call("sqlite3_progress_handler_go", uint64(handle), 100)
if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") { if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
var pragmas strings.Builder var pragmas strings.Builder
if _, after, ok := strings.Cut(filename, "?"); ok { if _, after, ok := strings.Cut(filename, "?"); ok {
@@ -109,6 +120,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
} }
} }
if pragmas.Len() != 0 { if pragmas.Len() != 0 {
c.checkInterrupt(handle)
pragmaPtr := c.arena.string(pragmas.String()) pragmaPtr := c.arena.string(pragmas.String())
r := c.call("sqlite3_exec", uint64(handle), uint64(pragmaPtr), 0, 0, 0) r := c.call("sqlite3_exec", uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.sqlite.error(r, handle, pragmas.String()); err != nil { if err := c.sqlite.error(r, handle, pragmas.String()); err != nil {
@@ -118,7 +130,6 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
} }
} }
} }
c.call("sqlite3_progress_handler_go", uint64(handle), 100)
return handle, nil return handle, nil
} }
@@ -160,10 +171,10 @@ func (c *Conn) Close() error {
// //
// https://sqlite.org/c3ref/exec.html // https://sqlite.org/c3ref/exec.html
func (c *Conn) Exec(sql string) error { func (c *Conn) Exec(sql string) error {
c.checkInterrupt()
defer c.arena.mark()() defer c.arena.mark()()
sqlPtr := c.arena.string(sql) sqlPtr := c.arena.string(sql)
c.checkInterrupt(c.handle)
r := c.call("sqlite3_exec", uint64(c.handle), uint64(sqlPtr), 0, 0, 0) r := c.call("sqlite3_exec", uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
return c.error(r, sql) return c.error(r, sql)
} }
@@ -301,8 +312,7 @@ func (c *Conn) ReleaseMemory() error {
return c.error(r) return c.error(r)
} }
// GetInterrupt gets the context set with [Conn.SetInterrupt], // GetInterrupt gets the context set with [Conn.SetInterrupt].
// or nil if none was set.
func (c *Conn) GetInterrupt() context.Context { func (c *Conn) GetInterrupt() context.Context {
return c.interrupt return c.interrupt
} }
@@ -322,9 +332,11 @@ func (c *Conn) GetInterrupt() context.Context {
// //
// https://sqlite.org/c3ref/interrupt.html // https://sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
// Is it the same context? old = c.interrupt
if ctx == c.interrupt { c.interrupt = ctx
return ctx
if ctx == old || ctx.Done() == old.Done() {
return old
} }
// A busy SQL statement prevents SQLite from ignoring an interrupt // A busy SQL statement prevents SQLite from ignoring an interrupt
@@ -333,32 +345,29 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
defer c.arena.mark()() defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen) stmtPtr := c.arena.new(ptrlen)
loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`) loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`)
c.call("sqlite3_prepare_v3", uint64(c.handle), uint64(loopPtr), math.MaxUint64, 0, uint64(stmtPtr), 0) c.call("sqlite3_prepare_v3", uint64(c.handle), uint64(loopPtr), math.MaxUint64,
uint64(PREPARE_PERSISTENT), uint64(stmtPtr), 0)
c.pending = &Stmt{c: c} c.pending = &Stmt{c: c}
c.pending.handle = util.ReadUint32(c.mod, stmtPtr) c.pending.handle = util.ReadUint32(c.mod, stmtPtr)
} }
old = c.interrupt if old.Done() != nil && ctx.Err() == nil {
c.interrupt = ctx
if old != nil && old.Done() != nil && (ctx == nil || ctx.Err() == nil) {
c.pending.Reset() c.pending.Reset()
} }
if ctx != nil && ctx.Done() != nil { if ctx.Done() != nil {
c.pending.Step() c.pending.Step()
} }
return old return old
} }
func (c *Conn) checkInterrupt() { func (c *Conn) checkInterrupt(handle uint32) {
if c.interrupt != nil && c.interrupt.Err() != nil { if c.interrupt.Err() != nil {
c.call("sqlite3_interrupt", uint64(c.handle)) c.call("sqlite3_interrupt", uint64(handle))
} }
} }
func progressCallback(ctx context.Context, mod api.Module, pDB uint32) (interrupt uint32) { func progressCallback(ctx context.Context, mod api.Module, _ uint32) (interrupt uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.interrupt.Err() != nil {
c.interrupt != nil && c.interrupt.Err() != nil {
interrupt = 1 interrupt = 1
} }
return interrupt return interrupt
@@ -373,9 +382,8 @@ func (c *Conn) BusyTimeout(timeout time.Duration) error {
return c.error(r) return c.error(r)
} }
func timeoutCallback(ctx context.Context, mod api.Module, pDB uint32, count, tmout int32) (retry uint32) { func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (retry uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.interrupt.Err() == nil {
(c.interrupt == nil || c.interrupt.Err() == nil) {
const delays = "\x01\x02\x05\x0a\x0f\x14\x19\x19\x19\x32\x32\x64" const delays = "\x01\x02\x05\x0a\x0f\x14\x19\x19\x19\x32\x32\x64"
const totals = "\x00\x01\x03\x08\x12\x21\x35\x4e\x67\x80\xb2\xe4" const totals = "\x00\x01\x03\x08\x12\x21\x35\x4e\x67\x80\xb2\xe4"
const ndelay = int32(len(delays) - 1) const ndelay = int32(len(delays) - 1)
@@ -391,7 +399,7 @@ func timeoutCallback(ctx context.Context, mod api.Module, pDB uint32, count, tmo
if delay = min(delay, tmout-prior); delay > 0 { if delay = min(delay, tmout-prior); delay > 0 {
delay := time.Duration(delay) * time.Millisecond delay := time.Duration(delay) * time.Millisecond
if c.interrupt == nil || c.interrupt.Done() == nil { if c.interrupt.Done() == nil {
time.Sleep(delay) time.Sleep(delay)
return 1 return 1
} }

View File

@@ -209,7 +209,7 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
tmWrite: n.tmWrite, tmWrite: n.tmWrite,
} }
c.Conn, err = sqlite3.Open(n.name) c.Conn, err = sqlite3.OpenContext(ctx, n.name)
if err != nil { if err != nil {
return nil, err return nil, err
} }

Binary file not shown.

View File

@@ -51,6 +51,7 @@ sqlite3_create_collation_go
sqlite3_create_function_go sqlite3_create_function_go
sqlite3_create_module_go sqlite3_create_module_go
sqlite3_create_window_function_go sqlite3_create_window_function_go
sqlite3_data_count
sqlite3_database_file_object sqlite3_database_file_object
sqlite3_db_cacheflush sqlite3_db_cacheflush
sqlite3_db_config sqlite3_db_config

Binary file not shown.

View File

@@ -301,7 +301,7 @@ func (a *arena) string(s string) uint32 {
func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncII(env, "go_progress_handler", progressCallback) util.ExportFuncII(env, "go_progress_handler", progressCallback)
util.ExportFuncIIII(env, "go_busy_timeout", timeoutCallback) util.ExportFuncIII(env, "go_busy_timeout", timeoutCallback)
util.ExportFuncIII(env, "go_busy_handler", busyCallback) util.ExportFuncIII(env, "go_busy_handler", busyCallback)
util.ExportFuncII(env, "go_commit_hook", commitCallback) util.ExportFuncII(env, "go_commit_hook", commitCallback)
util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback) util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback)

View File

@@ -4,7 +4,7 @@
int go_progress_handler(void *); int go_progress_handler(void *);
int go_busy_handler(void *, int); int go_busy_handler(void *, int);
int go_busy_timeout(void *, int count, int tmout); int go_busy_timeout(int count, int tmout);
int go_commit_hook(void *); int go_commit_hook(void *);
void go_rollback_hook(void *); void go_rollback_hook(void *);
@@ -20,7 +20,7 @@ unsigned int go_autovacuum_pages(void *, const char *, unsigned int,
unsigned int, unsigned int); unsigned int, unsigned int);
void sqlite3_progress_handler_go(sqlite3 *db, int n) { void sqlite3_progress_handler_go(sqlite3 *db, int n) {
sqlite3_progress_handler(db, n, go_progress_handler, /*arg=*/db); sqlite3_progress_handler(db, n, go_progress_handler, /*arg=*/NULL);
} }
int sqlite3_busy_handler_go(sqlite3 *db, bool enable) { int sqlite3_busy_handler_go(sqlite3 *db, bool enable) {
@@ -66,7 +66,7 @@ int sqlite3_autovacuum_pages_go(sqlite3 *db, go_handle app) {
#ifndef sqliteBusyCallback #ifndef sqliteBusyCallback
static int sqliteBusyCallback(sqlite3 *db, int count) { static int sqliteBusyCallback(sqlite3 *db, int count) {
return go_busy_timeout(db, count, db->busyTimeout); return go_busy_timeout(count, db->busyTimeout);
} }
#endif #endif

11
stmt.go
View File

@@ -106,7 +106,7 @@ func (s *Stmt) Busy() bool {
// //
// https://sqlite.org/c3ref/step.html // https://sqlite.org/c3ref/step.html
func (s *Stmt) Step() bool { func (s *Stmt) Step() bool {
s.c.checkInterrupt() s.c.checkInterrupt(s.c.handle)
r := s.c.call("sqlite3_step", uint64(s.handle)) r := s.c.call("sqlite3_step", uint64(s.handle))
switch r { switch r {
case _ROW: case _ROW:
@@ -377,6 +377,15 @@ func (s *Stmt) BindValue(param int, value Value) error {
return s.c.error(r) return s.c.error(r)
} }
// DataCount resets the number of columns in a result set.
//
// https://www.sqlite.org/c3ref/data_count.html
func (s *Stmt) DataCount() int {
r := s.c.call("sqlite3_data_count",
uint64(s.handle))
return int(int32(r))
}
// ColumnCount returns the number of columns in a result set. // ColumnCount returns the number of columns in a result set.
// //
// https://sqlite.org/c3ref/column_count.html // https://sqlite.org/c3ref/column_count.html

View File

@@ -146,6 +146,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ReadOnly(); got != true { if got := stmt.ReadOnly(); got != true {
t.Error("got false, want true") t.Error("got false, want true")
} }
if got := stmt.DataCount(); got != 0 {
t.Errorf("got %d, want 0", got)
}
if got := stmt.ColumnCount(); got != 1 {
t.Errorf("got %d, want 1", got)
}
if got := stmt.ColumnName(0); got != "c" { if got := stmt.ColumnName(0); got != "c" {
t.Errorf(`got %q, want "c"`, got) t.Errorf(`got %q, want "c"`, got)
} }
@@ -503,6 +509,10 @@ func TestStmt(t *testing.T) {
} }
} }
if got := stmt.DataCount(); got != 1 {
t.Errorf("got %d, want 1", got)
}
db.Stmts()(func(s *sqlite3.Stmt) bool { db.Stmts()(func(s *sqlite3.Stmt) bool {
if s != stmt { if s != stmt {
t.Error() t.Error()