From 01464960361c55a409df6cb6a6b47329a3060a5f Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 24 Feb 2023 14:31:41 +0000 Subject: [PATCH] Nested transactions. --- README.md | 48 +++++++- api.go | 2 + conn.go | 59 ++++++++-- conn_test.go | 3 +- driver/example_test.go | 5 +- embed/build.sh | 1 + embed/sqlite3.wasm | Bin 676281 -> 676392 bytes save.go | 76 ++++++++++++ stmt.go | 1 + tests/conn_test.go | 42 ++++--- tests/save_test.go | 260 +++++++++++++++++++++++++++++++++++++++++ 11 files changed, 458 insertions(+), 39 deletions(-) create mode 100644 save.go create mode 100644 tests/save_test.go diff --git a/README.md b/README.md index 22ff781..7258e7b 100644 --- a/README.md +++ b/README.md @@ -4,22 +4,58 @@ [![Go Report](https://goreportcard.com/badge/github.com/ncruces/go-sqlite3)](https://goreportcard.com/report/github.com/ncruces/go-sqlite3) [![Go Coverage](https://github.com/ncruces/go-sqlite3/wiki/coverage.svg)](https://raw.githack.com/wiki/ncruces/go-sqlite3/coverage.html) -⚠️ CAUTION ⚠️ +### ⚠️ Work in Progress ⚠️ -This is a WIP. +Go module `github.com/ncruces/go-sqlite3` wraps a [WASM](https://webassembly.org/) build of [SQLite](https://sqlite.org/), +and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings. + +- Package [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3) +wraps the [C SQLite API](https://www.sqlite.org/cintro.html) +([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)). +- Package [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver) +provides a [`database/sql`](https://pkg.go.dev/database/sql) driver +([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)). +- Package [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed) +embeds a build of SQLite into your application. + +### Caveats + +Because WASM does not support shared memory, +[WAL](https://www.sqlite.org/wal.html) support is [limited](https://www.sqlite.org/wal.html#noshm). + +To work around this limitation, SQLite is compiled with +[`SQLITE_DEFAULT_LOCKING_MODE=1`](https://www.sqlite.org/compile.html#default_locking_mode), +making `EXCLUSIVE` the default locking mode. +For non-WAL databases, `NORMAL` locking mode can be activated with +[`PRAGMA locking_mode=NORMAL`](https://www.sqlite.org/pragma.html#pragma_locking_mode). + +Because connection pooling is incompatible with `EXCLUSIVE` locking mode, +the `database/sql` driver defaults to `NORMAL` locking mode, +and WAL databases are not supported. + +### Roadmap -Roadmap: - [x] build SQLite using `zig cc --target=wasm32-wasi` - [x] `:memory:` databases - [x] port [`test_demovfs.c`](https://www.sqlite.org/src/doc/trunk/src/test_demovfs.c) to Go - branch [`wasi`](https://github.com/ncruces/go-sqlite3/tree/wasi) uses `test_demovfs.c` directly -- [x] design a simple, nice API, enough for simple use cases +- [x] design a nice API, enough for simple use cases - [x] provide a simple `database/sql` driver - [x] file locking, compatible with SQLite on macOS/Linux/Windows - [ ] advanced SQLite features - - [ ] nested transactions + - [x] nested transactions - [ ] incremental BLOB I/O - [ ] online backup - [ ] session extension - [ ] snapshots - - [ ] SQL functions \ No newline at end of file + - [ ] SQL functions +- [ ] custom VFSes + - [ ] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt) + - [ ] in-memory VFS, wrapping a [`bytes.Buffer`](https://pkg.go.dev/bytes#Buffer) + - [ ] expose a custom VFS API + +### Alternatives + +- [`modernc.org/sqlite`](https://pkg.go.dev/modernc.org/sqlite) +- [`crawshaw.io/sqlite`](https://pkg.go.dev/crawshaw.io/sqlite) +- [`github.com/mattn/go-sqlite3`](https://pkg.go.dev/github.com/mattn/go-sqlite3) \ No newline at end of file diff --git a/api.go b/api.go index 698cda1..0b0e598 100644 --- a/api.go +++ b/api.go @@ -62,6 +62,7 @@ func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) { columnText: getFun("sqlite3_column_text"), columnBlob: getFun("sqlite3_column_blob"), columnBytes: getFun("sqlite3_column_bytes"), + autocommit: getFun("sqlite3_get_autocommit"), lastRowid: getFun("sqlite3_last_insert_rowid"), changes: getFun("sqlite3_changes64"), interrupt: getFun("sqlite3_interrupt"), @@ -106,6 +107,7 @@ type sqliteAPI struct { columnText api.Function columnBlob api.Function columnBytes api.Function + autocommit api.Function lastRowid api.Function changes api.Function interrupt api.Function diff --git a/conn.go b/conn.go index b90ed51..034b978 100644 --- a/conn.go +++ b/conn.go @@ -91,6 +91,17 @@ func (c *Conn) Close() error { return c.mem.mod.Close(c.ctx) } +// GetAutocommit tests the connection for auto-commit mode. +// +// https://www.sqlite.org/c3ref/get_autocommit.html +func (c *Conn) GetAutocommit() bool { + r, err := c.api.autocommit.Call(c.ctx, uint64(c.handle)) + if err != nil { + panic(err) + } + return r[0] != 0 +} + // SetInterrupt interrupts a long-running query when done is closed. // // Subsequent uses of the connection will return [INTERRUPT] @@ -111,22 +122,31 @@ func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) { c.waiter = nil } - // Finalize the uncompleted SQL statement. - if c.pending != nil { - c.pending.Close() - c.pending = nil - } - old = c.done c.done = done if done == nil { + // Finalize the uncompleted SQL statement. + if c.pending != nil { + c.pending.Close() + c.pending = nil + } return old } // Creating an uncompleted SQL statement prevents SQLite from ignoring // an interrupt that comes before any other statements are started. - c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`) - c.pending.Step() + if c.pending == nil { + c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`) + c.pending.Step() + } else { + c.pending.Reset() + } + + // Don't create the goroutine if we're already interrupted. + // This happens frequently while restoring to a previously interrupted state. + if c.checkInterrupt() { + return old + } waiter := make(chan struct{}) c.waiter = waiter @@ -154,11 +174,25 @@ func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) { return old } +func (c *Conn) checkInterrupt() bool { + select { + case <-c.done: // Done was closed. + _, err := c.api.interrupt.Call(c.ctx, uint64(c.handle)) + if err != nil { + panic(err) + } + return true + default: + return false + } +} + // Exec is a convenience function that allows an application to run // multiple statements of SQL without having to use a lot of code. // // https://www.sqlite.org/c3ref/exec.html func (c *Conn) Exec(sql string) error { + c.checkInterrupt() defer c.arena.reset() sqlPtr := c.arena.string(sql) @@ -326,6 +360,15 @@ type arena struct { ptrs []uint32 } +func (a *arena) free() { + if a.c == nil { + return + } + a.reset() + a.c.free(a.base) + a.c = nil +} + func (a *arena) reset() { for _, ptr := range a.ptrs { a.c.free(ptr) diff --git a/conn_test.go b/conn_test.go index 94e1195..d00bfca 100644 --- a/conn_test.go +++ b/conn_test.go @@ -30,7 +30,7 @@ func TestConn_newArena(t *testing.T) { defer db.Close() arena := db.newArena(16) - defer arena.reset() + defer arena.free() const title = "Lorem ipsum" @@ -50,6 +50,7 @@ func TestConn_newArena(t *testing.T) { if got := db.mem.readString(ptr, math.MaxUint32); got != body { t.Errorf("got %q, want %q", got, body) } + arena.free() } func TestConn_newBytes(t *testing.T) { diff --git a/driver/example_test.go b/driver/example_test.go index f5df3f4..9c7842b 100644 --- a/driver/example_test.go +++ b/driver/example_test.go @@ -31,7 +31,8 @@ func Example() { defer db.Close() defer os.Remove("./recordings.db") - err = createAlbumsTable() + // Create a table with some data in it. + err = albumsSetup() if err != nil { log.Fatal(err) } @@ -65,7 +66,7 @@ func Example() { // ID of added album: 5 } -func createAlbumsTable() error { +func albumsSetup() error { _, err := db.Exec(` DROP TABLE IF EXISTS album; CREATE TABLE album ( diff --git a/embed/build.sh b/embed/build.sh index b6448c9..d80d08d 100755 --- a/embed/build.sh +++ b/embed/build.sh @@ -45,6 +45,7 @@ zig cc --target=wasm32-wasi -flto -g0 -Os \ -Wl,--export=sqlite3_column_text \ -Wl,--export=sqlite3_column_blob \ -Wl,--export=sqlite3_column_bytes \ + -Wl,--export=sqlite3_get_autocommit \ -Wl,--export=sqlite3_last_insert_rowid \ -Wl,--export=sqlite3_changes64 \ -Wl,--export=sqlite3_interrupt \ diff --git a/embed/sqlite3.wasm b/embed/sqlite3.wasm index cb9693d1fc1c8c717a875906520b34bccce03ce3..d83ac83d36c82ce93270d23caaad994a1f5c030f 100755 GIT binary patch delta 1062 zcmZ`&OH30{6n(F)z`Iz>uLVp*F=C8~XwbMZ>9kGM4GZf~p4enMQ)Z@!tW;10 z|BiV1851?exbPDfjR^}~lE{v|F($0sxuN3k7E>x17w2W}dH3COX5N{%k6WHSX?b(%hUCMM+iDj5zxyJZDTP*|OMZ^=S!Pw#O`6 zrr38;g`FX&I>LU4vi#1Ogsy6dLCJm!Z=N%yC2jUg1oA{&ld2i(PUt-ez1LzF0N)<~ ze=WNRm8&g=mQL!iZq4Xsmz+*f?H&R@b}Ha^v0b*lgy_h0^^l&&StIV?(pvky$S%xCm~7hp>TAL?|YdWM;i$ zXLWHgt7on3WM;8RJmN)yslwemaI z5h3Guu`5Dpb8s*MjqZD}9o1uW;B5qad>rQ@(7^BGmk31>u{2861hz%V{v@7=LOq|t Zbd-)y<2Y5dGtivR$J%?&JdM(O{RW!5WkCP{ delta 1007 zcmZWnNl#Nz6uy^NhI_D#Z5b+pOkp%ej4=`yCT>VNVdA=J%LSXfzE@sf39^wvWF8-I z7~BwDn7~3qVl{E0KY*wS3l{85*f}g^)JvftxcFXjzjMy_zB9afS^fS+_3L|*Gz^1l z_oZ!b#mNCPCpe8qxK0k44SzVmCXqi3e>Q^EPP!pCk#nNw zmsLOM0So$^V7X3u!FnN~MUDDpPWt|<)On|Fv@=Wc<1i=v#xk?zKRZkE2bC7@k+^U> zV0>n_xzflLFAH$mYOVx0L3Jn|`R zJR%yIfzuHOWHrLc>{*s*M<_?bI+=shr3psWylQLf(LN9||M3t|TR53NXDgy>WFelE z5kCPn%oT6P4YIiQJpFc(Y5h(4VBe;BVAD(mk_03HlR&aSia@GBT5QwIYV%_06>OQt z-d3`LkLGt9)TH=AoDOc#rP$UstLIoME~8+FS&uL9T0jkjd}3w1GpuU*E;c*BO82nM z0Y!8l`yEh04{*f+SLq?HJK#!^xT)_BsFb>)41YC>1FCUCChfsHPAI3n=yi&UKJ0fw zC4Gvkg2WXkXbYB2?&tEOfiXFNcu3po|XVkPAxb2(F4S9mTIM zsG?)I>w 0 { + frames := runtime.CallersFrames(pc[:n]) + frame, _ := frames.Next() + if frame.Function != "" { + name = frame.Function + } + } + + err := conn.Exec(fmt.Sprintf("SAVEPOINT %q;", name)) + if err != nil { + return func(errp *error) { + if *errp == nil { + *errp = err + } + } + } + + return func(errp *error) { + recovered := recover() + if recovered != nil { + defer panic(recovered) + } + + if conn.GetAutocommit() { + // There is nothing to commit/rollback. + return + } + + if *errp == nil && recovered == nil { + // Success path. + // RELEASE the savepoint successfully. + *errp = conn.Exec(fmt.Sprintf("RELEASE %q;", name)) + if *errp == nil { + return + } + // Possible interrupt, fall through to the error path. + } + + // Error path. + // Always ROLLBACK even if the connection has been interrupted. + old := conn.SetInterrupt(nil) + defer conn.SetInterrupt(old) + + err := conn.Exec(fmt.Sprintf("ROLLBACK TO %q;", name)) + if err != nil { + panic(err) + } + err = conn.Exec(fmt.Sprintf("RELEASE %q;", name)) + if err != nil { + panic(err) + } + } +} diff --git a/stmt.go b/stmt.go index 59997fe..919ea2a 100644 --- a/stmt.go +++ b/stmt.go @@ -66,6 +66,7 @@ func (s *Stmt) ClearBindings() error { // // https://www.sqlite.org/c3ref/step.html func (s *Stmt) Step() bool { + s.c.checkInterrupt() r, err := s.c.api.step.Call(s.c.ctx, uint64(s.handle)) if err != nil { panic(err) diff --git a/tests/conn_test.go b/tests/conn_test.go index fa80daf..32d9fa5 100644 --- a/tests/conn_test.go +++ b/tests/conn_test.go @@ -103,40 +103,38 @@ func TestConn_SetInterrupt(t *testing.T) { } defer stmt.Close() - cancel() db.SetInterrupt(ctx.Done()) + cancel() var serr *sqlite3.Error // Interrupting works. err = stmt.Exec() - if err != nil { - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) - } - if rc := serr.Code(); rc != sqlite3.INTERRUPT { - t.Errorf("got %d, want sqlite3.INTERRUPT", rc) - } - if got := err.Error(); got != `sqlite3: interrupted` { - t.Error("got message: ", got) - } + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != sqlite3.INTERRUPT { + t.Errorf("got %d, want sqlite3.INTERRUPT", rc) + } + if got := err.Error(); got != `sqlite3: interrupted` { + t.Error("got message: ", got) } // Interrupting sticks. err = db.Exec(`SELECT 1`) - if err != nil { - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) - } - if rc := serr.Code(); rc != sqlite3.INTERRUPT { - t.Errorf("got %d, want sqlite3.INTERRUPT", rc) - } - if got := err.Error(); got != `sqlite3: interrupted` { - t.Error("got message: ", got) - } + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != sqlite3.INTERRUPT { + t.Errorf("got %d, want sqlite3.INTERRUPT", rc) + } + if got := err.Error(); got != `sqlite3: interrupted` { + t.Error("got message: ", got) } - db.SetInterrupt(nil) + ctx, cancel = context.WithCancel(context.Background()) + defer cancel() + db.SetInterrupt(ctx.Done()) // Interrupting can be cleared. err = db.Exec(`SELECT 1`) diff --git a/tests/save_test.go b/tests/save_test.go new file mode 100644 index 0000000..2be1513 --- /dev/null +++ b/tests/save_test.go @@ -0,0 +1,260 @@ +package tests + +import ( + "context" + "errors" + "testing" + + "github.com/ncruces/go-sqlite3" +) + +func TestConn_Savepoint_exec(t *testing.T) { + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + t.Fatal(err) + } + + errFailed := errors.New("failed") + + count := func() int { + stmt, _, err := db.Prepare(`SELECT count(*) FROM test`) + if err != nil { + t.Fatal(err) + } + if stmt.Step() { + return stmt.ColumnInt(0) + } + t.Fatal(stmt.Err()) + return 0 + } + + insert := func(succeed bool) (err error) { + defer db.Savepoint()(&err) + + err = db.Exec(`INSERT INTO test VALUES ('hello')`) + if err != nil { + t.Fatal(err) + } + + if succeed { + return nil + } + return errFailed + } + + err = insert(true) + if err != nil { + t.Fatal(err) + } + if got := count(); got != 1 { + t.Errorf("got %d, want 1", got) + } + + err = insert(true) + if err != nil { + t.Fatal(err) + } + if got := count(); got != 2 { + t.Errorf("got %d, want 2", got) + } + + err = insert(false) + if err != errFailed { + t.Errorf("got %v, want errFailed", err) + } + if got := count(); got != 2 { + t.Errorf("got %d, want 2", got) + } +} + +func TestConn_Savepoint_panic(t *testing.T) { + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + t.Fatal(err) + } + + err = db.Exec(`INSERT INTO test VALUES ('one');`) + if err != nil { + t.Fatal(err) + } + + panics := func() (err error) { + defer db.Savepoint()(&err) + + err = db.Exec(`INSERT INTO test VALUES ('hello')`) + if err != nil { + return err + } + + panic("omg!") + } + + defer func() { + p := recover() + if p != "omg!" { + t.Errorf("got %v, want panic", p) + } + + stmt, _, err := db.Prepare(`SELECT count(*) FROM test`) + if err != nil { + t.Fatal(err) + } + if stmt.Step() { + got := stmt.ColumnInt(0) + if got != 1 { + t.Errorf("got %d, want 1", got) + } + return + } + t.Fatal(stmt.Err()) + }() + + err = panics() + if err != nil { + t.Error(err) + } +} + +func TestConn_Savepoint_interrupt(t *testing.T) { + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + t.Fatal(err) + } + + release := db.Savepoint() + err = db.Exec(`INSERT INTO test(col) VALUES(1)`) + if err != nil { + t.Fatal(err) + } + release(&err) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + db.SetInterrupt(ctx.Done()) + + release1 := db.Savepoint() + err = db.Exec(`INSERT INTO test(col) VALUES(2)`) + if err != nil { + t.Fatal(err) + } + release2 := db.Savepoint() + err = db.Exec(`INSERT INTO test(col) VALUES(3)`) + if err != nil { + t.Fatal(err) + } + + checkInterrupt := func(err error) { + var serr *sqlite3.Error + if err == nil { + t.Fatal("want error") + } + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != sqlite3.INTERRUPT { + t.Errorf("got %d, want sqlite3.INTERRUPT", rc) + } + if got := err.Error(); got != `sqlite3: interrupted` { + t.Error("got message: ", got) + } + } + + cancel() + db.Savepoint()(&err) + checkInterrupt(err) + + err = db.Exec(`INSERT INTO test(col) VALUES(4)`) + checkInterrupt(err) + + err = context.Canceled + release2(&err) + if err != context.Canceled { + t.Fatal(err) + } + + var nilErr error + release1(&nilErr) + checkInterrupt(nilErr) + + db.SetInterrupt(nil) + stmt, _, err := db.Prepare(`SELECT count(*) FROM test`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if stmt.Step() { + got := stmt.ColumnInt(0) + if got != 1 { + t.Errorf("got %d, want 1", got) + } + } + err = stmt.Err() + if err != nil { + t.Error(err) + } +} + +func TestConn_Savepoint_rollback(t *testing.T) { + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + t.Fatal(err) + } + + release := db.Savepoint() + err = db.Exec(`INSERT INTO test(col) VALUES(1)`) + if err != nil { + t.Fatal(err) + } + err = db.Exec(`COMMIT`) + if err != nil { + t.Fatal(err) + } + release(&err) + if err != nil { + t.Fatal(err) + } + + stmt, _, err := db.Prepare(`SELECT count(*) FROM test`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if stmt.Step() { + got := stmt.ColumnInt(0) + if got != 1 { + t.Errorf("got %d, want 1", got) + } + } + err = stmt.Err() + if err != nil { + t.Error(err) + } +}