From c9cc893ed72d058a84e9ba45136e294b68c73d9d Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sat, 27 Jan 2024 10:05:31 +0000 Subject: [PATCH] Commit callback. --- config.go | 2 +- conn.go | 11 ++--------- tests/txn_test.go | 3 +++ txn.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 53 insertions(+), 10 deletions(-) diff --git a/config.go b/config.go index ed3973b..38f25e1 100644 --- a/config.go +++ b/config.go @@ -65,6 +65,6 @@ func (c *Conn) Limit(id LimitCategory, value int) int { return int(int32(r)) } -func authorizerCallback(ctx context.Context, mod api.Module, pDB, action, zName3d, zName4th, zSchema, zInnerName uint32) uint32 { +func authorizerCallback(ctx context.Context, mod api.Module, pDB, action, zName3rd, zName4th, zSchema, zInnerName uint32) uint32 { return 0 } diff --git a/conn.go b/conn.go index 410e1d1..96390e3 100644 --- a/conn.go +++ b/conn.go @@ -22,6 +22,8 @@ type Conn struct { pending *Stmt log func(code xErrorCode, msg string) collation func(name string) + commit func() bool + rollback func() arena arena handle uint32 @@ -327,15 +329,6 @@ func progressCallback(ctx context.Context, mod api.Module, _ uint32) uint32 { return 0 } -func commitCallback(ctx context.Context, mod api.Module, pDB uint32) uint32 { - return 0 -} - -func rollbackCallback(ctx context.Context, mod api.Module, pDB uint32) {} - -func updateCallback(ctx context.Context, mod api.Module, pDB, action, zSchema, zTabName uint32, rowid uint64) { -} - // Deprecated: executes a PRAGMA statement and returns results. func (c *Conn) Pragma(str string) ([]string, error) { stmt, _, err := c.Prepare(`PRAGMA ` + str) diff --git a/tests/txn_test.go b/tests/txn_test.go index 61fed91..db9793d 100644 --- a/tests/txn_test.go +++ b/tests/txn_test.go @@ -18,6 +18,9 @@ func TestConn_Transaction_exec(t *testing.T) { } defer db.Close() + db.CommitHook(func() (ok bool) { return true }) + db.RollbackHook(func() {}) + err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) if err != nil { t.Fatal(err) diff --git a/txn.go b/txn.go index be4bb86..f778220 100644 --- a/txn.go +++ b/txn.go @@ -8,6 +8,8 @@ import ( "runtime" "strconv" "strings" + + "github.com/tetratelabs/wazero/api" ) // Txn is an in-progress database transaction. @@ -229,3 +231,48 @@ func (c *Conn) TxnState(schema string) TxnState { // Deprecated: renamed for consistency with [Conn.TxnState]. type Tx = Txn + +// CommitHook registers a callback function to be invoked +// whenever a transaction is committed. +// Return true to allow the commit operation to continue normally. +// +// https://sqlite.org/c3ref/commit_hook.html +func (c *Conn) CommitHook(cb func() (ok bool)) { + var enable uint64 + if cb != nil { + enable = 1 + } + c.call("sqlite3_commit_hook_go", uint64(c.handle), enable) + c.commit = cb +} + +// RollbackHook registers a callback function to be invoked +// whenever a transaction is rolled back. +// +// https://sqlite.org/c3ref/commit_hook.html +func (c *Conn) RollbackHook(cb func()) { + var enable uint64 + if cb != nil { + enable = 1 + } + c.call("sqlite3_rollback_hook_go", uint64(c.handle), enable) + c.rollback = cb +} + +func commitCallback(ctx context.Context, mod api.Module, pDB uint32) uint32 { + if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.commit != nil { + if !c.commit() { + return 1 + } + } + return 0 +} + +func rollbackCallback(ctx context.Context, mod api.Module, pDB uint32) { + if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.rollback != nil { + c.rollback() + } +} + +func updateCallback(ctx context.Context, mod api.Module, pDB, action, zSchema, zTabName uint32, rowid uint64) { +}