Rework context cancellation. (#251)

This commit is contained in:
Nuno Cruces
2025-03-26 11:39:06 +00:00
committed by GitHub
parent befed7cf23
commit 948641194b
7 changed files with 90 additions and 64 deletions

View File

@@ -31,6 +31,10 @@ var _ io.ReadWriteSeeker = &Blob{}
//
// https://sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
if c.interrupt.Err() != nil {
return nil, INTERRUPT
}
defer c.arena.mark()()
blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db)
@@ -42,7 +46,6 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob,
flags = 1
}
c.checkInterrupt()
rc := res_t(c.call("sqlite3_blob_open", stk_t(c.handle),
stk_t(dbPtr), stk_t(tablePtr), stk_t(columnPtr),
stk_t(row), stk_t(flags), stk_t(blobPtr)))
@@ -253,7 +256,9 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
//
// https://sqlite.org/c3ref/blob_reopen.html
func (b *Blob) Reopen(row int64) error {
b.c.checkInterrupt()
if b.c.interrupt.Err() != nil {
return INTERRUPT
}
err := b.c.error(res_t(b.c.call("sqlite3_blob_reopen", stk_t(b.handle), stk_t(row))))
b.bytes = int64(int32(b.c.call("sqlite3_blob_bytes", stk_t(b.handle))))
b.offset = 0

View File

@@ -275,12 +275,14 @@ func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pAr
//
// https://sqlite.org/c3ref/wal_checkpoint_v2.html
func (c *Conn) WALCheckpoint(schema string, mode CheckpointMode) (nLog, nCkpt int, err error) {
if c.interrupt.Err() != nil {
return 0, 0, INTERRUPT
}
defer c.arena.mark()()
nLogPtr := c.arena.new(ptrlen)
nCkptPtr := c.arena.new(ptrlen)
schemaPtr := c.arena.string(schema)
c.checkInterrupt()
rc := res_t(c.call("sqlite3_wal_checkpoint_v2",
stk_t(c.handle), stk_t(schemaPtr), stk_t(mode),
stk_t(nLogPtr), stk_t(nCkptPtr)))

51
conn.go
View File

@@ -40,8 +40,6 @@ type Conn struct {
busylst time.Time
arena arena
handle ptr_t
pending ptr_t
stepped bool
gosched uint8
}
@@ -167,9 +165,6 @@ func (c *Conn) Close() error {
return nil
}
c.call("sqlite3_finalize", stk_t(c.pending))
c.pending = 0
rc := res_t(c.call("sqlite3_close", stk_t(c.handle)))
if err := c.error(rc); err != nil {
return err
@@ -184,10 +179,15 @@ func (c *Conn) Close() error {
//
// https://sqlite.org/c3ref/exec.html
func (c *Conn) Exec(sql string) error {
if c.interrupt.Err() != nil {
return INTERRUPT
}
return c.exec(sql)
}
func (c *Conn) exec(sql string) error {
defer c.arena.mark()()
textPtr := c.arena.string(sql)
c.checkInterrupt()
rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(textPtr), 0, 0, 0))
return c.error(rc, sql)
}
@@ -207,13 +207,15 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
if len(sql) > _MAX_SQL_LENGTH {
return nil, "", TOOBIG
}
if c.interrupt.Err() != nil {
return nil, "", INTERRUPT
}
defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen)
tailPtr := c.arena.new(ptrlen)
textPtr := c.arena.string(sql)
c.checkInterrupt()
rc := res_t(c.call("sqlite3_prepare_v3", stk_t(c.handle),
stk_t(textPtr), stk_t(len(sql)+1), stk_t(flags),
stk_t(stmtPtr), stk_t(tailPtr)))
@@ -343,42 +345,9 @@ func (c *Conn) GetInterrupt() context.Context {
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
old = c.interrupt
c.interrupt = ctx
if ctx == old {
return old
}
// An active SQL statement prevents SQLite from ignoring an interrupt
// that comes before any other statements are started.
if c.pending == 0 {
defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen)
textPtr := c.arena.string(`SELECT 0 UNION ALL SELECT 0`)
c.call("sqlite3_prepare_v3", stk_t(c.handle), stk_t(textPtr), math.MaxUint64,
stk_t(PREPARE_PERSISTENT), stk_t(stmtPtr), 0)
c.pending = util.Read32[ptr_t](c.mod, stmtPtr)
}
if c.stepped && ctx.Err() == nil {
c.call("sqlite3_reset", stk_t(c.pending))
c.stepped = false
} else {
c.checkInterrupt()
}
return old
}
func (c *Conn) checkInterrupt() {
if c.interrupt.Err() == nil {
return
}
if !c.stepped {
c.call("sqlite3_step", stk_t(c.pending))
c.stepped = true
}
c.call("sqlite3_interrupt", stk_t(c.handle))
}
func progressCallback(ctx context.Context, mod api.Module, _ ptr_t) (interrupt int32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
if c.gosched++; c.gosched%16 == 0 {

View File

@@ -199,6 +199,62 @@ func Test_BeginTx(t *testing.T) {
}
}
func Test_nested_context(t *testing.T) {
t.Parallel()
tmp := memdb.TestDB(t)
db, err := sql.Open("sqlite3", tmp)
if err != nil {
t.Fatal(err)
}
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
outer, err := tx.Query(`SELECT value FROM generate_series(0)`)
if err != nil {
t.Fatal(err)
}
defer outer.Close()
want := func(rows *sql.Rows, want int) {
t.Helper()
var got int
rows.Next()
if err := rows.Scan(&got); err != nil {
t.Fatal(err)
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
want(outer, 0)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
inner, err := tx.QueryContext(ctx, `SELECT value FROM generate_series(0)`)
if err != nil {
t.Fatal(err)
}
defer inner.Close()
want(inner, 0)
cancel()
if inner.Next() || !errors.Is(inner.Err(), sqlite3.INTERRUPT) {
t.Fatal(inner.Err())
}
want(outer, 1)
}
func Test_Prepare(t *testing.T) {
t.Parallel()
tmp := memdb.TestDB(t)

View File

@@ -106,7 +106,11 @@ func (s *Stmt) Busy() bool {
//
// https://sqlite.org/c3ref/step.html
func (s *Stmt) Step() bool {
s.c.checkInterrupt()
if s.c.interrupt.Err() != nil {
s.err = INTERRUPT
return false
}
rc := res_t(s.c.call("sqlite3_step", stk_t(s.handle)))
switch rc {
case _ROW:

21
txn.go
View File

@@ -2,7 +2,6 @@ package sqlite3
import (
"context"
"errors"
"math/rand"
"runtime"
"strconv"
@@ -25,7 +24,7 @@ type Txn struct {
// https://sqlite.org/lang_transaction.html
func (c *Conn) Begin() Txn {
// BEGIN even if interrupted.
err := c.txnExecInterrupted(`BEGIN DEFERRED`)
err := c.exec(`BEGIN DEFERRED`)
if err != nil {
panic(err)
}
@@ -120,7 +119,7 @@ func (tx Txn) Commit() error {
//
// https://sqlite.org/lang_transaction.html
func (tx Txn) Rollback() error {
return tx.c.txnExecInterrupted(`ROLLBACK`)
return tx.c.exec(`ROLLBACK`)
}
// Savepoint is a marker within a transaction
@@ -143,7 +142,7 @@ func (c *Conn) Savepoint() Savepoint {
// Names can be reused, but this makes catching bugs more likely.
name = QuoteIdentifier(name + "_" + strconv.Itoa(int(rand.Int31())))
err := c.txnExecInterrupted(`SAVEPOINT ` + name)
err := c.exec(`SAVEPOINT ` + name)
if err != nil {
panic(err)
}
@@ -199,7 +198,7 @@ func (s Savepoint) Release(errp *error) {
return
}
// ROLLBACK and RELEASE even if interrupted.
err := s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name)
err := s.c.exec(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name)
if err != nil {
panic(err)
}
@@ -212,17 +211,7 @@ func (s Savepoint) Release(errp *error) {
// https://sqlite.org/lang_transaction.html
func (s Savepoint) Rollback() error {
// ROLLBACK even if interrupted.
return s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name)
}
func (c *Conn) txnExecInterrupted(sql string) error {
err := c.Exec(sql)
if errors.Is(err, INTERRUPT) {
old := c.SetInterrupt(context.Background())
defer c.SetInterrupt(old)
err = c.Exec(sql)
}
return err
return s.c.exec(`ROLLBACK TO ` + s.name)
}
// TxnState determines the transaction state of a database.

View File

@@ -79,10 +79,11 @@ func implements[T any](typ reflect.Type) bool {
//
// https://sqlite.org/c3ref/declare_vtab.html
func (c *Conn) DeclareVTab(sql string) error {
if c.interrupt.Err() != nil {
return INTERRUPT
}
defer c.arena.mark()()
textPtr := c.arena.string(sql)
c.checkInterrupt()
rc := res_t(c.call("sqlite3_declare_vtab", stk_t(c.handle), stk_t(textPtr)))
return c.error(rc)
}