diff --git a/blob.go b/blob.go index 2fac720..bf3a275 100644 --- a/blob.go +++ b/blob.go @@ -42,7 +42,7 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, flags = 1 } - c.checkInterrupt(c.handle) + c.checkInterrupt() rc := res_t(c.call("sqlite3_blob_open", stk_t(c.handle), stk_t(dbPtr), stk_t(tablePtr), stk_t(columnPtr), stk_t(row), stk_t(flags), stk_t(blobPtr))) @@ -253,7 +253,7 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) { // // https://sqlite.org/c3ref/blob_reopen.html func (b *Blob) Reopen(row int64) error { - b.c.checkInterrupt(b.c.handle) + b.c.checkInterrupt() err := b.c.error(res_t(b.c.call("sqlite3_blob_reopen", stk_t(b.handle), stk_t(row)))) b.bytes = int64(int32(b.c.call("sqlite3_blob_bytes", stk_t(b.handle)))) b.offset = 0 diff --git a/conn.go b/conn.go index a7f4acd..2ea16a2 100644 --- a/conn.go +++ b/conn.go @@ -25,7 +25,6 @@ type Conn struct { *sqlite interrupt context.Context - pending *Stmt stmts []*Stmt busy func(context.Context, int) bool log func(xErrorCode, string) @@ -41,7 +40,9 @@ type Conn struct { busylst time.Time arena arena handle ptr_t - nprogr uint8 + pending ptr_t + stepped bool + gosched uint8 } // Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI]. @@ -133,7 +134,6 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) { } } if pragmas.Len() != 0 { - c.checkInterrupt(handle) pragmaPtr := c.arena.string(pragmas.String()) rc := res_t(c.call("sqlite3_exec", stk_t(handle), stk_t(pragmaPtr), 0, 0, 0)) if err := c.sqlite.error(rc, handle, pragmas.String()); err != nil { @@ -167,8 +167,8 @@ func (c *Conn) Close() error { return nil } - c.pending.Close() - c.pending = nil + c.call("sqlite3_finalize", stk_t(c.pending)) + c.pending = 0 rc := res_t(c.call("sqlite3_close", stk_t(c.handle))) if err := c.error(rc); err != nil { @@ -187,7 +187,7 @@ func (c *Conn) Exec(sql string) error { defer c.arena.mark()() sqlPtr := c.arena.string(sql) - c.checkInterrupt(c.handle) + c.checkInterrupt() rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(sqlPtr), 0, 0, 0)) return c.error(rc, sql) } @@ -211,16 +211,16 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str defer c.arena.mark()() stmtPtr := c.arena.new(ptrlen) tailPtr := c.arena.new(ptrlen) - sqlPtr := c.arena.string(sql) + textPtr := c.arena.string(sql) - c.checkInterrupt(c.handle) + c.checkInterrupt() rc := res_t(c.call("sqlite3_prepare_v3", stk_t(c.handle), - stk_t(sqlPtr), stk_t(len(sql)+1), stk_t(flags), + stk_t(textPtr), stk_t(len(sql)+1), stk_t(flags), stk_t(stmtPtr), stk_t(tailPtr))) - stmt = &Stmt{c: c} + stmt = &Stmt{c: c, sql: sql} stmt.handle = util.Read32[ptr_t](c.mod, stmtPtr) - if sql := sql[util.Read32[ptr_t](c.mod, tailPtr)-sqlPtr:]; sql != "" { + if sql := sql[util.Read32[ptr_t](c.mod, tailPtr)-textPtr:]; sql != "" { tail = sql } @@ -344,40 +344,44 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { old = c.interrupt c.interrupt = ctx - if ctx == old || ctx.Done() == old.Done() { + if ctx == old { return old } - // A busy SQL statement prevents SQLite from ignoring an interrupt + // An active SQL statement prevents SQLite from ignoring an interrupt // that comes before any other statements are started. - if c.pending == nil { + if c.pending == 0 { defer c.arena.mark()() stmtPtr := c.arena.new(ptrlen) - loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`) - c.call("sqlite3_prepare_v3", stk_t(c.handle), stk_t(loopPtr), math.MaxUint64, + textPtr := c.arena.string(`SELECT 0 UNION ALL SELECT 0`) + c.call("sqlite3_prepare_v3", stk_t(c.handle), stk_t(textPtr), math.MaxUint64, stk_t(PREPARE_PERSISTENT), stk_t(stmtPtr), 0) - c.pending = &Stmt{c: c} - c.pending.handle = util.Read32[ptr_t](c.mod, stmtPtr) + c.pending = util.Read32[ptr_t](c.mod, stmtPtr) } - if old.Done() != nil && ctx.Err() == nil { - c.pending.Reset() - } - if ctx.Done() != nil { - c.pending.Step() + if c.stepped && ctx.Err() == nil { + c.call("sqlite3_reset", stk_t(c.pending)) + c.stepped = false + } else { + c.checkInterrupt() } return old } -func (c *Conn) checkInterrupt(handle ptr_t) { - if c.interrupt.Err() != nil { - c.call("sqlite3_interrupt", stk_t(handle)) +func (c *Conn) checkInterrupt() { + if c.interrupt.Err() == nil { + return } + if !c.stepped { + c.call("sqlite3_step", stk_t(c.pending)) + c.stepped = true + } + c.call("sqlite3_interrupt", stk_t(c.handle)) } func progressCallback(ctx context.Context, mod api.Module, _ ptr_t) (interrupt int32) { if c, ok := ctx.Value(connKey{}).(*Conn); ok { - if c.nprogr++; c.nprogr%16 == 0 { + if c.gosched++; c.gosched%16 == 0 { runtime.Gosched() } if c.interrupt.Err() != nil { diff --git a/func.go b/func.go index 934a5c1..16b4305 100644 --- a/func.go +++ b/func.go @@ -284,10 +284,10 @@ func returnArgs(p *[]Value) { } type aggregateFunc struct { - ctx Context - arg []Value next func() (struct{}, bool) stop func() + ctx Context + arg []Value } func (a *aggregateFunc) Step(ctx Context, arg ...Value) { diff --git a/stmt.go b/stmt.go index 2581eeb..5314595 100644 --- a/stmt.go +++ b/stmt.go @@ -106,7 +106,7 @@ func (s *Stmt) Busy() bool { // // https://sqlite.org/c3ref/step.html func (s *Stmt) Step() bool { - s.c.checkInterrupt(s.c.handle) + s.c.checkInterrupt() rc := res_t(s.c.call("sqlite3_step", stk_t(s.handle))) switch rc { case _ROW: