diff --git a/config.go b/config.go index 8a84bc9..6876ba5 100644 --- a/config.go +++ b/config.go @@ -284,7 +284,10 @@ func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema uint32, pa // // https://sqlite.org/c3ref/autovacuum_pages.html func (c *Conn) AutoVacuumPages(cb func(schema string, dbPages, freePages, bytesPerPage uint) uint) error { - funcPtr := util.AddHandle(c.ctx, cb) + var funcPtr uint32 + if cb != nil { + funcPtr = util.AddHandle(c.ctx, cb) + } r := c.call("sqlite3_autovacuum_pages_go", uint64(c.handle), uint64(funcPtr)) return c.error(r) } diff --git a/embed/bcw2/bcw2.wasm b/embed/bcw2/bcw2.wasm index 5805df4..eeaab47 100755 Binary files a/embed/bcw2/bcw2.wasm and b/embed/bcw2/bcw2.wasm differ diff --git a/embed/sqlite3.wasm b/embed/sqlite3.wasm index 47aa1ae..6b388a3 100755 Binary files a/embed/sqlite3.wasm and b/embed/sqlite3.wasm differ diff --git a/func.go b/func.go index 4eac249..7ff740d 100644 --- a/func.go +++ b/func.go @@ -44,9 +44,12 @@ func (c Conn) AnyCollationNeeded() error { // // https://sqlite.org/c3ref/create_collation.html func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { + var funcPtr uint32 defer c.arena.mark()() namePtr := c.arena.string(name) - funcPtr := util.AddHandle(c.ctx, fn) + if fn != nil { + funcPtr = util.AddHandle(c.ctx, fn) + } r := c.call("sqlite3_create_collation_go", uint64(c.handle), uint64(namePtr), uint64(funcPtr)) return c.error(r) @@ -56,9 +59,12 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { // // https://sqlite.org/c3ref/create_function.html func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error { + var funcPtr uint32 defer c.arena.mark()() namePtr := c.arena.string(name) - funcPtr := util.AddHandle(c.ctx, fn) + if fn != nil { + funcPtr = util.AddHandle(c.ctx, fn) + } r := c.call("sqlite3_create_function_go", uint64(c.handle), uint64(namePtr), uint64(nArg), uint64(flag), uint64(funcPtr)) @@ -75,10 +81,13 @@ type ScalarFunction func(ctx Context, arg ...Value) // // https://sqlite.org/c3ref/create_function.html func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error { + var funcPtr uint32 defer c.arena.mark()() - call := "sqlite3_create_aggregate_function_go" namePtr := c.arena.string(name) - funcPtr := util.AddHandle(c.ctx, fn) + if fn != nil { + funcPtr = util.AddHandle(c.ctx, fn) + } + call := "sqlite3_create_aggregate_function_go" if _, ok := fn().(WindowFunction); ok { call = "sqlite3_create_window_function_go" } @@ -188,11 +197,12 @@ func callbackAggregate(db *Conn, pAgg, pApp uint32) (AggregateFunction, uint32) // We need to create the aggregate. fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)() - handle := util.AddHandle(db.ctx, fn) if pAgg != 0 { + handle := util.AddHandle(db.ctx, fn) util.WriteUint32(db.mod, pAgg, handle) + return fn, handle } - return fn, handle + return fn, 0 } func callbackArgs(db *Conn, arg []Value, pArg uint32) { diff --git a/internal/util/handle.go b/internal/util/handle.go index 4584324..e4e3385 100644 --- a/internal/util/handle.go +++ b/internal/util/handle.go @@ -35,17 +35,22 @@ func DelHandle(ctx context.Context, id uint32) error { s := ctx.Value(moduleKey{}).(*moduleState) a := s.handles[^id] s.handles[^id] = nil - s.holes++ + if l := uint32(len(s.handles)); l == ^id { + s.handles = s.handles[:l-1] + } else { + s.holes++ + } if c, ok := a.(io.Closer); ok { return c.Close() } return nil } -func AddHandle(ctx context.Context, a any) (id uint32) { +func AddHandle(ctx context.Context, a any) uint32 { if a == nil { panic(NilErr) } + s := ctx.Value(moduleKey{}).(*moduleState) // Find an empty slot. diff --git a/sqlite3/column.c b/sqlite3/column.c index c876f07..a319d27 100644 --- a/sqlite3/column.c +++ b/sqlite3/column.c @@ -22,6 +22,7 @@ int sqlite3_columns_go(sqlite3_stmt *stmt, int nCol, char *aType, switch (aType[i] = sqlite3_column_type(stmt, i)) { default: // SQLITE_NULL aData[i] = (union sqlite3_data){}; + continue; case SQLITE_INTEGER: aData[i].i = sqlite3_column_int64(stmt, i); continue; diff --git a/sqlite3/func.c b/sqlite3/func.c index 240ec3e..9bb830c 100644 --- a/sqlite3/func.c +++ b/sqlite3/func.c @@ -53,6 +53,9 @@ int sqlite3_collation_needed_go(sqlite3 *db, bool enable) { } int sqlite3_create_collation_go(sqlite3 *db, const char *name, go_handle app) { + if (app == NULL) { + return sqlite3_create_collation_v2(db, name, SQLITE_UTF8, NULL, NULL, NULL); + } int rc = sqlite3_create_collation_v2(db, name, SQLITE_UTF8, app, go_compare, go_destroy); if (rc) go_destroy(app); @@ -61,6 +64,10 @@ int sqlite3_create_collation_go(sqlite3 *db, const char *name, go_handle app) { int sqlite3_create_function_go(sqlite3 *db, const char *name, int argc, int flags, go_handle app) { + if (app == NULL) { + return sqlite3_create_function_v2(db, name, argc, SQLITE_UTF8 | flags, NULL, + NULL, NULL, NULL, NULL); + } return sqlite3_create_function_v2(db, name, argc, SQLITE_UTF8 | flags, app, go_func_wrapper, /*step=*/NULL, /*final=*/NULL, go_destroy); @@ -68,6 +75,10 @@ int sqlite3_create_function_go(sqlite3 *db, const char *name, int argc, int sqlite3_create_aggregate_function_go(sqlite3 *db, const char *name, int argc, int flags, go_handle app) { + if (app == NULL) { + return sqlite3_create_function_v2(db, name, argc, SQLITE_UTF8 | flags, NULL, + NULL, NULL, NULL, NULL); + } return sqlite3_create_function_v2(db, name, argc, SQLITE_UTF8 | flags, app, /*func=*/NULL, go_step_wrapper, go_final_wrapper, go_destroy); @@ -75,6 +86,10 @@ int sqlite3_create_aggregate_function_go(sqlite3 *db, const char *name, int sqlite3_create_window_function_go(sqlite3 *db, const char *name, int argc, int flags, go_handle app) { + if (app == NULL) { + return sqlite3_create_window_function(db, name, argc, SQLITE_UTF8 | flags, + NULL, NULL, NULL, NULL, NULL, NULL); + } return sqlite3_create_window_function( db, name, argc, SQLITE_UTF8 | flags, app, go_step_wrapper, go_final_wrapper, go_value_wrapper, go_inverse_wrapper, go_destroy); diff --git a/sqlite3/hooks.c b/sqlite3/hooks.c index 4acc2fd..1bdb405 100644 --- a/sqlite3/hooks.c +++ b/sqlite3/hooks.c @@ -57,9 +57,10 @@ int sqlite3_config_log_go(bool enable) { } int sqlite3_autovacuum_pages_go(sqlite3 *db, go_handle app) { - int rc = sqlite3_autovacuum_pages(db, go_autovacuum_pages, app, go_destroy); - if (rc) go_destroy(app); - return rc; + if (app == NULL) { + return sqlite3_autovacuum_pages(db, NULL, NULL, NULL); + } + return sqlite3_autovacuum_pages(db, go_autovacuum_pages, app, go_destroy); } #ifndef sqliteBusyCallback diff --git a/sqlite3/vtab.c b/sqlite3/vtab.c index 9f7421a..29735f2 100644 --- a/sqlite3/vtab.c +++ b/sqlite3/vtab.c @@ -163,6 +163,10 @@ static int go_vtab_shadown_name_wrapper(const char *zName) { return true; } int sqlite3_create_module_go(sqlite3 *db, const char *zName, int flags, go_handle handle) { + if (handle == NULL) { + return sqlite3_create_module_v2(db, zName, NULL, NULL, NULL); + } + struct go_module *mod = malloc(sizeof(struct go_module)); if (mod == NULL) { go_destroy(handle); diff --git a/tests/func_test.go b/tests/func_test.go index acedc86..0a74d9e 100644 --- a/tests/func_test.go +++ b/tests/func_test.go @@ -209,6 +209,24 @@ func TestCreateFunction_error(t *testing.T) { stmt.Step() } +func TestCreateFunction_delete(t *testing.T) { + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.CreateFunction("regexp", 2, 0, nil) + if err != nil { + t.Fatal(err) + } + + err = db.Exec(`SELECT 'a' REGEXP 'a|b'`) + if err == nil { + t.Error("want error") + } +} + func TestOverloadFunction(t *testing.T) { t.Parallel() diff --git a/tests/vtab_test.go b/tests/vtab_test.go new file mode 100644 index 0000000..9266f73 --- /dev/null +++ b/tests/vtab_test.go @@ -0,0 +1,20 @@ +package tests + +import ( + "testing" + + "github.com/ncruces/go-sqlite3" +) + +func TestCreateModule_delete(t *testing.T) { + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = sqlite3.CreateModule[sqlite3.VTab](db, "generate_series", nil, nil) + if err != nil { + t.Fatal(err) + } +} diff --git a/vtab.go b/vtab.go index 80b3a33..2bb294b 100644 --- a/vtab.go +++ b/vtab.go @@ -57,9 +57,12 @@ func CreateModule[T VTab](db *Conn, name string, create, connect VTabConstructor flags |= VTAB_SHADOWTABS } + var modulePtr uint32 defer db.arena.mark()() namePtr := db.arena.string(name) - modulePtr := util.AddHandle(db.ctx, module[T]{create, connect}) + if connect != nil { + modulePtr = util.AddHandle(db.ctx, module[T]{create, connect}) + } r := db.call("sqlite3_create_module_go", uint64(db.handle), uint64(namePtr), uint64(flags), uint64(modulePtr)) return db.error(r)