diff --git a/config.go b/config.go index a21337b..ed3973b 100644 --- a/config.go +++ b/config.go @@ -64,3 +64,7 @@ func (c *Conn) Limit(id LimitCategory, value int) int { r := c.call("sqlite3_limit", uint64(c.handle), uint64(id), uint64(value)) return int(int32(r)) } + +func authorizerCallback(ctx context.Context, mod api.Module, pDB, action, zName3d, zName4th, zSchema, zInnerName uint32) uint32 { + return 0 +} diff --git a/conn.go b/conn.go index 6b1b867..a23dd60 100644 --- a/conn.go +++ b/conn.go @@ -326,9 +326,16 @@ func progressCallback(ctx context.Context, mod api.Module, _ uint32) uint32 { return 0 } -// Pragma executes a PRAGMA statement and returns any results. -// -// https://sqlite.org/pragma.html +func commitCallback(ctx context.Context, mod api.Module, pDB uint32) uint32 { + return 0 +} + +func rollbackCallback(ctx context.Context, mod api.Module, pDB uint32) {} + +func updateCallback(ctx context.Context, mod api.Module, pDB, action, zSchema, zTabName uint32, rowid uint64) { +} + +// Deprecated: executes a PRAGMA statement and returns results. func (c *Conn) Pragma(str string) ([]string, error) { stmt, _, err := c.Prepare(`PRAGMA ` + str) if err != nil { diff --git a/embed/exports.txt b/embed/exports.txt index 7a556f8..1ddb57c 100644 --- a/embed/exports.txt +++ b/embed/exports.txt @@ -28,6 +28,7 @@ sqlite3_changes64 sqlite3_clear_bindings sqlite3_close sqlite3_close_v2 +sqlite3_collation_needed_go sqlite3_column_blob sqlite3_column_bytes sqlite3_column_count @@ -39,6 +40,7 @@ sqlite3_column_text sqlite3_column_type sqlite3_column_value sqlite3_columns_go +sqlite3_commit_hook_go sqlite3_config_log_go sqlite3_create_aggregate_function_go sqlite3_create_collation_go @@ -78,6 +80,8 @@ sqlite3_result_pointer_go sqlite3_result_text64 sqlite3_result_value sqlite3_result_zeroblob64 +sqlite3_rollback_hook_go +sqlite3_set_authorizer_go sqlite3_set_auxdata_go sqlite3_set_last_insert_rowid sqlite3_step @@ -86,6 +90,7 @@ sqlite3_stmt_readonly sqlite3_stmt_status sqlite3_total_changes64 sqlite3_txn_state +sqlite3_update_hook_go sqlite3_uri_key sqlite3_uri_parameter sqlite3_value_blob diff --git a/embed/sqlite3.wasm b/embed/sqlite3.wasm index f502e2b..b0aed3c 100755 Binary files a/embed/sqlite3.wasm and b/embed/sqlite3.wasm differ diff --git a/func.go b/func.go index 9544ecc..7c69963 100644 --- a/func.go +++ b/func.go @@ -106,6 +106,8 @@ func destroyCallback(ctx context.Context, mod api.Module, pApp uint32) { util.DelHandle(ctx, pApp) } +func collationCallback(ctx context.Context, mod api.Module, pArg, pDB, eTextRep, zName uint32) {} + func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 { fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int) return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2)))) diff --git a/internal/util/func.go b/internal/util/func.go index 47dc890..be7a47c 100644 --- a/internal/util/func.go +++ b/internal/util/func.go @@ -75,6 +75,19 @@ func ExportFuncVIIIII[T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder, name Export(name) } +type funcVIIIIJ[T0, T1, T2, T3 i32, T4 i64] func(context.Context, api.Module, T0, T1, T2, T3, T4) + +func (fn funcVIIIIJ[T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) { + fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4])) +} + +func ExportFuncVIIIIJ[T0, T1, T2, T3 i32, T4 i64](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3, T4)) { + mod.NewFunctionBuilder(). + WithGoModuleFunction(funcVIIIIJ[T0, T1, T2, T3, T4](fn), + []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI64}, nil). + Export(name) +} + type funcII[TR, T0 i32] func(context.Context, api.Module, T0) TR func (fn funcII[TR, T0]) Call(ctx context.Context, mod api.Module, stack []uint64) { @@ -140,6 +153,19 @@ func ExportFuncIIIIII[TR, T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder, Export(name) } +type funcIIIIIII[TR, T0, T1, T2, T3, T4, T5 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4, T5) TR + +func (fn funcIIIIIII[TR, T0, T1, T2, T3, T4, T5]) Call(ctx context.Context, mod api.Module, stack []uint64) { + stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]), T5(stack[5]))) +} + +func ExportFuncIIIIIII[TR, T0, T1, T2, T3, T4, T5 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3, T4, T5) TR) { + mod.NewFunctionBuilder(). + WithGoModuleFunction(funcIIIIIII[TR, T0, T1, T2, T3, T4, T5](fn), + []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). + Export(name) +} + type funcIIIIJ[TR, T0, T1, T2 i32, T3 i64] func(context.Context, api.Module, T0, T1, T2, T3) TR func (fn funcIIIIJ[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) { diff --git a/sqlite.go b/sqlite.go index 5d2240a..9069ef2 100644 --- a/sqlite.go +++ b/sqlite.go @@ -290,6 +290,10 @@ func (a *arena) string(s string) uint32 { func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { util.ExportFuncII(env, "go_progress", progressCallback) + util.ExportFuncII(env, "go_commit_hook", commitCallback) + util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback) + util.ExportFuncVIIIIJ(env, "go_update_hook", updateCallback) + util.ExportFuncIIIIIII(env, "go_authorizer", authorizerCallback) util.ExportFuncVIII(env, "go_log", logCallback) util.ExportFuncVI(env, "go_destroy", destroyCallback) util.ExportFuncVIIII(env, "go_func", funcCallback) @@ -297,9 +301,10 @@ func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { util.ExportFuncVIII(env, "go_final", finalCallback) util.ExportFuncVII(env, "go_value", valueCallback) util.ExportFuncVIIII(env, "go_inverse", inverseCallback) + util.ExportFuncVIIII(env, "go_collation_needed", collationCallback) util.ExportFuncIIIIII(env, "go_compare", compareCallback) - util.ExportFuncIIIIII(env, "go_vtab_create", vtabModuleCallback(0)) - util.ExportFuncIIIIII(env, "go_vtab_connect", vtabModuleCallback(1)) + util.ExportFuncIIIIII(env, "go_vtab_create", vtabModuleCallback(xCreate)) + util.ExportFuncIIIIII(env, "go_vtab_connect", vtabModuleCallback(xConnect)) util.ExportFuncII(env, "go_vtab_disconnect", vtabDisconnectCallback) util.ExportFuncII(env, "go_vtab_destroy", vtabDestroyCallback) util.ExportFuncIII(env, "go_vtab_best_index", vtabBestIndexCallback) diff --git a/sqlite3/func.c b/sqlite3/func.c index adff527..240ec3e 100644 --- a/sqlite3/func.c +++ b/sqlite3/func.c @@ -1,8 +1,11 @@ +#include #include #include "include.h" #include "sqlite3.h" +void go_collation_needed(void *, sqlite3 *, int, const char *); + int go_compare(go_handle, int, const void *, int, const void *); void go_func(sqlite3_context *, go_handle, int, sqlite3_value **); @@ -44,6 +47,11 @@ void go_inverse_wrapper(sqlite3_context *ctx, int nArg, sqlite3_value **pArg) { go_inverse(ctx, *agg, nArg, pArg); } +int sqlite3_collation_needed_go(sqlite3 *db, bool enable) { + return sqlite3_collation_needed(db, /*arg=*/NULL, + enable ? go_collation_needed : NULL); +} + int sqlite3_create_collation_go(sqlite3 *db, const char *name, go_handle app) { int rc = sqlite3_create_collation_v2(db, name, SQLITE_UTF8, app, go_compare, go_destroy); diff --git a/sqlite3/hooks.c b/sqlite3/hooks.c new file mode 100644 index 0000000..b44f9a8 --- /dev/null +++ b/sqlite3/hooks.c @@ -0,0 +1,39 @@ +#include + +#include "sqlite3.h" + +int go_progress(void *); + +int go_commit_hook(void *); +void go_rollback_hook(void *); +void go_update_hook(void *, int, char const *, char const *, sqlite3_int64); + +int go_authorizer(void *, int, const char *, const char *, const char *, + const char *); + +void go_log(void *, int, const char *); + +void sqlite3_progress_handler_go(sqlite3 *db, int n) { + sqlite3_progress_handler(db, n, go_progress, /*arg=*/db); +} + +void sqlite3_commit_hook_go(sqlite3 *db, bool enable) { + sqlite3_commit_hook(db, enable ? go_commit_hook : NULL, /*arg=*/db); +} + +void sqlite3_rollback_hook_go(sqlite3 *db, bool enable) { + sqlite3_rollback_hook(db, enable ? go_rollback_hook : NULL, /*arg=*/db); +} + +void sqlite3_update_hook_go(sqlite3 *db, bool enable) { + sqlite3_update_hook(db, enable ? go_update_hook : NULL, /*arg=*/db); +} + +int sqlite3_set_authorizer_go(sqlite3 *db, bool enable) { + return sqlite3_set_authorizer(db, enable ? go_authorizer : NULL, /*arg=*/db); +} + +int sqlite3_config_log_go(bool enable) { + return sqlite3_config(SQLITE_CONFIG_LOG, enable ? go_log : NULL, + /*arg=*/NULL); +} \ No newline at end of file diff --git a/sqlite3/log.c b/sqlite3/log.c deleted file mode 100644 index bfcb66e..0000000 --- a/sqlite3/log.c +++ /dev/null @@ -1,9 +0,0 @@ -#include - -#include "sqlite3.h" - -void go_log(void *, int, const char *); - -int sqlite3_config_log_go(bool enable) { - return sqlite3_config(SQLITE_CONFIG_LOG, enable ? go_log : NULL, NULL); -} \ No newline at end of file diff --git a/sqlite3/main.c b/sqlite3/main.c index eb5074a..50a03b8 100644 --- a/sqlite3/main.c +++ b/sqlite3/main.c @@ -12,9 +12,8 @@ // Bindings #include "column.c" #include "func.c" -#include "log.c" +#include "hooks.c" #include "pointer.c" -#include "progress.c" #include "time.c" #include "vfs.c" #include "vtab.c" diff --git a/sqlite3/progress.c b/sqlite3/progress.c deleted file mode 100644 index 11c1551..0000000 --- a/sqlite3/progress.c +++ /dev/null @@ -1,9 +0,0 @@ -#include - -#include "sqlite3.h" - -int go_progress(void *); - -void sqlite3_progress_handler_go(sqlite3 *db, int n) { - sqlite3_progress_handler(db, n, go_progress, /*arg=*/NULL); -} \ No newline at end of file diff --git a/vtab.go b/vtab.go index 0455ead..784fce2 100644 --- a/vtab.go +++ b/vtab.go @@ -128,6 +128,13 @@ type VTabConstructor[T VTab] func(db *Conn, module, schema, table string, arg .. type module[T VTab] [2]VTabConstructor[T] +type vtabConstructor int + +const ( + xCreate vtabConstructor = 0 + xConnect vtabConstructor = 1 +) + // A VTab describes a particular instance of the virtual table. // A VTab may optionally implement [io.Closer] to free resources. // @@ -414,7 +421,7 @@ const ( INDEX_SCAN_UNIQUE IndexScanFlag = 1 ) -func vtabModuleCallback(i int) func(_ context.Context, _ api.Module, _, _, _, _, _ uint32) uint32 { +func vtabModuleCallback(i vtabConstructor) func(_ context.Context, _ api.Module, _, _, _, _, _ uint32) uint32 { return func(ctx context.Context, mod api.Module, pMod, nArg, pArg, ppVTab, pzErr uint32) uint32 { arg := make([]reflect.Value, 1+nArg) arg[0] = reflect.ValueOf(ctx.Value(connKey{})) @@ -425,7 +432,7 @@ func vtabModuleCallback(i int) func(_ context.Context, _ api.Module, _, _, _, _, } module := vtabGetHandle(ctx, mod, pMod) - res := reflect.ValueOf(module).Index(i).Call(arg) + res := reflect.ValueOf(module).Index(int(i)).Call(arg) err, _ := res[1].Interface().(error) if err == nil { vtabPutHandle(ctx, mod, ppVTab, res[0].Interface())