Towards callbacks.

This commit is contained in:
Nuno Cruces
2024-01-26 15:41:36 +00:00
parent 88cf845651
commit 019c71fb55
13 changed files with 111 additions and 27 deletions

View File

@@ -64,3 +64,7 @@ func (c *Conn) Limit(id LimitCategory, value int) int {
r := c.call("sqlite3_limit", uint64(c.handle), uint64(id), uint64(value)) r := c.call("sqlite3_limit", uint64(c.handle), uint64(id), uint64(value))
return int(int32(r)) return int(int32(r))
} }
func authorizerCallback(ctx context.Context, mod api.Module, pDB, action, zName3d, zName4th, zSchema, zInnerName uint32) uint32 {
return 0
}

13
conn.go
View File

@@ -326,9 +326,16 @@ func progressCallback(ctx context.Context, mod api.Module, _ uint32) uint32 {
return 0 return 0
} }
// Pragma executes a PRAGMA statement and returns any results. func commitCallback(ctx context.Context, mod api.Module, pDB uint32) uint32 {
// return 0
// https://sqlite.org/pragma.html }
func rollbackCallback(ctx context.Context, mod api.Module, pDB uint32) {}
func updateCallback(ctx context.Context, mod api.Module, pDB, action, zSchema, zTabName uint32, rowid uint64) {
}
// Deprecated: executes a PRAGMA statement and returns results.
func (c *Conn) Pragma(str string) ([]string, error) { func (c *Conn) Pragma(str string) ([]string, error) {
stmt, _, err := c.Prepare(`PRAGMA ` + str) stmt, _, err := c.Prepare(`PRAGMA ` + str)
if err != nil { if err != nil {

View File

@@ -28,6 +28,7 @@ sqlite3_changes64
sqlite3_clear_bindings sqlite3_clear_bindings
sqlite3_close sqlite3_close
sqlite3_close_v2 sqlite3_close_v2
sqlite3_collation_needed_go
sqlite3_column_blob sqlite3_column_blob
sqlite3_column_bytes sqlite3_column_bytes
sqlite3_column_count sqlite3_column_count
@@ -39,6 +40,7 @@ sqlite3_column_text
sqlite3_column_type sqlite3_column_type
sqlite3_column_value sqlite3_column_value
sqlite3_columns_go sqlite3_columns_go
sqlite3_commit_hook_go
sqlite3_config_log_go sqlite3_config_log_go
sqlite3_create_aggregate_function_go sqlite3_create_aggregate_function_go
sqlite3_create_collation_go sqlite3_create_collation_go
@@ -78,6 +80,8 @@ sqlite3_result_pointer_go
sqlite3_result_text64 sqlite3_result_text64
sqlite3_result_value sqlite3_result_value
sqlite3_result_zeroblob64 sqlite3_result_zeroblob64
sqlite3_rollback_hook_go
sqlite3_set_authorizer_go
sqlite3_set_auxdata_go sqlite3_set_auxdata_go
sqlite3_set_last_insert_rowid sqlite3_set_last_insert_rowid
sqlite3_step sqlite3_step
@@ -86,6 +90,7 @@ sqlite3_stmt_readonly
sqlite3_stmt_status sqlite3_stmt_status
sqlite3_total_changes64 sqlite3_total_changes64
sqlite3_txn_state sqlite3_txn_state
sqlite3_update_hook_go
sqlite3_uri_key sqlite3_uri_key
sqlite3_uri_parameter sqlite3_uri_parameter
sqlite3_value_blob sqlite3_value_blob

Binary file not shown.

View File

@@ -106,6 +106,8 @@ func destroyCallback(ctx context.Context, mod api.Module, pApp uint32) {
util.DelHandle(ctx, pApp) util.DelHandle(ctx, pApp)
} }
func collationCallback(ctx context.Context, mod api.Module, pArg, pDB, eTextRep, zName uint32) {}
func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 { func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int) fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2)))) return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))

View File

@@ -75,6 +75,19 @@ func ExportFuncVIIIII[T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder, name
Export(name) Export(name)
} }
type funcVIIIIJ[T0, T1, T2, T3 i32, T4 i64] func(context.Context, api.Module, T0, T1, T2, T3, T4)
func (fn funcVIIIIJ[T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) {
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]))
}
func ExportFuncVIIIIJ[T0, T1, T2, T3 i32, T4 i64](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3, T4)) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcVIIIIJ[T0, T1, T2, T3, T4](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI64}, nil).
Export(name)
}
type funcII[TR, T0 i32] func(context.Context, api.Module, T0) TR type funcII[TR, T0 i32] func(context.Context, api.Module, T0) TR
func (fn funcII[TR, T0]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcII[TR, T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {
@@ -140,6 +153,19 @@ func ExportFuncIIIIII[TR, T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder,
Export(name) Export(name)
} }
type funcIIIIIII[TR, T0, T1, T2, T3, T4, T5 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4, T5) TR
func (fn funcIIIIIII[TR, T0, T1, T2, T3, T4, T5]) Call(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]), T5(stack[5])))
}
func ExportFuncIIIIIII[TR, T0, T1, T2, T3, T4, T5 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3, T4, T5) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIIIIII[TR, T0, T1, T2, T3, T4, T5](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIIIJ[TR, T0, T1, T2 i32, T3 i64] func(context.Context, api.Module, T0, T1, T2, T3) TR type funcIIIIJ[TR, T0, T1, T2 i32, T3 i64] func(context.Context, api.Module, T0, T1, T2, T3) TR
func (fn funcIIIIJ[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcIIIIJ[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {

View File

@@ -290,6 +290,10 @@ func (a *arena) string(s string) uint32 {
func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncII(env, "go_progress", progressCallback) util.ExportFuncII(env, "go_progress", progressCallback)
util.ExportFuncII(env, "go_commit_hook", commitCallback)
util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback)
util.ExportFuncVIIIIJ(env, "go_update_hook", updateCallback)
util.ExportFuncIIIIIII(env, "go_authorizer", authorizerCallback)
util.ExportFuncVIII(env, "go_log", logCallback) util.ExportFuncVIII(env, "go_log", logCallback)
util.ExportFuncVI(env, "go_destroy", destroyCallback) util.ExportFuncVI(env, "go_destroy", destroyCallback)
util.ExportFuncVIIII(env, "go_func", funcCallback) util.ExportFuncVIIII(env, "go_func", funcCallback)
@@ -297,9 +301,10 @@ func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncVIII(env, "go_final", finalCallback) util.ExportFuncVIII(env, "go_final", finalCallback)
util.ExportFuncVII(env, "go_value", valueCallback) util.ExportFuncVII(env, "go_value", valueCallback)
util.ExportFuncVIIII(env, "go_inverse", inverseCallback) util.ExportFuncVIIII(env, "go_inverse", inverseCallback)
util.ExportFuncVIIII(env, "go_collation_needed", collationCallback)
util.ExportFuncIIIIII(env, "go_compare", compareCallback) util.ExportFuncIIIIII(env, "go_compare", compareCallback)
util.ExportFuncIIIIII(env, "go_vtab_create", vtabModuleCallback(0)) util.ExportFuncIIIIII(env, "go_vtab_create", vtabModuleCallback(xCreate))
util.ExportFuncIIIIII(env, "go_vtab_connect", vtabModuleCallback(1)) util.ExportFuncIIIIII(env, "go_vtab_connect", vtabModuleCallback(xConnect))
util.ExportFuncII(env, "go_vtab_disconnect", vtabDisconnectCallback) util.ExportFuncII(env, "go_vtab_disconnect", vtabDisconnectCallback)
util.ExportFuncII(env, "go_vtab_destroy", vtabDestroyCallback) util.ExportFuncII(env, "go_vtab_destroy", vtabDestroyCallback)
util.ExportFuncIII(env, "go_vtab_best_index", vtabBestIndexCallback) util.ExportFuncIII(env, "go_vtab_best_index", vtabBestIndexCallback)

View File

@@ -1,8 +1,11 @@
#include <stdbool.h>
#include <stddef.h> #include <stddef.h>
#include "include.h" #include "include.h"
#include "sqlite3.h" #include "sqlite3.h"
void go_collation_needed(void *, sqlite3 *, int, const char *);
int go_compare(go_handle, int, const void *, int, const void *); int go_compare(go_handle, int, const void *, int, const void *);
void go_func(sqlite3_context *, go_handle, int, sqlite3_value **); void go_func(sqlite3_context *, go_handle, int, sqlite3_value **);
@@ -44,6 +47,11 @@ void go_inverse_wrapper(sqlite3_context *ctx, int nArg, sqlite3_value **pArg) {
go_inverse(ctx, *agg, nArg, pArg); go_inverse(ctx, *agg, nArg, pArg);
} }
int sqlite3_collation_needed_go(sqlite3 *db, bool enable) {
return sqlite3_collation_needed(db, /*arg=*/NULL,
enable ? go_collation_needed : NULL);
}
int sqlite3_create_collation_go(sqlite3 *db, const char *name, go_handle app) { int sqlite3_create_collation_go(sqlite3 *db, const char *name, go_handle app) {
int rc = sqlite3_create_collation_v2(db, name, SQLITE_UTF8, app, go_compare, int rc = sqlite3_create_collation_v2(db, name, SQLITE_UTF8, app, go_compare,
go_destroy); go_destroy);

39
sqlite3/hooks.c Normal file
View File

@@ -0,0 +1,39 @@
#include <stdbool.h>
#include "sqlite3.h"
int go_progress(void *);
int go_commit_hook(void *);
void go_rollback_hook(void *);
void go_update_hook(void *, int, char const *, char const *, sqlite3_int64);
int go_authorizer(void *, int, const char *, const char *, const char *,
const char *);
void go_log(void *, int, const char *);
void sqlite3_progress_handler_go(sqlite3 *db, int n) {
sqlite3_progress_handler(db, n, go_progress, /*arg=*/db);
}
void sqlite3_commit_hook_go(sqlite3 *db, bool enable) {
sqlite3_commit_hook(db, enable ? go_commit_hook : NULL, /*arg=*/db);
}
void sqlite3_rollback_hook_go(sqlite3 *db, bool enable) {
sqlite3_rollback_hook(db, enable ? go_rollback_hook : NULL, /*arg=*/db);
}
void sqlite3_update_hook_go(sqlite3 *db, bool enable) {
sqlite3_update_hook(db, enable ? go_update_hook : NULL, /*arg=*/db);
}
int sqlite3_set_authorizer_go(sqlite3 *db, bool enable) {
return sqlite3_set_authorizer(db, enable ? go_authorizer : NULL, /*arg=*/db);
}
int sqlite3_config_log_go(bool enable) {
return sqlite3_config(SQLITE_CONFIG_LOG, enable ? go_log : NULL,
/*arg=*/NULL);
}

View File

@@ -1,9 +0,0 @@
#include <stdbool.h>
#include "sqlite3.h"
void go_log(void *, int, const char *);
int sqlite3_config_log_go(bool enable) {
return sqlite3_config(SQLITE_CONFIG_LOG, enable ? go_log : NULL, NULL);
}

View File

@@ -12,9 +12,8 @@
// Bindings // Bindings
#include "column.c" #include "column.c"
#include "func.c" #include "func.c"
#include "log.c" #include "hooks.c"
#include "pointer.c" #include "pointer.c"
#include "progress.c"
#include "time.c" #include "time.c"
#include "vfs.c" #include "vfs.c"
#include "vtab.c" #include "vtab.c"

View File

@@ -1,9 +0,0 @@
#include <stddef.h>
#include "sqlite3.h"
int go_progress(void *);
void sqlite3_progress_handler_go(sqlite3 *db, int n) {
sqlite3_progress_handler(db, n, go_progress, /*arg=*/NULL);
}

11
vtab.go
View File

@@ -128,6 +128,13 @@ type VTabConstructor[T VTab] func(db *Conn, module, schema, table string, arg ..
type module[T VTab] [2]VTabConstructor[T] type module[T VTab] [2]VTabConstructor[T]
type vtabConstructor int
const (
xCreate vtabConstructor = 0
xConnect vtabConstructor = 1
)
// A VTab describes a particular instance of the virtual table. // A VTab describes a particular instance of the virtual table.
// A VTab may optionally implement [io.Closer] to free resources. // A VTab may optionally implement [io.Closer] to free resources.
// //
@@ -414,7 +421,7 @@ const (
INDEX_SCAN_UNIQUE IndexScanFlag = 1 INDEX_SCAN_UNIQUE IndexScanFlag = 1
) )
func vtabModuleCallback(i int) func(_ context.Context, _ api.Module, _, _, _, _, _ uint32) uint32 { func vtabModuleCallback(i vtabConstructor) func(_ context.Context, _ api.Module, _, _, _, _, _ uint32) uint32 {
return func(ctx context.Context, mod api.Module, pMod, nArg, pArg, ppVTab, pzErr uint32) uint32 { return func(ctx context.Context, mod api.Module, pMod, nArg, pArg, ppVTab, pzErr uint32) uint32 {
arg := make([]reflect.Value, 1+nArg) arg := make([]reflect.Value, 1+nArg)
arg[0] = reflect.ValueOf(ctx.Value(connKey{})) arg[0] = reflect.ValueOf(ctx.Value(connKey{}))
@@ -425,7 +432,7 @@ func vtabModuleCallback(i int) func(_ context.Context, _ api.Module, _, _, _, _,
} }
module := vtabGetHandle(ctx, mod, pMod) module := vtabGetHandle(ctx, mod, pMod)
res := reflect.ValueOf(module).Index(i).Call(arg) res := reflect.ValueOf(module).Index(int(i)).Call(arg)
err, _ := res[1].Interface().(error) err, _ := res[1].Interface().(error)
if err == nil { if err == nil {
vtabPutHandle(ctx, mod, ppVTab, res[0].Interface()) vtabPutHandle(ctx, mod, ppVTab, res[0].Interface())