From 23aad5f62f0f338ff7ca994ec42993d7adf61adf Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Mon, 29 Sep 2025 12:45:18 +0100 Subject: [PATCH] MVCC API. --- tests/parallel/parallel_test.go | 12 ++-- vfs/mvcc/api.go | 109 ++++++++++++++++++++++++++------ vfs/mvcc/example_test.go | 2 +- vfs/mvcc/mvcc.go | 10 --- vfs/mvcc/mvcc_test.go | 5 +- vfs/tests/mptest/mptest_test.go | 6 +- 6 files changed, 100 insertions(+), 44 deletions(-) diff --git a/tests/parallel/parallel_test.go b/tests/parallel/parallel_test.go index 3e1630a..0eb672b 100644 --- a/tests/parallel/parallel_test.go +++ b/tests/parallel/parallel_test.go @@ -107,9 +107,9 @@ func Test_mvcc(t *testing.T) { iter = 5000 } - mvcc.Create("test.db", "") - name := "file:/test.db?vfs=mvcc" + - "&_pragma=busy_timeout(10000)" + name := mvcc.TestDB(t, mvcc.Snapshot{}, url.Values{ + "_pragma": {"busy_timeout(10000)"}, + }) createDB(t, name) testParallel(t, name, iter) testIntegrity(t, name) @@ -330,9 +330,9 @@ func Benchmark_memdb(b *testing.B) { } func Benchmark_mvcc(b *testing.B) { - mvcc.Create("test.db", "") - name := "file:/test.db?vfs=mvcc" + - "&_pragma=busy_timeout(10000)" + name := mvcc.TestDB(b, mvcc.Snapshot{}, url.Values{ + "_pragma": {"busy_timeout(10000)"}, + }) createDB(b, name) b.ResetTimer() diff --git a/vfs/mvcc/api.go b/vfs/mvcc/api.go index c131250..d8956fa 100644 --- a/vfs/mvcc/api.go +++ b/vfs/mvcc/api.go @@ -10,9 +10,15 @@ package mvcc import ( + "crypto/rand" + "fmt" + "net/url" + "strings" "sync" + "testing" "github.com/ncruces/go-sqlite3/vfs" + "github.com/ncruces/wbt" ) func init() { @@ -26,42 +32,103 @@ var ( ) // Create creates a shared memory database, -// using data as its initial contents. -func Create(name string, data string) { +// using a snapshot as its initial contents. +func Create(name string, snapshot Snapshot) { memoryMtx.Lock() defer memoryMtx.Unlock() - db := &mvccDB{ + memoryDBs[name] = &mvccDB{ refs: 1, name: name, - } - memoryDBs[name] = db - if len(data) == 0 { - return - } - // Convert data from WAL/2 to rollback journal. - if len(data) >= 20 && (false || - data[18] == 2 && data[19] == 2 || - data[18] == 3 && data[19] == 3) { - db.data = db.data. - Put(0, data[:18]). - Put(18, "\001\001"). - Put(20, data[20:]) - } else { - db.data = db.data.Put(0, data) + data: snapshot.Tree, } } // Delete deletes a shared memory database. func Delete(name string) { + name = getName(name) + memoryMtx.Lock() defer memoryMtx.Unlock() delete(memoryDBs, name) } -// Snapshot stores a snapshot of database src into dst. -func Snapshot(dst, src string) { +// Snapshot represents a database snapshot. +type Snapshot struct { + *wbt.Tree[int64, string] +} + +// NewSnapshot creates a snapshot from data. +func NewSnapshot(data string) Snapshot { + var tree *wbt.Tree[int64, string] + + // Convert data from WAL/2 to rollback journal. + if len(data) >= 20 && (false || + data[18] == 2 && data[19] == 2 || + data[18] == 3 && data[19] == 3) { + tree = tree. + Put(0, data[:18]). + Put(18, "\001\001"). + Put(20, data[20:]) + } else if len(data) > 0 { + tree = tree.Put(0, data) + } + + return Snapshot{tree} +} + +// TakeSnapshot takes a snapshot of a database. +// Name may be a URI filename. +func TakeSnapshot(name string) Snapshot { + name = getName(name) + memoryMtx.Lock() defer memoryMtx.Unlock() - memoryDBs[dst] = memoryDBs[src].fork() + db := memoryDBs[name] + if db == nil { + return Snapshot{} + } + + db.mtx.Lock() + defer db.mtx.Unlock() + return Snapshot{db.data} +} + +// TestDB creates a shared database from a snapshot for the test to use. +// The database is automatically deleted when the test and all its subtests complete. +// Returns a URI filename appropriate to call Open with. +// Each subsequent call to TestDB returns a unique database. +func TestDB(tb testing.TB, snapshot Snapshot, params ...url.Values) string { + tb.Helper() + + name := fmt.Sprintf("%s_%s", tb.Name(), rand.Text()) + tb.Cleanup(func() { Delete(name) }) + Create(name, snapshot) + + p := url.Values{"vfs": {"mvcc"}} + for _, v := range params { + for k, v := range v { + for _, v := range v { + p.Add(k, v) + } + } + } + + return (&url.URL{ + Scheme: "file", + OmitHost: true, + Path: "/" + name, + RawQuery: p.Encode(), + }).String() +} + +func getName(dsn string) string { + u, err := url.Parse(dsn) + if err == nil && + u.Scheme == "file" && + strings.HasPrefix(u.Path, "/") && + u.Query().Get("vfs") == "mvcc" { + return u.Path[1:] + } + return dsn } diff --git a/vfs/mvcc/example_test.go b/vfs/mvcc/example_test.go index d2c98a6..5598155 100644 --- a/vfs/mvcc/example_test.go +++ b/vfs/mvcc/example_test.go @@ -15,7 +15,7 @@ import ( var testDB string func Example() { - mvcc.Create("test.db", testDB) + mvcc.Create("test.db", mvcc.NewSnapshot(testDB)) db, err := sql.Open("sqlite3", "file:/test.db?vfs=mvcc") if err != nil { diff --git a/vfs/mvcc/mvcc.go b/vfs/mvcc/mvcc.go index b49e90e..47b4e7a 100644 --- a/vfs/mvcc/mvcc.go +++ b/vfs/mvcc/mvcc.go @@ -85,16 +85,6 @@ func (m *mvccDB) release() { } } -func (m *mvccDB) fork() *mvccDB { - m.mtx.Lock() - defer m.mtx.Unlock() - return &mvccDB{ - refs: 1, - name: m.name, - data: m.data, - } -} - type mvccFile struct { *mvccDB data *wbt.Tree[int64, string] diff --git a/vfs/mvcc/mvcc_test.go b/vfs/mvcc/mvcc_test.go index 8715adb..59a9552 100644 --- a/vfs/mvcc/mvcc_test.go +++ b/vfs/mvcc/mvcc_test.go @@ -14,10 +14,9 @@ var walDB string func Test_wal(t *testing.T) { t.Parallel() + dsn := TestDB(t, NewSnapshot(walDB)) - Create("test.db", walDB) - - db, err := sqlite3.Open("file:/test.db?vfs=mvcc") + db, err := sqlite3.Open(dsn) if err != nil { t.Fatal(err) } diff --git a/vfs/tests/mptest/mptest_test.go b/vfs/tests/mptest/mptest_test.go index 82f6719..107886c 100644 --- a/vfs/tests/mptest/mptest_test.go +++ b/vfs/tests/mptest/mptest_test.go @@ -197,7 +197,7 @@ func Test_multiwrite01_memory(t *testing.T) { } func Test_config01_mvcc(t *testing.T) { - mvcc.Create("test.db", "") + mvcc.Create("test.db", mvcc.Snapshot{}) ctx := util.NewContext(newContext(t)) cfg := config(ctx).WithArgs("mptest", "/test.db", "config01.test", "--vfs", "mvcc") @@ -213,7 +213,7 @@ func Test_crash01_mvcc(t *testing.T) { t.Skip("skipping in short mode") } - mvcc.Create("test.db", "") + mvcc.Create("test.db", mvcc.Snapshot{}) ctx := util.NewContext(newContext(t)) cfg := config(ctx).WithArgs("mptest", "/test.db", "crash01.test", "--vfs", "mvcc") @@ -229,7 +229,7 @@ func Test_multiwrite01_mvcc(t *testing.T) { t.Skip("skipping in slow CI") } - mvcc.Create("test.db", "") + mvcc.Create("test.db", mvcc.Snapshot{}) ctx := util.NewContext(newContext(t)) cfg := config(ctx).WithArgs("mptest", "/test.db", "multiwrite01.test", "--vfs", "mvcc")