diff --git a/conn.go b/conn.go index 56df742..fe8295c 100644 --- a/conn.go +++ b/conn.go @@ -265,7 +265,7 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { return old } -func callbackProgress(ctx context.Context, mod api.Module, _ uint32) uint32 { +func progressCallback(ctx context.Context, mod api.Module, _ uint32) uint32 { if c, ok := ctx.Value(connKey{}).(*Conn); ok { if c.interrupt != nil && c.interrupt.Err() != nil { return 1 diff --git a/context.go b/context.go index 601c37e..6de20bb 100644 --- a/context.go +++ b/context.go @@ -210,17 +210,8 @@ func (ctx Context) ResultError(err error) { uint64(ctx.handle), uint64(ptr), uint64(len(str))) ctx.c.free(ptr) - var code uint64 - var ecode ErrorCode - var xcode xErrorCode - switch { - case errors.As(err, &xcode): - code = uint64(xcode) - case errors.As(err, &ecode): - code = uint64(ecode) - } - if code != 0 { + if code := errorCode(err, _OK); code != _OK { ctx.c.call(ctx.c.api.resultErrorCode, - uint64(ctx.handle), code) + uint64(ctx.handle), uint64(code)) } } diff --git a/driver/time_test.go b/driver/time_test.go index e0306ff..03dfb6a 100644 --- a/driver/time_test.go +++ b/driver/time_test.go @@ -58,7 +58,7 @@ func Fuzz_stringOrTime_2(f *testing.F) { f.Add(639095955742, 222_222_222) // twosday, year 22222AD f.Add(-763421161058, 222_222_222) // twosday, year 22222BC - checkTime := func(t *testing.T, date time.Time) { + checkTime := func(t testing.TB, date time.Time) { value := stringOrTime([]byte(date.Format(time.RFC3339Nano))) switch v := value.(type) { diff --git a/embed/exports.txt b/embed/exports.txt index 51a4d1c..1233f1b 100644 --- a/embed/exports.txt +++ b/embed/exports.txt @@ -77,4 +77,15 @@ sqlite3_result_value sqlite3_result_error sqlite3_result_error_code sqlite3_result_error_nomem -sqlite3_result_error_toobig \ No newline at end of file +sqlite3_result_error_toobig +sqlite3_create_module_go +sqlite3_declare_vtab +sqlite3_vtab_config_go +sqlite3_vtab_collation +sqlite3_vtab_distinct +sqlite3_vtab_in +sqlite3_vtab_in_first +sqlite3_vtab_in_next +sqlite3_vtab_rhs_value +sqlite3_vtab_nochange +sqlite3_vtab_on_conflict \ No newline at end of file diff --git a/embed/sqlite3.wasm b/embed/sqlite3.wasm index 3e3ea0e..446660c 100755 Binary files a/embed/sqlite3.wasm and b/embed/sqlite3.wasm differ diff --git a/error.go b/error.go index 7d9cf6f..899c67c 100644 --- a/error.go +++ b/error.go @@ -1,6 +1,7 @@ package sqlite3 import ( + "errors" "strconv" "strings" @@ -135,3 +136,18 @@ func (e ExtendedErrorCode) Temporary() bool { func (e ExtendedErrorCode) Timeout() bool { return e == BUSY_TIMEOUT } + +func errorCode(err error, def ErrorCode) (code uint32) { + var ecode ErrorCode + var xcode xErrorCode + switch { + case errors.As(err, &xcode): + return uint32(xcode) + case errors.As(err, &ecode): + return uint32(ecode) + } + if err != nil { + return uint32(def) + } + return _OK +} diff --git a/func.go b/func.go index e7ecc60..ce50a69 100644 --- a/func.go +++ b/func.go @@ -21,6 +21,7 @@ func (c *Conn) AnyCollationNeeded() { // // https://sqlite.org/c3ref/create_collation.html func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { + defer c.arena.reset() namePtr := c.arena.string(name) funcPtr := util.AddHandle(c.ctx, fn) r := c.call(c.api.createCollation, @@ -32,6 +33,7 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { // // https://sqlite.org/c3ref/create_function.html func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error { + defer c.arena.reset() namePtr := c.arena.string(name) funcPtr := util.AddHandle(c.ctx, fn) r := c.call(c.api.createFunction, @@ -46,6 +48,7 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func( // // https://sqlite.org/c3ref/create_function.html func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error { + defer c.arena.reset() call := c.api.createAggregate namePtr := c.arena.string(name) funcPtr := util.AddHandle(c.ctx, fn) @@ -81,55 +84,55 @@ type WindowFunction interface { Inverse(ctx Context, arg ...Value) } -func callbackDestroy(ctx context.Context, mod api.Module, pApp uint32) { +func destroyCallback(ctx context.Context, mod api.Module, pApp uint32) { util.DelHandle(ctx, pApp) } -func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 { +func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 { fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int) return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2)))) } -func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { +func funcCallback(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { db := ctx.Value(connKey{}).(*Conn) - fn := callbackHandle(db, pCtx).(func(ctx Context, arg ...Value)) + fn := userDataHandle(db, pCtx).(func(ctx Context, arg ...Value)) fn(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...) } -func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { +func stepCallback(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { db := ctx.Value(connKey{}).(*Conn) - fn := callbackAggregate(db, pCtx, nil).(AggregateFunction) + fn := aggregateCtxHandle(db, pCtx, nil).(AggregateFunction) fn.Step(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...) } -func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) { +func finalCallback(ctx context.Context, mod api.Module, pCtx uint32) { var handle uint32 db := ctx.Value(connKey{}).(*Conn) - fn := callbackAggregate(db, pCtx, &handle).(AggregateFunction) + fn := aggregateCtxHandle(db, pCtx, &handle).(AggregateFunction) fn.Value(Context{db, pCtx}) if err := util.DelHandle(ctx, handle); err != nil { Context{db, pCtx}.ResultError(err) } } -func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) { +func valueCallback(ctx context.Context, mod api.Module, pCtx uint32) { db := ctx.Value(connKey{}).(*Conn) - fn := callbackAggregate(db, pCtx, nil).(AggregateFunction) + fn := aggregateCtxHandle(db, pCtx, nil).(AggregateFunction) fn.Value(Context{db, pCtx}) } -func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { +func inverseCallback(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { db := ctx.Value(connKey{}).(*Conn) - fn := callbackAggregate(db, pCtx, nil).(WindowFunction) + fn := aggregateCtxHandle(db, pCtx, nil).(WindowFunction) fn.Inverse(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...) } -func callbackHandle(db *Conn, pCtx uint32) any { +func userDataHandle(db *Conn, pCtx uint32) any { pApp := uint32(db.call(db.api.userData, uint64(pCtx))) return util.GetHandle(db.ctx, pApp) } -func callbackAggregate(db *Conn, pCtx uint32, close *uint32) any { +func aggregateCtxHandle(db *Conn, pCtx uint32, close *uint32) any { // On close, we're getting rid of the handle. // Don't allocate space to store it. var size uint64 @@ -152,7 +155,7 @@ func callbackAggregate(db *Conn, pCtx uint32, close *uint32) any { } // Create a new aggregate and store the handle. - fn := callbackHandle(db, pCtx).(func() AggregateFunction)() + fn := userDataHandle(db, pCtx).(func() AggregateFunction)() if ptr != 0 { util.WriteUint32(db.mod, ptr, util.AddHandle(db.ctx, fn)) } diff --git a/sqlite.go b/sqlite.go index 6b025c8..6de79a0 100644 --- a/sqlite.go +++ b/sqlite.go @@ -183,6 +183,8 @@ func instantiateSQLite() (sqlt *sqlite, err error) { resultErrorCode: getFun("sqlite3_result_error_code"), resultErrorMem: getFun("sqlite3_result_error_nomem"), resultErrorBig: getFun("sqlite3_result_error_toobig"), + createModule: getFun("sqlite3_create_module_go"), + declareVTab: getFun("sqlite3_declare_vtab"), } if err != nil { return nil, err @@ -407,17 +409,42 @@ type sqliteAPI struct { resultErrorCode api.Function resultErrorMem api.Function resultErrorBig api.Function + createModule api.Function + declareVTab api.Function destructor uint32 } func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { - util.ExportFuncII(env, "go_progress", callbackProgress) - util.ExportFuncVI(env, "go_destroy", callbackDestroy) - util.ExportFuncVIII(env, "go_func", callbackFunc) - util.ExportFuncVIII(env, "go_step", callbackStep) - util.ExportFuncVI(env, "go_final", callbackFinal) - util.ExportFuncVI(env, "go_value", callbackValue) - util.ExportFuncVIII(env, "go_inverse", callbackInverse) - util.ExportFuncIIIIII(env, "go_compare", callbackCompare) + util.ExportFuncII(env, "go_progress", progressCallback) + util.ExportFuncVI(env, "go_destroy", destroyCallback) + util.ExportFuncVIII(env, "go_func", funcCallback) + util.ExportFuncVIII(env, "go_step", stepCallback) + util.ExportFuncVI(env, "go_final", finalCallback) + util.ExportFuncVI(env, "go_value", valueCallback) + util.ExportFuncVIII(env, "go_inverse", inverseCallback) + util.ExportFuncIIIIII(env, "go_compare", compareCallback) + util.ExportFuncIIIIII(env, "go_vtab_create", vtabConnectCallback) + util.ExportFuncIIIIII(env, "go_vtab_connect", vtabConnectCallback) + util.ExportFuncII(env, "go_vtab_disconnect", vtabDisconnectCallback) + util.ExportFuncII(env, "go_vtab_destroy", vtabDisconnectCallback) + util.ExportFuncIII(env, "go_vtab_best_index", vtabBestIndexCallback) + util.ExportFuncIIIII(env, "go_vtab_update", vtabCallbackIIII) + util.ExportFuncIII(env, "go_vtab_rename", vtabCallbackII) + util.ExportFuncIIIII(env, "go_vtab_find_function", vtabCallbackIIII) + util.ExportFuncII(env, "go_vtab_begin", vtabCallbackI) + util.ExportFuncII(env, "go_vtab_sync", vtabCallbackI) + util.ExportFuncII(env, "go_vtab_commit", vtabCallbackI) + util.ExportFuncII(env, "go_vtab_rollback", vtabCallbackI) + util.ExportFuncIII(env, "go_vtab_savepoint", vtabCallbackII) + util.ExportFuncIII(env, "go_vtab_release", vtabCallbackII) + util.ExportFuncIII(env, "go_vtab_rollback_to", vtabCallbackII) + util.ExportFuncIIIIII(env, "go_vtab_integrity", vtabIntegrityCallback) + util.ExportFuncIII(env, "go_cur_open", cursorOpenCallback) + util.ExportFuncII(env, "go_cur_close", cursorCallbackI) + util.ExportFuncIIIIII(env, "go_cur_filter", cursorFilterCallback) + util.ExportFuncII(env, "go_cur_next", cursorCallbackI) + util.ExportFuncII(env, "go_cur_eof", cursorCallbackI) + util.ExportFuncIIII(env, "go_cur_column", cursorColumnCallback) + util.ExportFuncIII(env, "go_cur_rowid", cursorRowidCallback) return env } diff --git a/sqlite3/func.c b/sqlite3/func.c index c10bbed..ce35292 100644 --- a/sqlite3/func.c +++ b/sqlite3/func.c @@ -1,12 +1,7 @@ #include #include "sqlite3.h" - -typedef void *go_handle; - -void go_destroy(go_handle); - -static_assert(sizeof(go_handle) == 4, "Unexpected size"); +#include "types.h" void go_func(sqlite3_context *, int, sqlite3_value **); void go_step(sqlite3_context *, int, sqlite3_value **); diff --git a/sqlite3/main.c b/sqlite3/main.c index 6189976..bc63b35 100644 --- a/sqlite3/main.c +++ b/sqlite3/main.c @@ -8,12 +8,15 @@ #include "ext/series.c" #include "ext/uint.c" #include "ext/uuid.c" +// Bindings #include "func.c" #include "pointer.c" #include "progress.c" #include "time.c" #include "vfs.c" -// #include "vtab.c" +#include "vtab.c" + +sqlite3_destructor_type malloc_destructor = &free; __attribute__((constructor)) void init() { sqlite3_initialize(); diff --git a/sqlite3/pointer.c b/sqlite3/pointer.c index d9a317e..c4a66c8 100644 --- a/sqlite3/pointer.c +++ b/sqlite3/pointer.c @@ -1,5 +1,6 @@ #include "sqlite3.h" +#include "types.h" #define GO_POINTER_TYPE "github.com/ncruces/go-sqlite3.Pointer" diff --git a/sqlite3/types.h b/sqlite3/types.h new file mode 100644 index 0000000..d5e369f --- /dev/null +++ b/sqlite3/types.h @@ -0,0 +1,7 @@ +#pragma once + +typedef void *go_handle; + +void go_destroy(go_handle); + +static_assert(sizeof(go_handle) == 4, "Unexpected size"); \ No newline at end of file diff --git a/sqlite3/vfs.c b/sqlite3/vfs.c index 896f86f..dc2b2ce 100644 --- a/sqlite3/vfs.c +++ b/sqlite3/vfs.c @@ -3,6 +3,7 @@ #include #include "sqlite3.h" +#include "types.h" int go_localtime(struct tm *, sqlite3_int64); int go_vfs_find(const char *zVfsName); @@ -83,8 +84,6 @@ int sqlite3_os_init() { return sqlite3_vfs_register(&os_vfs, /*default=*/true); } -sqlite3_destructor_type malloc_destructor = &free; - int localtime_s(struct tm *const pTm, time_t const *const pTime) { return go_localtime(pTm, (sqlite3_int64)*pTime); } diff --git a/sqlite3/vtab.c b/sqlite3/vtab.c index 6daee2b..8134118 100644 --- a/sqlite3/vtab.c +++ b/sqlite3/vtab.c @@ -1,28 +1,30 @@ #include #include "sqlite3.h" +#include "types.h" // https://github.com/JuliaLang/julia/blob/v1.9.4/src/julia.h#L67-L68 #define container_of(ptr, type, member) \ ((type *)((char *)(ptr)-offsetof(type, member))) -#define SQLITE_MOD_CREATOR_GO /*******/ 0x01 -#define SQLITE_VTAB_UPDATER_GO /******/ 0x02 -#define SQLITE_VTAB_RENAMER_GO /******/ 0x04 -#define SQLITE_VTAB_OVERLOADER_GO /***/ 0x08 -#define SQLITE_VTAB_CHECKER_GO /******/ 0x10 -#define SQLITE_VTAB_TX_GO /***********/ 0x20 -#define SQLITE_VTAB_SAVEPOINTER_GO /**/ 0x40 +#define SQLITE_VTAB_CREATOR_GO /******/ 0x01 +#define SQLITE_VTAB_DESTROYER_GO /****/ 0x02 +#define SQLITE_VTAB_UPDATER_GO /******/ 0x04 +#define SQLITE_VTAB_RENAMER_GO /******/ 0x08 +#define SQLITE_VTAB_OVERLOADER_GO /***/ 0x10 +#define SQLITE_VTAB_CHECKER_GO /******/ 0x20 +#define SQLITE_VTAB_TX_GO /***********/ 0x40 +#define SQLITE_VTAB_SAVEPOINTER_GO /**/ 0x80 -int go_mod_create(sqlite3_module *, int argc, const char *const *argv, - sqlite3_vtab **, char **pzErr); -int go_mod_connect(sqlite3_module *, int argc, const char *const *argv, +int go_vtab_create(sqlite3_module *, int argc, const char *const *argv, sqlite3_vtab **, char **pzErr); +int go_vtab_connect(sqlite3_module *, int argc, const char *const *argv, + sqlite3_vtab **, char **pzErr); int go_vtab_disconnect(sqlite3_vtab *); int go_vtab_destroy(sqlite3_vtab *); int go_vtab_best_index(sqlite3_vtab *, sqlite3_index_info *); -int go_vtab_open(sqlite3_vtab *, sqlite3_vtab_cursor **); +int go_cur_open(sqlite3_vtab *, sqlite3_vtab_cursor **); int go_cur_close(sqlite3_vtab_cursor *); int go_cur_filter(sqlite3_vtab_cursor *, int idxNum, const char *idxStr, @@ -71,23 +73,7 @@ static void go_mod_destroy(void *pAux) { go_destroy(handle); } -static int go_mod_create_wrapper(sqlite3 *db, void *pAux, int argc, - const char *const *argv, sqlite3_vtab **ppVTab, - char **pzErr) { - struct go_vtab *vtab = calloc(1, sizeof(struct go_vtab)); - if (vtab == NULL) return SQLITE_NOMEM; - *ppVTab = &vtab->base; - - struct go_module *mod = (struct go_module *)pAux; - int rc = go_mod_create(&mod->base, argc, argv, ppVTab, pzErr); - if (rc) { - if (*pzErr) *pzErr = sqlite3_mprintf("%s", *pzErr); - free(vtab); - } - return rc; -} - -static int go_mod_connect_wrapper(sqlite3 *db, void *pAux, int argc, +static int go_vtab_create_wrapper(sqlite3 *db, void *pAux, int argc, const char *const *argv, sqlite3_vtab **ppVTab, char **pzErr) { struct go_vtab *vtab = calloc(1, sizeof(struct go_vtab)); @@ -95,7 +81,23 @@ static int go_mod_connect_wrapper(sqlite3 *db, void *pAux, int argc, *ppVTab = &vtab->base; struct go_module *mod = (struct go_module *)pAux; - int rc = go_mod_connect(&mod->base, argc, argv, ppVTab, pzErr); + int rc = go_vtab_create(&mod->base, argc, argv, ppVTab, pzErr); + if (rc) { + if (*pzErr) *pzErr = sqlite3_mprintf("%s", *pzErr); + free(vtab); + } + return rc; +} + +static int go_vtab_connect_wrapper(sqlite3 *db, void *pAux, int argc, + const char *const *argv, + sqlite3_vtab **ppVTab, char **pzErr) { + struct go_vtab *vtab = calloc(1, sizeof(struct go_vtab)); + if (vtab == NULL) return SQLITE_NOMEM; + *ppVTab = &vtab->base; + + struct go_module *mod = (struct go_module *)pAux; + int rc = go_vtab_connect(&mod->base, argc, argv, ppVTab, pzErr); if (rc) { free(vtab); if (*pzErr) *pzErr = sqlite3_mprintf("%s", *pzErr); @@ -117,13 +119,13 @@ static int go_vtab_destroy_wrapper(sqlite3_vtab *pVTab) { return rc; } -static int go_vtab_open_wrapper(sqlite3_vtab *pVTab, - sqlite3_vtab_cursor **ppCursor) { +static int go_cur_open_wrapper(sqlite3_vtab *pVTab, + sqlite3_vtab_cursor **ppCursor) { struct go_cursor *cur = calloc(1, sizeof(struct go_cursor)); if (cur == NULL) return SQLITE_NOMEM; *ppCursor = &cur->base; - int rc = go_vtab_open(pVTab, ppCursor); + int rc = go_cur_open(pVTab, ppCursor); if (rc) free(cur); return rc; } @@ -158,7 +160,7 @@ static int go_vtab_integrity_wrapper(sqlite3_vtab *pVTab, const char *zSchema, } int sqlite3_create_module_go(sqlite3 *db, const char *zName, int flags, - void *handle) { + go_handle handle) { struct go_module *mod = malloc(sizeof(struct go_module)); if (mod == NULL) { go_destroy(handle); @@ -168,10 +170,10 @@ int sqlite3_create_module_go(sqlite3 *db, const char *zName, int flags, mod->handle = handle; mod->base = (sqlite3_module){ .iVersion = 4, - .xConnect = go_mod_connect_wrapper, + .xConnect = go_vtab_connect_wrapper, .xDisconnect = go_vtab_disconnect_wrapper, .xBestIndex = go_vtab_best_index, - .xOpen = go_vtab_open_wrapper, + .xOpen = go_cur_open_wrapper, .xClose = go_cur_close_wrapper, .xFilter = go_cur_filter, .xNext = go_cur_next, @@ -179,9 +181,14 @@ int sqlite3_create_module_go(sqlite3 *db, const char *zName, int flags, .xColumn = go_cur_column, .xRowid = go_cur_rowid, }; - if (flags & SQLITE_MOD_CREATOR_GO) { - mod->base.xCreate = go_mod_create_wrapper; - mod->base.xDestroy = go_vtab_destroy_wrapper; + if (flags & SQLITE_VTAB_CREATOR_GO) { + if (flags & SQLITE_VTAB_DESTROYER_GO) { + mod->base.xCreate = go_vtab_create_wrapper; + mod->base.xDestroy = go_vtab_destroy_wrapper; + } else { + mod->base.xCreate = mod->base.xConnect; + mod->base.xDestroy = mod->base.xDisconnect; + } } if (flags & SQLITE_VTAB_UPDATER_GO) { mod->base.xUpdate = go_vtab_update; @@ -210,6 +217,10 @@ int sqlite3_create_module_go(sqlite3 *db, const char *zName, int flags, return sqlite3_create_module_v2(db, zName, &mod->base, mod, go_mod_destroy); } +int sqlite3_vtab_config_go(sqlite3 *db, int op, int constraint) { + return sqlite3_vtab_config(db, op, constraint); +} + static_assert(offsetof(struct go_module, base) == 4, "Unexpected offset"); static_assert(offsetof(struct go_vtab, base) == 4, "Unexpected offset"); static_assert(offsetof(struct go_cursor, base) == 4, "Unexpected offset"); \ No newline at end of file diff --git a/tests/db_test.go b/tests/db_test.go index 3ff27f4..54acbde 100644 --- a/tests/db_test.go +++ b/tests/db_test.go @@ -46,7 +46,7 @@ func TestDB_vfs(t *testing.T) { testDB(t, "file:test.db?vfs=memdb") } -func testDB(t *testing.T, name string) { +func testDB(t testing.TB, name string) { db, err := sqlite3.Open(name) if err != nil { t.Fatal(err) diff --git a/tests/parallel/parallel_test.go b/tests/parallel/parallel_test.go index 42df69f..f920235 100644 --- a/tests/parallel/parallel_test.go +++ b/tests/parallel/parallel_test.go @@ -11,7 +11,7 @@ import ( "github.com/ncruces/go-sqlite3" _ "github.com/ncruces/go-sqlite3/embed" - _ "github.com/ncruces/go-sqlite3/vfs/memdb" + "github.com/ncruces/go-sqlite3/vfs/memdb" ) func TestParallel(t *testing.T) { @@ -96,7 +96,16 @@ func TestChildProcess(t *testing.T) { testParallel(t, name, 1000) } -func testParallel(t *testing.T, name string, n int) { +func BenchmarkMemory(b *testing.B) { + memdb.Delete("test.db") + name := "file:/test.db?vfs=memdb" + + "&_pragma=busy_timeout(10000)" + + "&_pragma=journal_mode(memory)" + + "&_pragma=synchronous(off)" + testParallel(b, name, b.N) +} + +func testParallel(t testing.TB, name string, n int) { writer := func() error { db, err := sqlite3.Open(name) if err != nil { @@ -174,7 +183,7 @@ func testParallel(t *testing.T, name string, n int) { } } -func testIntegrity(t *testing.T, name string) { +func testIntegrity(t testing.TB, name string) { db, err := sqlite3.Open(name) if err != nil { t.Fatal(err) diff --git a/vtab.go b/vtab.go index e426a97..cbd001e 100644 --- a/vtab.go +++ b/vtab.go @@ -1,21 +1,93 @@ package sqlite3 +import ( + "context" + "reflect" + + "github.com/ncruces/go-sqlite3/internal/util" + "github.com/tetratelabs/wazero/api" +) + +// CreateModule register a new virtual table module name. +func CreateModule[T VTab](conn *Conn, name string, module Module[T]) error { + var flags int + + const ( + VTAB_CREATOR = 0x01 + VTAB_DESTROYER = 0x02 + VTAB_UPDATER = 0x04 + VTAB_RENAMER = 0x08 + VTAB_OVERLOADER = 0x10 + VTAB_CHECKER = 0x20 + VTAB_TX = 0x40 + VTAB_SAVEPOINTER = 0x80 + ) + + create, ok := reflect.TypeOf(module).MethodByName("Create") + connect, _ := reflect.TypeOf(module).MethodByName("Connect") + if ok && create.Type == connect.Type { + flags |= VTAB_CREATOR + } + + vtab := connect.Type.Out(0) + if implements[VTabDestroyer](vtab) { + flags |= VTAB_DESTROYER + } + if implements[VTabUpdater](vtab) { + flags |= VTAB_UPDATER + } + if implements[VTabRenamer](vtab) { + flags |= VTAB_RENAMER + } + if implements[VTabOverloader](vtab) { + flags |= VTAB_OVERLOADER + } + if implements[VTabChecker](vtab) { + flags |= VTAB_CHECKER + } + if implements[VTabTx](vtab) { + flags |= VTAB_TX + } + if implements[VTabSavepointer](vtab) { + flags |= VTAB_SAVEPOINTER + } + + defer conn.arena.reset() + namePtr := conn.arena.string(name) + modulePtr := util.AddHandle(conn.ctx, module) + r := conn.call(conn.api.createModule, uint64(conn.handle), + uint64(namePtr), uint64(flags), uint64(modulePtr)) + return conn.error(r) +} + +func implements[T any](typ reflect.Type) bool { + var ptr *T + return typ.Implements(reflect.TypeOf(ptr).Elem()) +} + +func (c *Conn) DeclareVtab(sql string) error { + defer c.arena.reset() + sqlPtr := c.arena.string(sql) + r := c.call(c.api.declareVTab, uint64(c.handle), uint64(sqlPtr)) + return c.error(r) +} + // A Module defines the implementation of a virtual table. -// Modules that don't also implement [ModuleCreator] provide +// A Module that doesn't implement [ModuleCreator] provides // eponymous-only virtual tables or table-valued functions. // // https://sqlite.org/c3ref/module.html -type Module interface { +type Module[T VTab] interface { // https://sqlite.org/vtab.html#xconnect - Connect(db *Conn, arg ...string) (VTab, error) + Connect(c *Conn, arg ...string) (T, error) } -// A ModuleCreator extends Module for -// non-eponymous virtual tables. -type ModuleCreator interface { - Module +// A ModuleCreator allows virtual tables to be created. +// A persistent virtual table must implement [VTabDestroyer]. +type ModuleCreator[T VTab] interface { + Module[T] // https://sqlite.org/vtab.html#xcreate - Create(db *Conn, arg ...string) (VTabDestroyer, error) + Create(c *Conn, arg ...string) (T, error) } // A VTab describes a particular instance of the virtual table. @@ -30,7 +102,7 @@ type VTab interface { Open() (VTabCursor, error) } -// A VTabDestroyer allows a virtual table to be destroyed. +// A VTabDestroyer allows a persistent virtual table to be destroyed. type VTabDestroyer interface { VTab // https://sqlite.org/vtab.html#sqlite3_module.xDestroy @@ -173,3 +245,78 @@ type IndexScanFlag uint8 const ( Unique IndexScanFlag = 1 ) + +func vtabConnectCallback(ctx context.Context, mod api.Module, pMod, argc, argv, ppVTab, pzErr uint32) uint32 { + const handleOffset = 4 + handle := util.ReadUint32(mod, pMod-handleOffset) + module := util.GetHandle(ctx, handle) + db := ctx.Value(connKey{}).(*Conn) + + arg := make([]reflect.Value, 1+argc) + arg[0] = reflect.ValueOf(db) + + for i := uint32(0); i < argc; i++ { + ptr := util.ReadUint32(mod, argv+i*ptrlen) + arg[i+1] = reflect.ValueOf(util.ReadString(mod, ptr, _MAX_STRING)) + } + + res := reflect.ValueOf(module).MethodByName("Connect").Call(arg) + err, _ := res[1].Interface().(error) + if err == nil { + handle := util.AddHandle(ctx, res[0].Interface()) + ptr := util.ReadUint32(mod, ppVTab) + util.WriteUint32(mod, ptr-handleOffset, handle) + return _OK + } + + // TODO: error message + return errorCode(err, ERROR) +} + +func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo uint32) uint32 { + const handleOffset = 4 + handle := util.ReadUint32(mod, pVTab-handleOffset) + vtab := util.GetHandle(ctx, handle).(VTab) + _ = vtab + return 1 +} + +func vtabDisconnectCallback(ctx context.Context, mod api.Module, pVTab uint32) uint32 { + return 1 +} + +func vtabIntegrityCallback(ctx context.Context, mod api.Module, pVTab, zSchema, zTabName, mFlags, pzErr uint32) uint32 { + return 1 +} + +func vtabCallbackI(ctx context.Context, mod api.Module, _ uint32) uint32 { + return 1 +} + +func vtabCallbackII(ctx context.Context, mod api.Module, _, _ uint32) uint32 { + return 1 +} + +func vtabCallbackIIII(ctx context.Context, mod api.Module, _, _, _, _ uint32) uint32 { + return 1 +} + +func cursorOpenCallback(ctx context.Context, mod api.Module, pVTab, ppCur uint32) uint32 { + return 1 +} + +func cursorFilterCallback(ctx context.Context, mod api.Module, pCur, idxNum, idxStr, argc, argv uint32) uint32 { + return 1 +} + +func cursorColumnCallback(ctx context.Context, mod api.Module, pCur, pCtx, n uint32) uint32 { + return 1 +} + +func cursorRowidCallback(ctx context.Context, mod api.Module, pCur, pRowid uint32) uint32 { + return 1 +} + +func cursorCallbackI(ctx context.Context, mod api.Module, _ uint32) uint32 { + return 1 +}