Update, authorizer callbacks.

This commit is contained in:
Nuno Cruces
2024-01-27 10:57:46 +00:00
parent c9cc893ed7
commit 031087327d
8 changed files with 147 additions and 14 deletions

View File

@@ -65,6 +65,39 @@ func (c *Conn) Limit(id LimitCategory, value int) int {
return int(int32(r))
}
func authorizerCallback(ctx context.Context, mod api.Module, pDB, action, zName3rd, zName4th, zSchema, zInnerName uint32) uint32 {
return 0
// SetAuthorizer registers an authorizer callback with the database connection.
//
// https://sqlite.org/c3ref/set_authorizer.html
func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4th, schema, nameInner string) AuthorizerReturnCode) error {
var enable uint64
if cb != nil {
enable = 1
}
r := c.call("sqlite3_set_authorizer_go", uint64(c.handle), enable)
if err := c.error(r); err != nil {
return err
}
c.authorizer = cb
return nil
}
func authorizerCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zName3rd, zName4th, zSchema, zNameInner uint32) AuthorizerReturnCode {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.authorizer != nil {
var name3rd, name4th, schema, nameInner string
if zName3rd != 0 {
name3rd = util.ReadString(mod, zName3rd, _MAX_NAME)
}
if zName4th != 0 {
name4th = util.ReadString(mod, zName4th, _MAX_NAME)
}
if zSchema != 0 {
schema = util.ReadString(mod, zSchema, _MAX_NAME)
}
if zNameInner != 0 {
nameInner = util.ReadString(mod, zNameInner, _MAX_NAME)
}
return c.authorizer(action, name3rd, name4th, schema, nameInner)
}
return AUTH_OK
}

16
conn.go
View File

@@ -18,13 +18,15 @@ import (
type Conn struct {
*sqlite
interrupt context.Context
pending *Stmt
log func(code xErrorCode, msg string)
collation func(name string)
commit func() bool
rollback func()
arena arena
interrupt context.Context
pending *Stmt
log func(xErrorCode, string)
collation func(*Conn, string)
authorizer func(AuthorizerActionCode, string, string, string, string) AuthorizerReturnCode
update func(AuthorizerActionCode, string, string, int64)
commit func() bool
rollback func()
arena arena
handle uint32
}

View File

@@ -249,6 +249,62 @@ const (
LIMIT_WORKER_THREADS LimitCategory = 11
)
// AuthorizerActionCode are the integer action codes
// that the authorizer callback may be passed.
//
// https://sqlite.org/c3ref/c_alter_table.html
type AuthorizerActionCode uint32
const (
/************************************************ 3rd ************ 4th ***********/
CREATE_INDEX AuthorizerActionCode = 1 /* Index Name Table Name */
CREATE_TABLE AuthorizerActionCode = 2 /* Table Name NULL */
CREATE_TEMP_INDEX AuthorizerActionCode = 3 /* Index Name Table Name */
CREATE_TEMP_TABLE AuthorizerActionCode = 4 /* Table Name NULL */
CREATE_TEMP_TRIGGER AuthorizerActionCode = 5 /* Trigger Name Table Name */
CREATE_TEMP_VIEW AuthorizerActionCode = 6 /* View Name NULL */
CREATE_TRIGGER AuthorizerActionCode = 7 /* Trigger Name Table Name */
CREATE_VIEW AuthorizerActionCode = 8 /* View Name NULL */
DELETE AuthorizerActionCode = 9 /* Table Name NULL */
DROP_INDEX AuthorizerActionCode = 10 /* Index Name Table Name */
DROP_TABLE AuthorizerActionCode = 11 /* Table Name NULL */
DROP_TEMP_INDEX AuthorizerActionCode = 12 /* Index Name Table Name */
DROP_TEMP_TABLE AuthorizerActionCode = 13 /* Table Name NULL */
DROP_TEMP_TRIGGER AuthorizerActionCode = 14 /* Trigger Name Table Name */
DROP_TEMP_VIEW AuthorizerActionCode = 15 /* View Name NULL */
DROP_TRIGGER AuthorizerActionCode = 16 /* Trigger Name Table Name */
DROP_VIEW AuthorizerActionCode = 17 /* View Name NULL */
INSERT AuthorizerActionCode = 18 /* Table Name NULL */
PRAGMA AuthorizerActionCode = 19 /* Pragma Name 1st arg or NULL */
READ AuthorizerActionCode = 20 /* Table Name Column Name */
SELECT AuthorizerActionCode = 21 /* NULL NULL */
TRANSACTION AuthorizerActionCode = 22 /* Operation NULL */
UPDATE AuthorizerActionCode = 23 /* Table Name Column Name */
ATTACH AuthorizerActionCode = 24 /* Filename NULL */
DETACH AuthorizerActionCode = 25 /* Database Name NULL */
ALTER_TABLE AuthorizerActionCode = 26 /* Database Name Table Name */
REINDEX AuthorizerActionCode = 27 /* Index Name NULL */
ANALYZE AuthorizerActionCode = 28 /* Table Name NULL */
CREATE_VTABLE AuthorizerActionCode = 29 /* Table Name Module Name */
DROP_VTABLE AuthorizerActionCode = 30 /* Table Name Module Name */
FUNCTION AuthorizerActionCode = 31 /* NULL Function Name */
SAVEPOINT AuthorizerActionCode = 32 /* Operation Savepoint Name */
COPY AuthorizerActionCode = 0 /* No longer used */
RECURSIVE AuthorizerActionCode = 33 /* NULL NULL */
)
// AuthorizerReturnCode are the integer codes
// that the authorizer callback may return.
//
// https://sqlite.org/c3ref/c_deny.html
type AuthorizerReturnCode uint32
const (
AUTH_OK AuthorizerReturnCode = 0
AUTH_DENY AuthorizerReturnCode = 1 /* Abort the SQL statement with an error */
AUTH_IGNORE AuthorizerReturnCode = 2 /* Don't allow access, but don't generate an error */
)
// TxnState are the allowed return values from [Conn.TxnState].
//
// https://sqlite.org/c3ref/c_txn_none.html

View File

@@ -12,7 +12,7 @@ import (
// whenever an unknown collation sequence is required.
//
// https://sqlite.org/c3ref/collation_needed.html
func (c *Conn) CollationNeeded(cb func(name string)) error {
func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error {
var enable uint64
if cb != nil {
enable = 1
@@ -126,7 +126,7 @@ func destroyCallback(ctx context.Context, mod api.Module, pApp uint32) {
func collationCallback(ctx context.Context, mod api.Module, pArg, pDB, eTextRep, zName uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.collation != nil {
name := util.ReadString(mod, zName, _MAX_NAME)
c.collation(name)
c.collation(c, name)
}
}

View File

@@ -28,7 +28,7 @@ func ExampleConn_CreateCollation() {
log.Fatal(err)
}
err = db.CollationNeeded(func(name string) {
err = db.CollationNeeded(func(db *sqlite3.Conn, name string) {
err := unicode.RegisterCollation(db, name, name)
if err != nil {
log.Fatal(err)

View File

@@ -396,6 +396,28 @@ func TestConn_Limit(t *testing.T) {
}
}
func TestConn_SetAuthorizer(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.SetAuthorizer(func(action sqlite3.AuthorizerActionCode, name3rd, name4th, schema, nameInner string) sqlite3.AuthorizerReturnCode {
return sqlite3.AUTH_DENY
})
if err != nil {
t.Fatal(err)
}
err = db.Exec(`SELECT * FROM sqlite_schema`)
if !errors.Is(err, sqlite3.AUTH) {
t.Errorf("got %v, want sqlite3.AUTH", err)
}
}
func TestConn_ReleaseMemory(t *testing.T) {
t.Parallel()

View File

@@ -18,8 +18,9 @@ func TestConn_Transaction_exec(t *testing.T) {
}
defer db.Close()
db.CommitHook(func() (ok bool) { return true })
db.RollbackHook(func() {})
db.CommitHook(func() bool { return true })
db.UpdateHook(func(sqlite3.AuthorizerActionCode, string, string, int64) {})
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {

21
txn.go
View File

@@ -9,6 +9,7 @@ import (
"strconv"
"strings"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
)
@@ -259,6 +260,19 @@ func (c *Conn) RollbackHook(cb func()) {
c.rollback = cb
}
// RollbackHook registers a callback function to be invoked
// whenever a row is updated, inserted or deleted in a rowid table.
//
// https://sqlite.org/c3ref/update_hook.html
func (c *Conn) UpdateHook(cb func(action AuthorizerActionCode, schema, table string, rowid int64)) {
var enable uint64
if cb != nil {
enable = 1
}
c.call("sqlite3_update_hook_go", uint64(c.handle), enable)
c.update = cb
}
func commitCallback(ctx context.Context, mod api.Module, pDB uint32) uint32 {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.commit != nil {
if !c.commit() {
@@ -274,5 +288,10 @@ func rollbackCallback(ctx context.Context, mod api.Module, pDB uint32) {
}
}
func updateCallback(ctx context.Context, mod api.Module, pDB, action, zSchema, zTabName uint32, rowid uint64) {
func updateCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zSchema, zTabName uint32, rowid uint64) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.update != nil {
schema := util.ReadString(mod, zSchema, _MAX_NAME)
table := util.ReadString(mod, zTabName, _MAX_NAME)
c.update(action, schema, table, int64(rowid))
}
}