From 031087327dd48f3409a0234aecc48419524c0ef9 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sat, 27 Jan 2024 10:57:46 +0000 Subject: [PATCH] Update, authorizer callbacks. --- config.go | 37 ++++++++++++++++++++++++++++-- conn.go | 16 +++++++------ const.go | 56 ++++++++++++++++++++++++++++++++++++++++++++++ func.go | 4 ++-- func_test.go | 2 +- tests/conn_test.go | 22 ++++++++++++++++++ tests/txn_test.go | 3 ++- txn.go | 21 ++++++++++++++++- 8 files changed, 147 insertions(+), 14 deletions(-) diff --git a/config.go b/config.go index 38f25e1..2f82d08 100644 --- a/config.go +++ b/config.go @@ -65,6 +65,39 @@ func (c *Conn) Limit(id LimitCategory, value int) int { return int(int32(r)) } -func authorizerCallback(ctx context.Context, mod api.Module, pDB, action, zName3rd, zName4th, zSchema, zInnerName uint32) uint32 { - return 0 +// SetAuthorizer registers an authorizer callback with the database connection. +// +// https://sqlite.org/c3ref/set_authorizer.html +func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4th, schema, nameInner string) AuthorizerReturnCode) error { + var enable uint64 + if cb != nil { + enable = 1 + } + r := c.call("sqlite3_set_authorizer_go", uint64(c.handle), enable) + if err := c.error(r); err != nil { + return err + } + c.authorizer = cb + return nil + +} + +func authorizerCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zName3rd, zName4th, zSchema, zNameInner uint32) AuthorizerReturnCode { + if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.authorizer != nil { + var name3rd, name4th, schema, nameInner string + if zName3rd != 0 { + name3rd = util.ReadString(mod, zName3rd, _MAX_NAME) + } + if zName4th != 0 { + name4th = util.ReadString(mod, zName4th, _MAX_NAME) + } + if zSchema != 0 { + schema = util.ReadString(mod, zSchema, _MAX_NAME) + } + if zNameInner != 0 { + nameInner = util.ReadString(mod, zNameInner, _MAX_NAME) + } + return c.authorizer(action, name3rd, name4th, schema, nameInner) + } + return AUTH_OK } diff --git a/conn.go b/conn.go index 96390e3..3636336 100644 --- a/conn.go +++ b/conn.go @@ -18,13 +18,15 @@ import ( type Conn struct { *sqlite - interrupt context.Context - pending *Stmt - log func(code xErrorCode, msg string) - collation func(name string) - commit func() bool - rollback func() - arena arena + interrupt context.Context + pending *Stmt + log func(xErrorCode, string) + collation func(*Conn, string) + authorizer func(AuthorizerActionCode, string, string, string, string) AuthorizerReturnCode + update func(AuthorizerActionCode, string, string, int64) + commit func() bool + rollback func() + arena arena handle uint32 } diff --git a/const.go b/const.go index 49ac2d1..63e5c5a 100644 --- a/const.go +++ b/const.go @@ -249,6 +249,62 @@ const ( LIMIT_WORKER_THREADS LimitCategory = 11 ) +// AuthorizerActionCode are the integer action codes +// that the authorizer callback may be passed. +// +// https://sqlite.org/c3ref/c_alter_table.html +type AuthorizerActionCode uint32 + +const ( + /************************************************ 3rd ************ 4th ***********/ + CREATE_INDEX AuthorizerActionCode = 1 /* Index Name Table Name */ + CREATE_TABLE AuthorizerActionCode = 2 /* Table Name NULL */ + CREATE_TEMP_INDEX AuthorizerActionCode = 3 /* Index Name Table Name */ + CREATE_TEMP_TABLE AuthorizerActionCode = 4 /* Table Name NULL */ + CREATE_TEMP_TRIGGER AuthorizerActionCode = 5 /* Trigger Name Table Name */ + CREATE_TEMP_VIEW AuthorizerActionCode = 6 /* View Name NULL */ + CREATE_TRIGGER AuthorizerActionCode = 7 /* Trigger Name Table Name */ + CREATE_VIEW AuthorizerActionCode = 8 /* View Name NULL */ + DELETE AuthorizerActionCode = 9 /* Table Name NULL */ + DROP_INDEX AuthorizerActionCode = 10 /* Index Name Table Name */ + DROP_TABLE AuthorizerActionCode = 11 /* Table Name NULL */ + DROP_TEMP_INDEX AuthorizerActionCode = 12 /* Index Name Table Name */ + DROP_TEMP_TABLE AuthorizerActionCode = 13 /* Table Name NULL */ + DROP_TEMP_TRIGGER AuthorizerActionCode = 14 /* Trigger Name Table Name */ + DROP_TEMP_VIEW AuthorizerActionCode = 15 /* View Name NULL */ + DROP_TRIGGER AuthorizerActionCode = 16 /* Trigger Name Table Name */ + DROP_VIEW AuthorizerActionCode = 17 /* View Name NULL */ + INSERT AuthorizerActionCode = 18 /* Table Name NULL */ + PRAGMA AuthorizerActionCode = 19 /* Pragma Name 1st arg or NULL */ + READ AuthorizerActionCode = 20 /* Table Name Column Name */ + SELECT AuthorizerActionCode = 21 /* NULL NULL */ + TRANSACTION AuthorizerActionCode = 22 /* Operation NULL */ + UPDATE AuthorizerActionCode = 23 /* Table Name Column Name */ + ATTACH AuthorizerActionCode = 24 /* Filename NULL */ + DETACH AuthorizerActionCode = 25 /* Database Name NULL */ + ALTER_TABLE AuthorizerActionCode = 26 /* Database Name Table Name */ + REINDEX AuthorizerActionCode = 27 /* Index Name NULL */ + ANALYZE AuthorizerActionCode = 28 /* Table Name NULL */ + CREATE_VTABLE AuthorizerActionCode = 29 /* Table Name Module Name */ + DROP_VTABLE AuthorizerActionCode = 30 /* Table Name Module Name */ + FUNCTION AuthorizerActionCode = 31 /* NULL Function Name */ + SAVEPOINT AuthorizerActionCode = 32 /* Operation Savepoint Name */ + COPY AuthorizerActionCode = 0 /* No longer used */ + RECURSIVE AuthorizerActionCode = 33 /* NULL NULL */ +) + +// AuthorizerReturnCode are the integer codes +// that the authorizer callback may return. +// +// https://sqlite.org/c3ref/c_deny.html +type AuthorizerReturnCode uint32 + +const ( + AUTH_OK AuthorizerReturnCode = 0 + AUTH_DENY AuthorizerReturnCode = 1 /* Abort the SQL statement with an error */ + AUTH_IGNORE AuthorizerReturnCode = 2 /* Don't allow access, but don't generate an error */ +) + // TxnState are the allowed return values from [Conn.TxnState]. // // https://sqlite.org/c3ref/c_txn_none.html diff --git a/func.go b/func.go index 85650d4..255584a 100644 --- a/func.go +++ b/func.go @@ -12,7 +12,7 @@ import ( // whenever an unknown collation sequence is required. // // https://sqlite.org/c3ref/collation_needed.html -func (c *Conn) CollationNeeded(cb func(name string)) error { +func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error { var enable uint64 if cb != nil { enable = 1 @@ -126,7 +126,7 @@ func destroyCallback(ctx context.Context, mod api.Module, pApp uint32) { func collationCallback(ctx context.Context, mod api.Module, pArg, pDB, eTextRep, zName uint32) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.collation != nil { name := util.ReadString(mod, zName, _MAX_NAME) - c.collation(name) + c.collation(c, name) } } diff --git a/func_test.go b/func_test.go index baf4050..12a41a5 100644 --- a/func_test.go +++ b/func_test.go @@ -28,7 +28,7 @@ func ExampleConn_CreateCollation() { log.Fatal(err) } - err = db.CollationNeeded(func(name string) { + err = db.CollationNeeded(func(db *sqlite3.Conn, name string) { err := unicode.RegisterCollation(db, name, name) if err != nil { log.Fatal(err) diff --git a/tests/conn_test.go b/tests/conn_test.go index 72583c6..f2937fa 100644 --- a/tests/conn_test.go +++ b/tests/conn_test.go @@ -396,6 +396,28 @@ func TestConn_Limit(t *testing.T) { } } +func TestConn_SetAuthorizer(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.SetAuthorizer(func(action sqlite3.AuthorizerActionCode, name3rd, name4th, schema, nameInner string) sqlite3.AuthorizerReturnCode { + return sqlite3.AUTH_DENY + }) + if err != nil { + t.Fatal(err) + } + + err = db.Exec(`SELECT * FROM sqlite_schema`) + if !errors.Is(err, sqlite3.AUTH) { + t.Errorf("got %v, want sqlite3.AUTH", err) + } +} + func TestConn_ReleaseMemory(t *testing.T) { t.Parallel() diff --git a/tests/txn_test.go b/tests/txn_test.go index db9793d..3699ffe 100644 --- a/tests/txn_test.go +++ b/tests/txn_test.go @@ -18,8 +18,9 @@ func TestConn_Transaction_exec(t *testing.T) { } defer db.Close() - db.CommitHook(func() (ok bool) { return true }) db.RollbackHook(func() {}) + db.CommitHook(func() bool { return true }) + db.UpdateHook(func(sqlite3.AuthorizerActionCode, string, string, int64) {}) err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) if err != nil { diff --git a/txn.go b/txn.go index f778220..223ffa1 100644 --- a/txn.go +++ b/txn.go @@ -9,6 +9,7 @@ import ( "strconv" "strings" + "github.com/ncruces/go-sqlite3/internal/util" "github.com/tetratelabs/wazero/api" ) @@ -259,6 +260,19 @@ func (c *Conn) RollbackHook(cb func()) { c.rollback = cb } +// RollbackHook registers a callback function to be invoked +// whenever a row is updated, inserted or deleted in a rowid table. +// +// https://sqlite.org/c3ref/update_hook.html +func (c *Conn) UpdateHook(cb func(action AuthorizerActionCode, schema, table string, rowid int64)) { + var enable uint64 + if cb != nil { + enable = 1 + } + c.call("sqlite3_update_hook_go", uint64(c.handle), enable) + c.update = cb +} + func commitCallback(ctx context.Context, mod api.Module, pDB uint32) uint32 { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.commit != nil { if !c.commit() { @@ -274,5 +288,10 @@ func rollbackCallback(ctx context.Context, mod api.Module, pDB uint32) { } } -func updateCallback(ctx context.Context, mod api.Module, pDB, action, zSchema, zTabName uint32, rowid uint64) { +func updateCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zSchema, zTabName uint32, rowid uint64) { + if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.update != nil { + schema := util.ReadString(mod, zSchema, _MAX_NAME) + table := util.ReadString(mod, zTabName, _MAX_NAME) + c.update(action, schema, table, int64(rowid)) + } }