From e5c285b783f1e8e68f049a5ad5c4d18431dd705b Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 28 Mar 2025 11:10:51 +0000 Subject: [PATCH] Discussion #250. --- conn.go | 10 +++++----- driver/driver.go | 11 ++++------- tests/conn_test.go | 5 +---- txn.go | 3 +++ 4 files changed, 13 insertions(+), 16 deletions(-) diff --git a/conn.go b/conn.go index 10d3cd2..7e88d8c 100644 --- a/conn.go +++ b/conn.go @@ -343,6 +343,9 @@ func (c *Conn) GetInterrupt() context.Context { // // https://sqlite.org/c3ref/interrupt.html func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { + if ctx == nil { + panic("nil Context") + } old = c.interrupt c.interrupt = ctx return old @@ -406,11 +409,8 @@ func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool)) func busyCallback(ctx context.Context, mod api.Module, pDB ptr_t, count int32) (retry int32) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil { - interrupt := c.interrupt - if interrupt == nil { - interrupt = context.Background() - } - if interrupt.Err() == nil && c.busy(interrupt, int(count)) { + if interrupt := c.interrupt; interrupt.Err() == nil && + c.busy(interrupt, int(count)) { retry = 1 } } diff --git a/driver/driver.go b/driver/driver.go index 21799ae..871aa74 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -358,13 +358,10 @@ func (c *conn) Commit() error { } func (c *conn) Rollback() error { - err := c.Conn.Exec(`ROLLBACK` + c.txReset) - if errors.Is(err, sqlite3.INTERRUPT) { - old := c.Conn.SetInterrupt(context.Background()) - defer c.Conn.SetInterrupt(old) - err = c.Conn.Exec(`ROLLBACK` + c.txReset) - } - return err + // ROLLBACK even if interrupted. + old := c.Conn.SetInterrupt(context.Background()) + defer c.Conn.SetInterrupt(old) + return c.Conn.Exec(`ROLLBACK` + c.txReset) } func (c *conn) Prepare(query string) (driver.Stmt, error) { diff --git a/tests/conn_test.go b/tests/conn_test.go index 2c017c0..0ab2405 100644 --- a/tests/conn_test.go +++ b/tests/conn_test.go @@ -146,10 +146,7 @@ func TestConn_SetInterrupt(t *testing.T) { } defer stmt.Close() - go func() { - time.Sleep(time.Millisecond) - cancel() - }() + time.AfterFunc(time.Millisecond, cancel) // Interrupting works. err = stmt.Exec() diff --git a/txn.go b/txn.go index a21b99a..7a5e112 100644 --- a/txn.go +++ b/txn.go @@ -20,6 +20,8 @@ type Txn struct { } // Begin starts a deferred transaction. +// Panics if a transaction is already in-progress. +// For nested transactions, use [Conn.Savepoint]. // // https://sqlite.org/lang_transaction.html func (c *Conn) Begin() Txn { @@ -119,6 +121,7 @@ func (tx Txn) Commit() error { // // https://sqlite.org/lang_transaction.html func (tx Txn) Rollback() error { + // ROLLBACK even if interrupted. return tx.c.exec(`ROLLBACK`) }