From 6509e5deb26ce3c321b398caab6c566e875d16fd Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sun, 26 Feb 2023 03:22:08 +0000 Subject: [PATCH] Transactions. --- README.md | 3 +- conn.go | 76 --------- tests/{save_test.go => tx_test.go} | 240 ++++++++++++++++++++++++++++- tx.go | 171 ++++++++++++++++++++ 4 files changed, 412 insertions(+), 78 deletions(-) rename tests/{save_test.go => tx_test.go} (51%) create mode 100644 tx.go diff --git a/README.md b/README.md index 7258e7b..1939b36 100644 --- a/README.md +++ b/README.md @@ -58,4 +58,5 @@ and WAL databases are not supported. - [`modernc.org/sqlite`](https://pkg.go.dev/modernc.org/sqlite) - [`crawshaw.io/sqlite`](https://pkg.go.dev/crawshaw.io/sqlite) -- [`github.com/mattn/go-sqlite3`](https://pkg.go.dev/github.com/mattn/go-sqlite3) \ No newline at end of file +- [`github.com/mattn/go-sqlite3`](https://pkg.go.dev/github.com/mattn/go-sqlite3) +- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite) \ No newline at end of file diff --git a/conn.go b/conn.go index 9987c15..9a620c9 100644 --- a/conn.go +++ b/conn.go @@ -2,10 +2,7 @@ package sqlite3 import ( "context" - "errors" - "fmt" "math" - "runtime" "sync" "github.com/tetratelabs/wazero/api" @@ -267,79 +264,6 @@ func (c *Conn) sendInterrupt() { c.call(c.api.interrupt, uint64(c.handle)) } -// Savepoint creates a named SQLite transaction using SAVEPOINT. -// -// On success Savepoint returns a release func that will call -// either RELEASE or ROLLBACK depending on whether the parameter *error -// points to a nil or non-nil error. -// -// This is meant to be deferred: -// -// func doWork(conn *sqlite3.Conn) (err error) { -// defer conn.Savepoint()(&err) -// -// // ... do work in the transaction -// } -func (conn *Conn) Savepoint() (release func(*error)) { - name := "sqlite3.Savepoint" // names can be reused - var pc [1]uintptr - if n := runtime.Callers(2, pc[:]); n > 0 { - frames := runtime.CallersFrames(pc[:n]) - frame, _ := frames.Next() - if frame.Function != "" { - name = frame.Function - } - } - - err := conn.Exec(fmt.Sprintf("SAVEPOINT %q;", name)) - if err != nil { - if errors.Is(err, INTERRUPT) { - return func(errp *error) { - if *errp == nil { - *errp = err - } - } - } - panic(err) - } - - return func(errp *error) { - recovered := recover() - if recovered != nil { - defer panic(recovered) - } - - if conn.GetAutocommit() { - // There is nothing to commit/rollback. - return - } - - if *errp == nil && recovered == nil { - // Success path. - // RELEASE the savepoint successfully. - *errp = conn.Exec(fmt.Sprintf("RELEASE %q;", name)) - if *errp == nil { - return - } - // Possible interrupt, fall through to the error path. - } - - // Error path. - // Always ROLLBACK even if the connection has been interrupted. - old := conn.SetInterrupt(context.Background()) - defer conn.SetInterrupt(old) - - err := conn.Exec(fmt.Sprintf("ROLLBACK TO %q;", name)) - if err != nil { - panic(err) - } - err = conn.Exec(fmt.Sprintf("RELEASE %q;", name)) - if err != nil { - panic(err) - } - } -} - // Pragma executes a PRAGMA statement and returns any result as a string. // // https://www.sqlite.org/pragma.html diff --git a/tests/save_test.go b/tests/tx_test.go similarity index 51% rename from tests/save_test.go rename to tests/tx_test.go index 23b88f2..fed6970 100644 --- a/tests/save_test.go +++ b/tests/tx_test.go @@ -8,6 +8,244 @@ import ( "github.com/ncruces/go-sqlite3" ) +func TestConn_Transaction_exec(t *testing.T) { + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + t.Fatal(err) + } + + errFailed := errors.New("failed") + + count := func() int { + stmt, _, err := db.Prepare(`SELECT count(*) FROM test`) + if err != nil { + t.Fatal(err) + } + if stmt.Step() { + return stmt.ColumnInt(0) + } + t.Fatal(stmt.Err()) + return 0 + } + + insert := func(succeed bool) (err error) { + tx := db.Begin() + defer tx.End(&err) + + err = db.Exec(`INSERT INTO test VALUES ('hello')`) + if err != nil { + t.Fatal(err) + } + + if succeed { + return nil + } + return errFailed + } + + err = insert(true) + if err != nil { + t.Fatal(err) + } + if got := count(); got != 1 { + t.Errorf("got %d, want 1", got) + } + + err = insert(true) + if err != nil { + t.Fatal(err) + } + if got := count(); got != 2 { + t.Errorf("got %d, want 2", got) + } + + err = insert(false) + if err != errFailed { + t.Errorf("got %v, want errFailed", err) + } + if got := count(); got != 2 { + t.Errorf("got %d, want 2", got) + } +} + +func TestConn_Transaction_panic(t *testing.T) { + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + t.Fatal(err) + } + + err = db.Exec(`INSERT INTO test VALUES ('one');`) + if err != nil { + t.Fatal(err) + } + + panics := func() (err error) { + tx := db.Begin() + defer tx.End(&err) + + err = db.Exec(`INSERT INTO test VALUES ('hello')`) + if err != nil { + return err + } + + panic("omg!") + } + + defer func() { + p := recover() + if p != "omg!" { + t.Errorf("got %v, want panic", p) + } + + stmt, _, err := db.Prepare(`SELECT count(*) FROM test`) + if err != nil { + t.Fatal(err) + } + if stmt.Step() { + got := stmt.ColumnInt(0) + if got != 1 { + t.Errorf("got %d, want 1", got) + } + return + } + t.Fatal(stmt.Err()) + }() + + err = panics() + if err != nil { + t.Error(err) + } +} + +func TestConn_Transaction_interrupt(t *testing.T) { + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + t.Fatal(err) + } + + tx, err := db.BeginImmediate() + if err != nil { + t.Fatal(err) + } + err = db.Exec(`INSERT INTO test(col) VALUES(1)`) + if err != nil { + t.Fatal(err) + } + tx.End(&err) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + db.SetInterrupt(ctx) + + tx, err = db.BeginExclusive() + if err != nil { + t.Fatal(err) + } + err = db.Exec(`INSERT INTO test(col) VALUES(2)`) + if err != nil { + t.Fatal(err) + } + + cancel() + _, err = db.BeginImmediate() + if !errors.Is(err, sqlite3.INTERRUPT) { + t.Errorf("got %v, want sqlite3.INTERRUPT", err) + } + + err = db.Exec(`INSERT INTO test(col) VALUES(3)`) + if !errors.Is(err, sqlite3.INTERRUPT) { + t.Errorf("got %v, want sqlite3.INTERRUPT", err) + } + + var nilErr error + tx.End(&nilErr) + if !errors.Is(nilErr, sqlite3.INTERRUPT) { + t.Errorf("got %v, want sqlite3.INTERRUPT", nilErr) + } + + db.SetInterrupt(context.Background()) + stmt, _, err := db.Prepare(`SELECT count(*) FROM test`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if stmt.Step() { + got := stmt.ColumnInt(0) + if got != 1 { + t.Errorf("got %d, want 1", got) + } + } + err = stmt.Err() + if err != nil { + t.Error(err) + } +} + +func TestConn_Transaction_rollback(t *testing.T) { + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + t.Fatal(err) + } + + tx := db.Begin() + err = db.Exec(`INSERT INTO test(col) VALUES(1)`) + if err != nil { + t.Fatal(err) + } + err = db.Exec(`COMMIT`) + if err != nil { + t.Fatal(err) + } + tx.End(&err) + if err != nil { + t.Fatal(err) + } + + stmt, _, err := db.Prepare(`SELECT count(*) FROM test`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if stmt.Step() { + got := stmt.ColumnInt(0) + if got != 1 { + t.Errorf("got %d, want 1", got) + } + } + err = stmt.Err() + if err != nil { + t.Error(err) + } +} + func TestConn_Savepoint_exec(t *testing.T) { db, err := sqlite3.Open(":memory:") if err != nil { @@ -183,7 +421,7 @@ func TestConn_Savepoint_interrupt(t *testing.T) { var nilErr error release1(&nilErr) if !errors.Is(nilErr, sqlite3.INTERRUPT) { - t.Errorf("got %v, want sqlite3.INTERRUPT", err) + t.Errorf("got %v, want sqlite3.INTERRUPT", nilErr) } db.SetInterrupt(context.Background()) diff --git a/tx.go b/tx.go new file mode 100644 index 0000000..aeb2887 --- /dev/null +++ b/tx.go @@ -0,0 +1,171 @@ +package sqlite3 + +import ( + "context" + "errors" + "fmt" + "runtime" +) + +type Tx struct { + c *Conn +} + +// Begin starts a deferred transaction. +// +// https://www.sqlite.org/lang_transaction.html +func (c *Conn) Begin() Tx { + err := c.Exec(`BEGIN DEFERRED`) + if err != nil && !errors.Is(err, INTERRUPT) { + panic(err) + } + return Tx{c} +} + +// BeginImmediate starts an immediate transaction. +// +// https://www.sqlite.org/lang_transaction.html +func (c *Conn) BeginImmediate() (Tx, error) { + err := c.Exec(`BEGIN IMMEDIATE`) + if err != nil { + return Tx{}, err + } + return Tx{c}, nil +} + +// BeginExclusive starts an exclusive transaction. +// +// https://www.sqlite.org/lang_transaction.html +func (c *Conn) BeginExclusive() (Tx, error) { + err := c.Exec(`BEGIN EXCLUSIVE`) + if err != nil { + return Tx{}, err + } + return Tx{c}, nil +} + +// End calls either [Commit] or [Rollback] +// depending on whether *error points to a nil or non-nil error. +// +// This is meant to be deferred: +// +// func doWork(conn *sqlite3.Conn) (err error) { +// tx := conn.Begin() +// defer tx.End(&err) +// +// // ... do work in the transaction +// } +// +// https://www.sqlite.org/lang_savepoint.html +func (tx Tx) End(errp *error) { + recovered := recover() + if recovered != nil { + defer panic(recovered) + } + + if tx.c.GetAutocommit() { + // There is nothing to commit/rollback. + return + } + + if *errp == nil && recovered == nil { + // Success path. + *errp = tx.Commit() + if *errp == nil { + return + } + // Possible interrupt, fall through to the error path. + } + + // Error path. + err := tx.Rollback() + if err != nil { + panic(err) + } +} + +func (tx Tx) Commit() error { + return tx.c.Exec(`COMMIT`) +} + +func (tx Tx) Rollback() error { + // ROLLBACK even if the connection has been interrupted. + old := tx.c.SetInterrupt(context.Background()) + defer tx.c.SetInterrupt(old) + return tx.c.Exec(`ROLLBACK`) +} + +// Savepoint creates a named SQLite transaction using SAVEPOINT. +// +// On success Savepoint returns a release func that will call either +// RELEASE or ROLLBACK depending on whether the parameter *error +// points to a nil or non-nil error. +// +// This is meant to be deferred: +// +// func doWork(conn *sqlite3.Conn) (err error) { +// defer conn.Savepoint()(&err) +// +// // ... do work in the transaction +// } +// +// https://www.sqlite.org/lang_savepoint.html +func (c *Conn) Savepoint() (release func(*error)) { + name := "sqlite3.Savepoint" // names can be reused + var pc [1]uintptr + if n := runtime.Callers(2, pc[:]); n > 0 { + frames := runtime.CallersFrames(pc[:n]) + frame, _ := frames.Next() + if frame.Function != "" { + name = frame.Function + } + } + + err := c.Exec(fmt.Sprintf("SAVEPOINT %q;", name)) + if err != nil { + if errors.Is(err, INTERRUPT) { + return func(errp *error) { + if *errp == nil { + *errp = err + } + } + } + panic(err) + } + + return func(errp *error) { + recovered := recover() + if recovered != nil { + defer panic(recovered) + } + + if c.GetAutocommit() { + // There is nothing to commit/rollback. + return + } + + if *errp == nil && recovered == nil { + // Success path. + // RELEASE the savepoint successfully. + *errp = c.Exec(fmt.Sprintf("RELEASE %q;", name)) + if *errp == nil { + return + } + // Possible interrupt, fall through to the error path. + } + + // Error path. + // Always ROLLBACK even if the connection has been interrupted. + old := c.SetInterrupt(context.Background()) + defer c.SetInterrupt(old) + + err := c.Exec(fmt.Sprintf("ROLLBACK TO %q;", name)) + if err != nil { + panic(err) + } + err = c.Exec(fmt.Sprintf("RELEASE %q;", name)) + if err != nil { + panic(err) + } + } +}