diff --git a/conn.go b/conn.go index be0c9aa..01cc1a3 100644 --- a/conn.go +++ b/conn.go @@ -332,6 +332,6 @@ type DriverConn interface { driver.ExecerContext driver.ConnPrepareContext - Savepoint() (release func(*error)) + Savepoint() Savepoint OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) } diff --git a/driver/driver.go b/driver/driver.go index 0c56212..b97e30c 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -189,7 +189,7 @@ func (c conn) ExecContext(ctx context.Context, query string, args []driver.Named }, nil } -func (c conn) Savepoint() (release func(*error)) { +func (c conn) Savepoint() sqlite3.Savepoint { return c.conn.Savepoint() } diff --git a/driver_test.go b/driver_test.go index 6e9c931..3b5bad5 100644 --- a/driver_test.go +++ b/driver_test.go @@ -48,7 +48,8 @@ func ExampleDriverConn() { err = conn.Raw(func(driverConn any) error { conn := driverConn.(sqlite3.DriverConn) - defer conn.Savepoint()(&err) + savept := conn.Savepoint() + defer savept.Release(&err) blob, err := conn.OpenBlob("main", "test", "col", id, true) if err != nil { diff --git a/tests/tx_test.go b/tests/tx_test.go index bb63531..45470e5 100644 --- a/tests/tx_test.go +++ b/tests/tx_test.go @@ -185,10 +185,10 @@ func TestConn_Transaction_interrupt(t *testing.T) { t.Errorf("got %v, want sqlite3.INTERRUPT", err) } - var nilErr error - tx.End(&nilErr) - if !errors.Is(nilErr, sqlite3.INTERRUPT) { - t.Errorf("got %v, want sqlite3.INTERRUPT", nilErr) + err = nil + tx.End(&err) + if !errors.Is(err, sqlite3.INTERRUPT) { + t.Errorf("got %v, want sqlite3.INTERRUPT", err) } db.SetInterrupt(context.Background()) @@ -210,6 +210,33 @@ func TestConn_Transaction_interrupt(t *testing.T) { } } +func TestConn_Transaction_interrupted(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + db.SetInterrupt(ctx) + cancel() + + tx := db.Begin() + + err = tx.Commit() + if !errors.Is(err, sqlite3.INTERRUPT) { + t.Errorf("got %v, want sqlite3.INTERRUPT", err) + } + + err = nil + tx.End(&err) + if !errors.Is(err, sqlite3.INTERRUPT) { + t.Errorf("got %v, want sqlite3.INTERRUPT", err) + } +} + func TestConn_Transaction_rollback(t *testing.T) { t.Parallel() @@ -286,7 +313,7 @@ func TestConn_Savepoint_exec(t *testing.T) { } insert := func(succeed bool) (err error) { - defer db.Savepoint()(&err) + defer db.Savepoint().Release(&err) err = db.Exec(`INSERT INTO test VALUES ('hello')`) if err != nil { @@ -344,7 +371,7 @@ func TestConn_Savepoint_panic(t *testing.T) { } panics := func() (err error) { - defer db.Savepoint()(&err) + defer db.Savepoint().Release(&err) err = db.Exec(`INSERT INTO test VALUES ('hello')`) if err != nil { @@ -395,12 +422,12 @@ func TestConn_Savepoint_interrupt(t *testing.T) { t.Fatal(err) } - release := db.Savepoint() + savept := db.Savepoint() err = db.Exec(`INSERT INTO test VALUES (1)`) if err != nil { t.Fatal(err) } - release(&err) + savept.Release(&err) if err != nil { t.Fatal(err) } @@ -408,19 +435,19 @@ func TestConn_Savepoint_interrupt(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) db.SetInterrupt(ctx) - release1 := db.Savepoint() + savept1 := db.Savepoint() err = db.Exec(`INSERT INTO test VALUES (2)`) if err != nil { t.Fatal(err) } - release2 := db.Savepoint() + savept2 := db.Savepoint() err = db.Exec(`INSERT INTO test VALUES (3)`) if err != nil { t.Fatal(err) } cancel() - db.Savepoint()(&err) + db.Savepoint().Release(&err) if !errors.Is(err, sqlite3.INTERRUPT) { t.Errorf("got %v, want sqlite3.INTERRUPT", err) } @@ -431,15 +458,15 @@ func TestConn_Savepoint_interrupt(t *testing.T) { } err = context.Canceled - release2(&err) + savept2.Release(&err) if err != context.Canceled { t.Fatal(err) } - var nilErr error - release1(&nilErr) - if !errors.Is(nilErr, sqlite3.INTERRUPT) { - t.Errorf("got %v, want sqlite3.INTERRUPT", nilErr) + err = nil + savept1.Release(&err) + if !errors.Is(err, sqlite3.INTERRUPT) { + t.Errorf("got %v, want sqlite3.INTERRUPT", err) } db.SetInterrupt(context.Background()) @@ -475,7 +502,7 @@ func TestConn_Savepoint_rollback(t *testing.T) { t.Fatal(err) } - release := db.Savepoint() + savept := db.Savepoint() err = db.Exec(`INSERT INTO test VALUES (1)`) if err != nil { t.Fatal(err) @@ -484,7 +511,7 @@ func TestConn_Savepoint_rollback(t *testing.T) { if err != nil { t.Fatal(err) } - release(&err) + savept.Release(&err) if err != nil { t.Fatal(err) } diff --git a/tx.go b/tx.go index 71300cf..0eeca6d 100644 --- a/tx.go +++ b/tx.go @@ -4,10 +4,14 @@ import ( "context" "errors" "fmt" + "math/rand" "runtime" + "strconv" ) // Tx is an in-progress database transaction. +// +// https://www.sqlite.org/lang_transaction.html type Tx struct { c *Conn } @@ -16,8 +20,9 @@ type Tx struct { // // https://www.sqlite.org/lang_transaction.html func (c *Conn) Begin() Tx { - err := c.Exec(`BEGIN DEFERRED`) - if err != nil && !errors.Is(err, INTERRUPT) { + // BEGIN even if interrupted. + err := c.txExecInterrupted(`BEGIN DEFERRED`) + if err != nil { panic(err) } return Tx{c} @@ -64,21 +69,22 @@ func (tx Tx) End(errp *error) { defer panic(recovered) } - if tx.c.GetAutocommit() { - // There is nothing to commit/rollback. - return - } - - if *errp == nil && recovered == nil { + if (errp == nil || *errp == nil) && recovered == nil { // Success path. + if tx.c.GetAutocommit() { // There is nothing to commit. + return + } *errp = tx.Commit() if *errp == nil { return } - // Possible interrupt, fall through to the error path. + // Fall through to the error path. } // Error path. + if tx.c.GetAutocommit() { // There is nothing to rollback. + return + } err := tx.Rollback() if err != nil { panic(err) @@ -92,33 +98,28 @@ func (tx Tx) Commit() error { return tx.c.Exec(`COMMIT`) } -// Rollback rolls back the transaction. +// Rollback rolls back the transaction, +// even if the connection has been interrupted. // // https://www.sqlite.org/lang_transaction.html func (tx Tx) Rollback() error { - // ROLLBACK even if the connection has been interrupted. - old := tx.c.SetInterrupt(context.Background()) - defer tx.c.SetInterrupt(old) - return tx.c.Exec(`ROLLBACK`) + return tx.c.txExecInterrupted(`ROLLBACK`) } -// Savepoint creates a named SQLite transaction using SAVEPOINT. -// -// On success Savepoint returns a release func that will call either -// RELEASE or ROLLBACK depending on whether the parameter *error -// points to a nil or non-nil error. -// -// This is meant to be deferred: -// -// func doWork(conn *sqlite3.Conn) (err error) { -// defer conn.Savepoint()(&err) -// -// // ... do work in the transaction -// } +// Savepoint is a marker within a transaction +// that allows for partial rollback. // // https://www.sqlite.org/lang_savepoint.html -func (c *Conn) Savepoint() (release func(*error)) { - name := "sqlite3.Savepoint" // names can be reused +type Savepoint struct { + c *Conn + name string +} + +// Savepoint establishes a new transaction savepoint. +// +// https://www.sqlite.org/lang_savepoint.html +func (c *Conn) Savepoint() Savepoint { + name := "sqlite3.Savepoint" var pc [1]uintptr if n := runtime.Callers(2, pc[:]); n > 0 { frames := runtime.CallersFrames(pc[:n]) @@ -127,52 +128,75 @@ func (c *Conn) Savepoint() (release func(*error)) { name = frame.Function } } + // Names can be reused; this makes catching bugs more likely. + name += "#" + strconv.Itoa(int(rand.Int31())) - err := c.Exec(fmt.Sprintf("SAVEPOINT %q;", name)) + err := c.txExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name)) if err != nil { - if errors.Is(err, INTERRUPT) { - return func(errp *error) { - if *errp == nil { - *errp = err - } - } - } panic(err) } + return Savepoint{c: c, name: name} +} - return func(errp *error) { - recovered := recover() - if recovered != nil { - defer panic(recovered) - } +// Release releases the savepoint rolling back any changes +// if *error points to a non-nil error. +// +// This is meant to be deferred: +// +// func doWork(conn *sqlite3.Conn) (err error) { +// savept := conn.Savepoint() +// defer savept.Release(&err) +// +// // ... do work in the transaction +// } +func (s Savepoint) Release(errp *error) { + recovered := recover() + if recovered != nil { + defer panic(recovered) + } - if c.GetAutocommit() { - // There is nothing to commit/rollback. + if (errp == nil || *errp == nil) && recovered == nil { + // Success path. + if s.c.GetAutocommit() { // There is nothing to commit. return } - - if *errp == nil && recovered == nil { - // Success path. - // RELEASE the savepoint successfully. - *errp = c.Exec(fmt.Sprintf("RELEASE %q;", name)) - if *errp == nil { - return - } - // Possible interrupt, fall through to the error path. + *errp = s.c.Exec(fmt.Sprintf("RELEASE %q;", s.name)) + if *errp == nil { + return } + // Fall through to the error path. + } - // Error path. - // Always ROLLBACK even if the connection has been interrupted. - old := c.SetInterrupt(context.Background()) - defer c.SetInterrupt(old) - - err := c.Exec(fmt.Sprintf("ROLLBACK TO %q;", name)) - if err != nil { - panic(err) - } - err = c.Exec(fmt.Sprintf("RELEASE %q;", name)) - if err != nil { - panic(err) - } + // Error path. + if s.c.GetAutocommit() { // There is nothing to rollback. + return + } + // ROLLBACK and RELEASE even if interrupted. + err := s.c.txExecInterrupted(fmt.Sprintf(` + ROLLBACK TO %[1]q; + RELEASE %[1]q; + `, s.name)) + if err != nil { + panic(err) } } + +// Rollback rolls the transaction back to the savepoint, +// even if the connection has been interrupted. +// Rollback does not release the savepoint. +// +// https://www.sqlite.org/lang_transaction.html +func (s Savepoint) Rollback() error { + // ROLLBACK even if interrupted. + return s.c.txExecInterrupted(fmt.Sprintf("ROLLBACK TO %q;", s.name)) +} + +func (c *Conn) txExecInterrupted(sql string) error { + err := c.Exec(sql) + if errors.Is(err, INTERRUPT) { + old := c.SetInterrupt(context.Background()) + defer c.SetInterrupt(old) + err = c.Exec(sql) + } + return err +}