mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Transactions.
This commit is contained in:
@@ -58,4 +58,5 @@ and WAL databases are not supported.
|
||||
|
||||
- [`modernc.org/sqlite`](https://pkg.go.dev/modernc.org/sqlite)
|
||||
- [`crawshaw.io/sqlite`](https://pkg.go.dev/crawshaw.io/sqlite)
|
||||
- [`github.com/mattn/go-sqlite3`](https://pkg.go.dev/github.com/mattn/go-sqlite3)
|
||||
- [`github.com/mattn/go-sqlite3`](https://pkg.go.dev/github.com/mattn/go-sqlite3)
|
||||
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)
|
||||
76
conn.go
76
conn.go
@@ -2,10 +2,7 @@ package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
@@ -267,79 +264,6 @@ func (c *Conn) sendInterrupt() {
|
||||
c.call(c.api.interrupt, uint64(c.handle))
|
||||
}
|
||||
|
||||
// Savepoint creates a named SQLite transaction using SAVEPOINT.
|
||||
//
|
||||
// On success Savepoint returns a release func that will call
|
||||
// either RELEASE or ROLLBACK depending on whether the parameter *error
|
||||
// points to a nil or non-nil error.
|
||||
//
|
||||
// This is meant to be deferred:
|
||||
//
|
||||
// func doWork(conn *sqlite3.Conn) (err error) {
|
||||
// defer conn.Savepoint()(&err)
|
||||
//
|
||||
// // ... do work in the transaction
|
||||
// }
|
||||
func (conn *Conn) Savepoint() (release func(*error)) {
|
||||
name := "sqlite3.Savepoint" // names can be reused
|
||||
var pc [1]uintptr
|
||||
if n := runtime.Callers(2, pc[:]); n > 0 {
|
||||
frames := runtime.CallersFrames(pc[:n])
|
||||
frame, _ := frames.Next()
|
||||
if frame.Function != "" {
|
||||
name = frame.Function
|
||||
}
|
||||
}
|
||||
|
||||
err := conn.Exec(fmt.Sprintf("SAVEPOINT %q;", name))
|
||||
if err != nil {
|
||||
if errors.Is(err, INTERRUPT) {
|
||||
return func(errp *error) {
|
||||
if *errp == nil {
|
||||
*errp = err
|
||||
}
|
||||
}
|
||||
}
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return func(errp *error) {
|
||||
recovered := recover()
|
||||
if recovered != nil {
|
||||
defer panic(recovered)
|
||||
}
|
||||
|
||||
if conn.GetAutocommit() {
|
||||
// There is nothing to commit/rollback.
|
||||
return
|
||||
}
|
||||
|
||||
if *errp == nil && recovered == nil {
|
||||
// Success path.
|
||||
// RELEASE the savepoint successfully.
|
||||
*errp = conn.Exec(fmt.Sprintf("RELEASE %q;", name))
|
||||
if *errp == nil {
|
||||
return
|
||||
}
|
||||
// Possible interrupt, fall through to the error path.
|
||||
}
|
||||
|
||||
// Error path.
|
||||
// Always ROLLBACK even if the connection has been interrupted.
|
||||
old := conn.SetInterrupt(context.Background())
|
||||
defer conn.SetInterrupt(old)
|
||||
|
||||
err := conn.Exec(fmt.Sprintf("ROLLBACK TO %q;", name))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = conn.Exec(fmt.Sprintf("RELEASE %q;", name))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pragma executes a PRAGMA statement and returns any result as a string.
|
||||
//
|
||||
// https://www.sqlite.org/pragma.html
|
||||
|
||||
@@ -8,6 +8,244 @@ import (
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestConn_Transaction_exec(t *testing.T) {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
errFailed := errors.New("failed")
|
||||
|
||||
count := func() int {
|
||||
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
return stmt.ColumnInt(0)
|
||||
}
|
||||
t.Fatal(stmt.Err())
|
||||
return 0
|
||||
}
|
||||
|
||||
insert := func(succeed bool) (err error) {
|
||||
tx := db.Begin()
|
||||
defer tx.End(&err)
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if succeed {
|
||||
return nil
|
||||
}
|
||||
return errFailed
|
||||
}
|
||||
|
||||
err = insert(true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := count(); got != 1 {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
|
||||
err = insert(true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := count(); got != 2 {
|
||||
t.Errorf("got %d, want 2", got)
|
||||
}
|
||||
|
||||
err = insert(false)
|
||||
if err != errFailed {
|
||||
t.Errorf("got %v, want errFailed", err)
|
||||
}
|
||||
if got := count(); got != 2 {
|
||||
t.Errorf("got %d, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Transaction_panic(t *testing.T) {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES ('one');`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
panics := func() (err error) {
|
||||
tx := db.Begin()
|
||||
defer tx.End(&err)
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
panic("omg!")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
p := recover()
|
||||
if p != "omg!" {
|
||||
t.Errorf("got %v, want panic", p)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
got := stmt.ColumnInt(0)
|
||||
if got != 1 {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
t.Fatal(stmt.Err())
|
||||
}()
|
||||
|
||||
err = panics()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Transaction_interrupt(t *testing.T) {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tx, err := db.BeginImmediate()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.Exec(`INSERT INTO test(col) VALUES(1)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tx.End(&err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
db.SetInterrupt(ctx)
|
||||
|
||||
tx, err = db.BeginExclusive()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.Exec(`INSERT INTO test(col) VALUES(2)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
_, err = db.BeginImmediate()
|
||||
if !errors.Is(err, sqlite3.INTERRUPT) {
|
||||
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO test(col) VALUES(3)`)
|
||||
if !errors.Is(err, sqlite3.INTERRUPT) {
|
||||
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
|
||||
}
|
||||
|
||||
var nilErr error
|
||||
tx.End(&nilErr)
|
||||
if !errors.Is(nilErr, sqlite3.INTERRUPT) {
|
||||
t.Errorf("got %v, want sqlite3.INTERRUPT", nilErr)
|
||||
}
|
||||
|
||||
db.SetInterrupt(context.Background())
|
||||
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
got := stmt.ColumnInt(0)
|
||||
if got != 1 {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
}
|
||||
err = stmt.Err()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Transaction_rollback(t *testing.T) {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tx := db.Begin()
|
||||
err = db.Exec(`INSERT INTO test(col) VALUES(1)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.Exec(`COMMIT`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tx.End(&err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
got := stmt.ColumnInt(0)
|
||||
if got != 1 {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
}
|
||||
err = stmt.Err()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Savepoint_exec(t *testing.T) {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
@@ -183,7 +421,7 @@ func TestConn_Savepoint_interrupt(t *testing.T) {
|
||||
var nilErr error
|
||||
release1(&nilErr)
|
||||
if !errors.Is(nilErr, sqlite3.INTERRUPT) {
|
||||
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
|
||||
t.Errorf("got %v, want sqlite3.INTERRUPT", nilErr)
|
||||
}
|
||||
|
||||
db.SetInterrupt(context.Background())
|
||||
171
tx.go
Normal file
171
tx.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
type Tx struct {
|
||||
c *Conn
|
||||
}
|
||||
|
||||
// Begin starts a deferred transaction.
|
||||
//
|
||||
// https://www.sqlite.org/lang_transaction.html
|
||||
func (c *Conn) Begin() Tx {
|
||||
err := c.Exec(`BEGIN DEFERRED`)
|
||||
if err != nil && !errors.Is(err, INTERRUPT) {
|
||||
panic(err)
|
||||
}
|
||||
return Tx{c}
|
||||
}
|
||||
|
||||
// BeginImmediate starts an immediate transaction.
|
||||
//
|
||||
// https://www.sqlite.org/lang_transaction.html
|
||||
func (c *Conn) BeginImmediate() (Tx, error) {
|
||||
err := c.Exec(`BEGIN IMMEDIATE`)
|
||||
if err != nil {
|
||||
return Tx{}, err
|
||||
}
|
||||
return Tx{c}, nil
|
||||
}
|
||||
|
||||
// BeginExclusive starts an exclusive transaction.
|
||||
//
|
||||
// https://www.sqlite.org/lang_transaction.html
|
||||
func (c *Conn) BeginExclusive() (Tx, error) {
|
||||
err := c.Exec(`BEGIN EXCLUSIVE`)
|
||||
if err != nil {
|
||||
return Tx{}, err
|
||||
}
|
||||
return Tx{c}, nil
|
||||
}
|
||||
|
||||
// End calls either [Commit] or [Rollback]
|
||||
// depending on whether *error points to a nil or non-nil error.
|
||||
//
|
||||
// This is meant to be deferred:
|
||||
//
|
||||
// func doWork(conn *sqlite3.Conn) (err error) {
|
||||
// tx := conn.Begin()
|
||||
// defer tx.End(&err)
|
||||
//
|
||||
// // ... do work in the transaction
|
||||
// }
|
||||
//
|
||||
// https://www.sqlite.org/lang_savepoint.html
|
||||
func (tx Tx) End(errp *error) {
|
||||
recovered := recover()
|
||||
if recovered != nil {
|
||||
defer panic(recovered)
|
||||
}
|
||||
|
||||
if tx.c.GetAutocommit() {
|
||||
// There is nothing to commit/rollback.
|
||||
return
|
||||
}
|
||||
|
||||
if *errp == nil && recovered == nil {
|
||||
// Success path.
|
||||
*errp = tx.Commit()
|
||||
if *errp == nil {
|
||||
return
|
||||
}
|
||||
// Possible interrupt, fall through to the error path.
|
||||
}
|
||||
|
||||
// Error path.
|
||||
err := tx.Rollback()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (tx Tx) Commit() error {
|
||||
return tx.c.Exec(`COMMIT`)
|
||||
}
|
||||
|
||||
func (tx Tx) Rollback() error {
|
||||
// ROLLBACK even if the connection has been interrupted.
|
||||
old := tx.c.SetInterrupt(context.Background())
|
||||
defer tx.c.SetInterrupt(old)
|
||||
return tx.c.Exec(`ROLLBACK`)
|
||||
}
|
||||
|
||||
// Savepoint creates a named SQLite transaction using SAVEPOINT.
|
||||
//
|
||||
// On success Savepoint returns a release func that will call either
|
||||
// RELEASE or ROLLBACK depending on whether the parameter *error
|
||||
// points to a nil or non-nil error.
|
||||
//
|
||||
// This is meant to be deferred:
|
||||
//
|
||||
// func doWork(conn *sqlite3.Conn) (err error) {
|
||||
// defer conn.Savepoint()(&err)
|
||||
//
|
||||
// // ... do work in the transaction
|
||||
// }
|
||||
//
|
||||
// https://www.sqlite.org/lang_savepoint.html
|
||||
func (c *Conn) Savepoint() (release func(*error)) {
|
||||
name := "sqlite3.Savepoint" // names can be reused
|
||||
var pc [1]uintptr
|
||||
if n := runtime.Callers(2, pc[:]); n > 0 {
|
||||
frames := runtime.CallersFrames(pc[:n])
|
||||
frame, _ := frames.Next()
|
||||
if frame.Function != "" {
|
||||
name = frame.Function
|
||||
}
|
||||
}
|
||||
|
||||
err := c.Exec(fmt.Sprintf("SAVEPOINT %q;", name))
|
||||
if err != nil {
|
||||
if errors.Is(err, INTERRUPT) {
|
||||
return func(errp *error) {
|
||||
if *errp == nil {
|
||||
*errp = err
|
||||
}
|
||||
}
|
||||
}
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return func(errp *error) {
|
||||
recovered := recover()
|
||||
if recovered != nil {
|
||||
defer panic(recovered)
|
||||
}
|
||||
|
||||
if c.GetAutocommit() {
|
||||
// There is nothing to commit/rollback.
|
||||
return
|
||||
}
|
||||
|
||||
if *errp == nil && recovered == nil {
|
||||
// Success path.
|
||||
// RELEASE the savepoint successfully.
|
||||
*errp = c.Exec(fmt.Sprintf("RELEASE %q;", name))
|
||||
if *errp == nil {
|
||||
return
|
||||
}
|
||||
// Possible interrupt, fall through to the error path.
|
||||
}
|
||||
|
||||
// Error path.
|
||||
// Always ROLLBACK even if the connection has been interrupted.
|
||||
old := c.SetInterrupt(context.Background())
|
||||
defer c.SetInterrupt(old)
|
||||
|
||||
err := c.Exec(fmt.Sprintf("ROLLBACK TO %q;", name))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = c.Exec(fmt.Sprintf("RELEASE %q;", name))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user