Towards virtual tables.

This commit is contained in:
Nuno Cruces
2023-11-16 01:16:38 +00:00
parent 314098addb
commit 787086b8c1
17 changed files with 317 additions and 97 deletions

View File

@@ -265,7 +265,7 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
return old
}
func callbackProgress(ctx context.Context, mod api.Module, _ uint32) uint32 {
func progressCallback(ctx context.Context, mod api.Module, _ uint32) uint32 {
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
if c.interrupt != nil && c.interrupt.Err() != nil {
return 1

View File

@@ -210,17 +210,8 @@ func (ctx Context) ResultError(err error) {
uint64(ctx.handle), uint64(ptr), uint64(len(str)))
ctx.c.free(ptr)
var code uint64
var ecode ErrorCode
var xcode xErrorCode
switch {
case errors.As(err, &xcode):
code = uint64(xcode)
case errors.As(err, &ecode):
code = uint64(ecode)
}
if code != 0 {
if code := errorCode(err, _OK); code != _OK {
ctx.c.call(ctx.c.api.resultErrorCode,
uint64(ctx.handle), code)
uint64(ctx.handle), uint64(code))
}
}

View File

@@ -58,7 +58,7 @@ func Fuzz_stringOrTime_2(f *testing.F) {
f.Add(639095955742, 222_222_222) // twosday, year 22222AD
f.Add(-763421161058, 222_222_222) // twosday, year 22222BC
checkTime := func(t *testing.T, date time.Time) {
checkTime := func(t testing.TB, date time.Time) {
value := stringOrTime([]byte(date.Format(time.RFC3339Nano)))
switch v := value.(type) {

View File

@@ -77,4 +77,15 @@ sqlite3_result_value
sqlite3_result_error
sqlite3_result_error_code
sqlite3_result_error_nomem
sqlite3_result_error_toobig
sqlite3_result_error_toobig
sqlite3_create_module_go
sqlite3_declare_vtab
sqlite3_vtab_config_go
sqlite3_vtab_collation
sqlite3_vtab_distinct
sqlite3_vtab_in
sqlite3_vtab_in_first
sqlite3_vtab_in_next
sqlite3_vtab_rhs_value
sqlite3_vtab_nochange
sqlite3_vtab_on_conflict

Binary file not shown.

View File

@@ -1,6 +1,7 @@
package sqlite3
import (
"errors"
"strconv"
"strings"
@@ -135,3 +136,18 @@ func (e ExtendedErrorCode) Temporary() bool {
func (e ExtendedErrorCode) Timeout() bool {
return e == BUSY_TIMEOUT
}
func errorCode(err error, def ErrorCode) (code uint32) {
var ecode ErrorCode
var xcode xErrorCode
switch {
case errors.As(err, &xcode):
return uint32(xcode)
case errors.As(err, &ecode):
return uint32(ecode)
}
if err != nil {
return uint32(def)
}
return _OK
}

33
func.go
View File

@@ -21,6 +21,7 @@ func (c *Conn) AnyCollationNeeded() {
//
// https://sqlite.org/c3ref/create_collation.html
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
defer c.arena.reset()
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
r := c.call(c.api.createCollation,
@@ -32,6 +33,7 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error {
defer c.arena.reset()
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
r := c.call(c.api.createFunction,
@@ -46,6 +48,7 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
defer c.arena.reset()
call := c.api.createAggregate
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
@@ -81,55 +84,55 @@ type WindowFunction interface {
Inverse(ctx Context, arg ...Value)
}
func callbackDestroy(ctx context.Context, mod api.Module, pApp uint32) {
func destroyCallback(ctx context.Context, mod api.Module, pApp uint32) {
util.DelHandle(ctx, pApp)
}
func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))
}
func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
func funcCallback(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
db := ctx.Value(connKey{}).(*Conn)
fn := callbackHandle(db, pCtx).(func(ctx Context, arg ...Value))
fn := userDataHandle(db, pCtx).(func(ctx Context, arg ...Value))
fn(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
}
func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
func stepCallback(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
db := ctx.Value(connKey{}).(*Conn)
fn := callbackAggregate(db, pCtx, nil).(AggregateFunction)
fn := aggregateCtxHandle(db, pCtx, nil).(AggregateFunction)
fn.Step(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
}
func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) {
func finalCallback(ctx context.Context, mod api.Module, pCtx uint32) {
var handle uint32
db := ctx.Value(connKey{}).(*Conn)
fn := callbackAggregate(db, pCtx, &handle).(AggregateFunction)
fn := aggregateCtxHandle(db, pCtx, &handle).(AggregateFunction)
fn.Value(Context{db, pCtx})
if err := util.DelHandle(ctx, handle); err != nil {
Context{db, pCtx}.ResultError(err)
}
}
func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) {
func valueCallback(ctx context.Context, mod api.Module, pCtx uint32) {
db := ctx.Value(connKey{}).(*Conn)
fn := callbackAggregate(db, pCtx, nil).(AggregateFunction)
fn := aggregateCtxHandle(db, pCtx, nil).(AggregateFunction)
fn.Value(Context{db, pCtx})
}
func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
func inverseCallback(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
db := ctx.Value(connKey{}).(*Conn)
fn := callbackAggregate(db, pCtx, nil).(WindowFunction)
fn := aggregateCtxHandle(db, pCtx, nil).(WindowFunction)
fn.Inverse(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
}
func callbackHandle(db *Conn, pCtx uint32) any {
func userDataHandle(db *Conn, pCtx uint32) any {
pApp := uint32(db.call(db.api.userData, uint64(pCtx)))
return util.GetHandle(db.ctx, pApp)
}
func callbackAggregate(db *Conn, pCtx uint32, close *uint32) any {
func aggregateCtxHandle(db *Conn, pCtx uint32, close *uint32) any {
// On close, we're getting rid of the handle.
// Don't allocate space to store it.
var size uint64
@@ -152,7 +155,7 @@ func callbackAggregate(db *Conn, pCtx uint32, close *uint32) any {
}
// Create a new aggregate and store the handle.
fn := callbackHandle(db, pCtx).(func() AggregateFunction)()
fn := userDataHandle(db, pCtx).(func() AggregateFunction)()
if ptr != 0 {
util.WriteUint32(db.mod, ptr, util.AddHandle(db.ctx, fn))
}

View File

@@ -183,6 +183,8 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
resultErrorCode: getFun("sqlite3_result_error_code"),
resultErrorMem: getFun("sqlite3_result_error_nomem"),
resultErrorBig: getFun("sqlite3_result_error_toobig"),
createModule: getFun("sqlite3_create_module_go"),
declareVTab: getFun("sqlite3_declare_vtab"),
}
if err != nil {
return nil, err
@@ -407,17 +409,42 @@ type sqliteAPI struct {
resultErrorCode api.Function
resultErrorMem api.Function
resultErrorBig api.Function
createModule api.Function
declareVTab api.Function
destructor uint32
}
func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncII(env, "go_progress", callbackProgress)
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
util.ExportFuncVIII(env, "go_func", callbackFunc)
util.ExportFuncVIII(env, "go_step", callbackStep)
util.ExportFuncVI(env, "go_final", callbackFinal)
util.ExportFuncVI(env, "go_value", callbackValue)
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
util.ExportFuncII(env, "go_progress", progressCallback)
util.ExportFuncVI(env, "go_destroy", destroyCallback)
util.ExportFuncVIII(env, "go_func", funcCallback)
util.ExportFuncVIII(env, "go_step", stepCallback)
util.ExportFuncVI(env, "go_final", finalCallback)
util.ExportFuncVI(env, "go_value", valueCallback)
util.ExportFuncVIII(env, "go_inverse", inverseCallback)
util.ExportFuncIIIIII(env, "go_compare", compareCallback)
util.ExportFuncIIIIII(env, "go_vtab_create", vtabConnectCallback)
util.ExportFuncIIIIII(env, "go_vtab_connect", vtabConnectCallback)
util.ExportFuncII(env, "go_vtab_disconnect", vtabDisconnectCallback)
util.ExportFuncII(env, "go_vtab_destroy", vtabDisconnectCallback)
util.ExportFuncIII(env, "go_vtab_best_index", vtabBestIndexCallback)
util.ExportFuncIIIII(env, "go_vtab_update", vtabCallbackIIII)
util.ExportFuncIII(env, "go_vtab_rename", vtabCallbackII)
util.ExportFuncIIIII(env, "go_vtab_find_function", vtabCallbackIIII)
util.ExportFuncII(env, "go_vtab_begin", vtabCallbackI)
util.ExportFuncII(env, "go_vtab_sync", vtabCallbackI)
util.ExportFuncII(env, "go_vtab_commit", vtabCallbackI)
util.ExportFuncII(env, "go_vtab_rollback", vtabCallbackI)
util.ExportFuncIII(env, "go_vtab_savepoint", vtabCallbackII)
util.ExportFuncIII(env, "go_vtab_release", vtabCallbackII)
util.ExportFuncIII(env, "go_vtab_rollback_to", vtabCallbackII)
util.ExportFuncIIIIII(env, "go_vtab_integrity", vtabIntegrityCallback)
util.ExportFuncIII(env, "go_cur_open", cursorOpenCallback)
util.ExportFuncII(env, "go_cur_close", cursorCallbackI)
util.ExportFuncIIIIII(env, "go_cur_filter", cursorFilterCallback)
util.ExportFuncII(env, "go_cur_next", cursorCallbackI)
util.ExportFuncII(env, "go_cur_eof", cursorCallbackI)
util.ExportFuncIIII(env, "go_cur_column", cursorColumnCallback)
util.ExportFuncIII(env, "go_cur_rowid", cursorRowidCallback)
return env
}

View File

@@ -1,12 +1,7 @@
#include <stddef.h>
#include "sqlite3.h"
typedef void *go_handle;
void go_destroy(go_handle);
static_assert(sizeof(go_handle) == 4, "Unexpected size");
#include "types.h"
void go_func(sqlite3_context *, int, sqlite3_value **);
void go_step(sqlite3_context *, int, sqlite3_value **);

View File

@@ -8,12 +8,15 @@
#include "ext/series.c"
#include "ext/uint.c"
#include "ext/uuid.c"
// Bindings
#include "func.c"
#include "pointer.c"
#include "progress.c"
#include "time.c"
#include "vfs.c"
// #include "vtab.c"
#include "vtab.c"
sqlite3_destructor_type malloc_destructor = &free;
__attribute__((constructor)) void init() {
sqlite3_initialize();

View File

@@ -1,5 +1,6 @@
#include "sqlite3.h"
#include "types.h"
#define GO_POINTER_TYPE "github.com/ncruces/go-sqlite3.Pointer"

7
sqlite3/types.h Normal file
View File

@@ -0,0 +1,7 @@
#pragma once
typedef void *go_handle;
void go_destroy(go_handle);
static_assert(sizeof(go_handle) == 4, "Unexpected size");

View File

@@ -3,6 +3,7 @@
#include <time.h>
#include "sqlite3.h"
#include "types.h"
int go_localtime(struct tm *, sqlite3_int64);
int go_vfs_find(const char *zVfsName);
@@ -83,8 +84,6 @@ int sqlite3_os_init() {
return sqlite3_vfs_register(&os_vfs, /*default=*/true);
}
sqlite3_destructor_type malloc_destructor = &free;
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
return go_localtime(pTm, (sqlite3_int64)*pTime);
}

View File

@@ -1,28 +1,30 @@
#include <stddef.h>
#include "sqlite3.h"
#include "types.h"
// https://github.com/JuliaLang/julia/blob/v1.9.4/src/julia.h#L67-L68
#define container_of(ptr, type, member) \
((type *)((char *)(ptr)-offsetof(type, member)))
#define SQLITE_MOD_CREATOR_GO /*******/ 0x01
#define SQLITE_VTAB_UPDATER_GO /******/ 0x02
#define SQLITE_VTAB_RENAMER_GO /******/ 0x04
#define SQLITE_VTAB_OVERLOADER_GO /***/ 0x08
#define SQLITE_VTAB_CHECKER_GO /******/ 0x10
#define SQLITE_VTAB_TX_GO /***********/ 0x20
#define SQLITE_VTAB_SAVEPOINTER_GO /**/ 0x40
#define SQLITE_VTAB_CREATOR_GO /******/ 0x01
#define SQLITE_VTAB_DESTROYER_GO /****/ 0x02
#define SQLITE_VTAB_UPDATER_GO /******/ 0x04
#define SQLITE_VTAB_RENAMER_GO /******/ 0x08
#define SQLITE_VTAB_OVERLOADER_GO /***/ 0x10
#define SQLITE_VTAB_CHECKER_GO /******/ 0x20
#define SQLITE_VTAB_TX_GO /***********/ 0x40
#define SQLITE_VTAB_SAVEPOINTER_GO /**/ 0x80
int go_mod_create(sqlite3_module *, int argc, const char *const *argv,
sqlite3_vtab **, char **pzErr);
int go_mod_connect(sqlite3_module *, int argc, const char *const *argv,
int go_vtab_create(sqlite3_module *, int argc, const char *const *argv,
sqlite3_vtab **, char **pzErr);
int go_vtab_connect(sqlite3_module *, int argc, const char *const *argv,
sqlite3_vtab **, char **pzErr);
int go_vtab_disconnect(sqlite3_vtab *);
int go_vtab_destroy(sqlite3_vtab *);
int go_vtab_best_index(sqlite3_vtab *, sqlite3_index_info *);
int go_vtab_open(sqlite3_vtab *, sqlite3_vtab_cursor **);
int go_cur_open(sqlite3_vtab *, sqlite3_vtab_cursor **);
int go_cur_close(sqlite3_vtab_cursor *);
int go_cur_filter(sqlite3_vtab_cursor *, int idxNum, const char *idxStr,
@@ -71,23 +73,7 @@ static void go_mod_destroy(void *pAux) {
go_destroy(handle);
}
static int go_mod_create_wrapper(sqlite3 *db, void *pAux, int argc,
const char *const *argv, sqlite3_vtab **ppVTab,
char **pzErr) {
struct go_vtab *vtab = calloc(1, sizeof(struct go_vtab));
if (vtab == NULL) return SQLITE_NOMEM;
*ppVTab = &vtab->base;
struct go_module *mod = (struct go_module *)pAux;
int rc = go_mod_create(&mod->base, argc, argv, ppVTab, pzErr);
if (rc) {
if (*pzErr) *pzErr = sqlite3_mprintf("%s", *pzErr);
free(vtab);
}
return rc;
}
static int go_mod_connect_wrapper(sqlite3 *db, void *pAux, int argc,
static int go_vtab_create_wrapper(sqlite3 *db, void *pAux, int argc,
const char *const *argv,
sqlite3_vtab **ppVTab, char **pzErr) {
struct go_vtab *vtab = calloc(1, sizeof(struct go_vtab));
@@ -95,7 +81,23 @@ static int go_mod_connect_wrapper(sqlite3 *db, void *pAux, int argc,
*ppVTab = &vtab->base;
struct go_module *mod = (struct go_module *)pAux;
int rc = go_mod_connect(&mod->base, argc, argv, ppVTab, pzErr);
int rc = go_vtab_create(&mod->base, argc, argv, ppVTab, pzErr);
if (rc) {
if (*pzErr) *pzErr = sqlite3_mprintf("%s", *pzErr);
free(vtab);
}
return rc;
}
static int go_vtab_connect_wrapper(sqlite3 *db, void *pAux, int argc,
const char *const *argv,
sqlite3_vtab **ppVTab, char **pzErr) {
struct go_vtab *vtab = calloc(1, sizeof(struct go_vtab));
if (vtab == NULL) return SQLITE_NOMEM;
*ppVTab = &vtab->base;
struct go_module *mod = (struct go_module *)pAux;
int rc = go_vtab_connect(&mod->base, argc, argv, ppVTab, pzErr);
if (rc) {
free(vtab);
if (*pzErr) *pzErr = sqlite3_mprintf("%s", *pzErr);
@@ -117,13 +119,13 @@ static int go_vtab_destroy_wrapper(sqlite3_vtab *pVTab) {
return rc;
}
static int go_vtab_open_wrapper(sqlite3_vtab *pVTab,
sqlite3_vtab_cursor **ppCursor) {
static int go_cur_open_wrapper(sqlite3_vtab *pVTab,
sqlite3_vtab_cursor **ppCursor) {
struct go_cursor *cur = calloc(1, sizeof(struct go_cursor));
if (cur == NULL) return SQLITE_NOMEM;
*ppCursor = &cur->base;
int rc = go_vtab_open(pVTab, ppCursor);
int rc = go_cur_open(pVTab, ppCursor);
if (rc) free(cur);
return rc;
}
@@ -158,7 +160,7 @@ static int go_vtab_integrity_wrapper(sqlite3_vtab *pVTab, const char *zSchema,
}
int sqlite3_create_module_go(sqlite3 *db, const char *zName, int flags,
void *handle) {
go_handle handle) {
struct go_module *mod = malloc(sizeof(struct go_module));
if (mod == NULL) {
go_destroy(handle);
@@ -168,10 +170,10 @@ int sqlite3_create_module_go(sqlite3 *db, const char *zName, int flags,
mod->handle = handle;
mod->base = (sqlite3_module){
.iVersion = 4,
.xConnect = go_mod_connect_wrapper,
.xConnect = go_vtab_connect_wrapper,
.xDisconnect = go_vtab_disconnect_wrapper,
.xBestIndex = go_vtab_best_index,
.xOpen = go_vtab_open_wrapper,
.xOpen = go_cur_open_wrapper,
.xClose = go_cur_close_wrapper,
.xFilter = go_cur_filter,
.xNext = go_cur_next,
@@ -179,9 +181,14 @@ int sqlite3_create_module_go(sqlite3 *db, const char *zName, int flags,
.xColumn = go_cur_column,
.xRowid = go_cur_rowid,
};
if (flags & SQLITE_MOD_CREATOR_GO) {
mod->base.xCreate = go_mod_create_wrapper;
mod->base.xDestroy = go_vtab_destroy_wrapper;
if (flags & SQLITE_VTAB_CREATOR_GO) {
if (flags & SQLITE_VTAB_DESTROYER_GO) {
mod->base.xCreate = go_vtab_create_wrapper;
mod->base.xDestroy = go_vtab_destroy_wrapper;
} else {
mod->base.xCreate = mod->base.xConnect;
mod->base.xDestroy = mod->base.xDisconnect;
}
}
if (flags & SQLITE_VTAB_UPDATER_GO) {
mod->base.xUpdate = go_vtab_update;
@@ -210,6 +217,10 @@ int sqlite3_create_module_go(sqlite3 *db, const char *zName, int flags,
return sqlite3_create_module_v2(db, zName, &mod->base, mod, go_mod_destroy);
}
int sqlite3_vtab_config_go(sqlite3 *db, int op, int constraint) {
return sqlite3_vtab_config(db, op, constraint);
}
static_assert(offsetof(struct go_module, base) == 4, "Unexpected offset");
static_assert(offsetof(struct go_vtab, base) == 4, "Unexpected offset");
static_assert(offsetof(struct go_cursor, base) == 4, "Unexpected offset");

View File

@@ -46,7 +46,7 @@ func TestDB_vfs(t *testing.T) {
testDB(t, "file:test.db?vfs=memdb")
}
func testDB(t *testing.T, name string) {
func testDB(t testing.TB, name string) {
db, err := sqlite3.Open(name)
if err != nil {
t.Fatal(err)

View File

@@ -11,7 +11,7 @@ import (
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
"github.com/ncruces/go-sqlite3/vfs/memdb"
)
func TestParallel(t *testing.T) {
@@ -96,7 +96,16 @@ func TestChildProcess(t *testing.T) {
testParallel(t, name, 1000)
}
func testParallel(t *testing.T, name string, n int) {
func BenchmarkMemory(b *testing.B) {
memdb.Delete("test.db")
name := "file:/test.db?vfs=memdb" +
"&_pragma=busy_timeout(10000)" +
"&_pragma=journal_mode(memory)" +
"&_pragma=synchronous(off)"
testParallel(b, name, b.N)
}
func testParallel(t testing.TB, name string, n int) {
writer := func() error {
db, err := sqlite3.Open(name)
if err != nil {
@@ -174,7 +183,7 @@ func testParallel(t *testing.T, name string, n int) {
}
}
func testIntegrity(t *testing.T, name string) {
func testIntegrity(t testing.TB, name string) {
db, err := sqlite3.Open(name)
if err != nil {
t.Fatal(err)

165
vtab.go
View File

@@ -1,21 +1,93 @@
package sqlite3
import (
"context"
"reflect"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
)
// CreateModule register a new virtual table module name.
func CreateModule[T VTab](conn *Conn, name string, module Module[T]) error {
var flags int
const (
VTAB_CREATOR = 0x01
VTAB_DESTROYER = 0x02
VTAB_UPDATER = 0x04
VTAB_RENAMER = 0x08
VTAB_OVERLOADER = 0x10
VTAB_CHECKER = 0x20
VTAB_TX = 0x40
VTAB_SAVEPOINTER = 0x80
)
create, ok := reflect.TypeOf(module).MethodByName("Create")
connect, _ := reflect.TypeOf(module).MethodByName("Connect")
if ok && create.Type == connect.Type {
flags |= VTAB_CREATOR
}
vtab := connect.Type.Out(0)
if implements[VTabDestroyer](vtab) {
flags |= VTAB_DESTROYER
}
if implements[VTabUpdater](vtab) {
flags |= VTAB_UPDATER
}
if implements[VTabRenamer](vtab) {
flags |= VTAB_RENAMER
}
if implements[VTabOverloader](vtab) {
flags |= VTAB_OVERLOADER
}
if implements[VTabChecker](vtab) {
flags |= VTAB_CHECKER
}
if implements[VTabTx](vtab) {
flags |= VTAB_TX
}
if implements[VTabSavepointer](vtab) {
flags |= VTAB_SAVEPOINTER
}
defer conn.arena.reset()
namePtr := conn.arena.string(name)
modulePtr := util.AddHandle(conn.ctx, module)
r := conn.call(conn.api.createModule, uint64(conn.handle),
uint64(namePtr), uint64(flags), uint64(modulePtr))
return conn.error(r)
}
func implements[T any](typ reflect.Type) bool {
var ptr *T
return typ.Implements(reflect.TypeOf(ptr).Elem())
}
func (c *Conn) DeclareVtab(sql string) error {
defer c.arena.reset()
sqlPtr := c.arena.string(sql)
r := c.call(c.api.declareVTab, uint64(c.handle), uint64(sqlPtr))
return c.error(r)
}
// A Module defines the implementation of a virtual table.
// Modules that don't also implement [ModuleCreator] provide
// A Module that doesn't implement [ModuleCreator] provides
// eponymous-only virtual tables or table-valued functions.
//
// https://sqlite.org/c3ref/module.html
type Module interface {
type Module[T VTab] interface {
// https://sqlite.org/vtab.html#xconnect
Connect(db *Conn, arg ...string) (VTab, error)
Connect(c *Conn, arg ...string) (T, error)
}
// A ModuleCreator extends Module for
// non-eponymous virtual tables.
type ModuleCreator interface {
Module
// A ModuleCreator allows virtual tables to be created.
// A persistent virtual table must implement [VTabDestroyer].
type ModuleCreator[T VTab] interface {
Module[T]
// https://sqlite.org/vtab.html#xcreate
Create(db *Conn, arg ...string) (VTabDestroyer, error)
Create(c *Conn, arg ...string) (T, error)
}
// A VTab describes a particular instance of the virtual table.
@@ -30,7 +102,7 @@ type VTab interface {
Open() (VTabCursor, error)
}
// A VTabDestroyer allows a virtual table to be destroyed.
// A VTabDestroyer allows a persistent virtual table to be destroyed.
type VTabDestroyer interface {
VTab
// https://sqlite.org/vtab.html#sqlite3_module.xDestroy
@@ -173,3 +245,78 @@ type IndexScanFlag uint8
const (
Unique IndexScanFlag = 1
)
func vtabConnectCallback(ctx context.Context, mod api.Module, pMod, argc, argv, ppVTab, pzErr uint32) uint32 {
const handleOffset = 4
handle := util.ReadUint32(mod, pMod-handleOffset)
module := util.GetHandle(ctx, handle)
db := ctx.Value(connKey{}).(*Conn)
arg := make([]reflect.Value, 1+argc)
arg[0] = reflect.ValueOf(db)
for i := uint32(0); i < argc; i++ {
ptr := util.ReadUint32(mod, argv+i*ptrlen)
arg[i+1] = reflect.ValueOf(util.ReadString(mod, ptr, _MAX_STRING))
}
res := reflect.ValueOf(module).MethodByName("Connect").Call(arg)
err, _ := res[1].Interface().(error)
if err == nil {
handle := util.AddHandle(ctx, res[0].Interface())
ptr := util.ReadUint32(mod, ppVTab)
util.WriteUint32(mod, ptr-handleOffset, handle)
return _OK
}
// TODO: error message
return errorCode(err, ERROR)
}
func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo uint32) uint32 {
const handleOffset = 4
handle := util.ReadUint32(mod, pVTab-handleOffset)
vtab := util.GetHandle(ctx, handle).(VTab)
_ = vtab
return 1
}
func vtabDisconnectCallback(ctx context.Context, mod api.Module, pVTab uint32) uint32 {
return 1
}
func vtabIntegrityCallback(ctx context.Context, mod api.Module, pVTab, zSchema, zTabName, mFlags, pzErr uint32) uint32 {
return 1
}
func vtabCallbackI(ctx context.Context, mod api.Module, _ uint32) uint32 {
return 1
}
func vtabCallbackII(ctx context.Context, mod api.Module, _, _ uint32) uint32 {
return 1
}
func vtabCallbackIIII(ctx context.Context, mod api.Module, _, _, _, _ uint32) uint32 {
return 1
}
func cursorOpenCallback(ctx context.Context, mod api.Module, pVTab, ppCur uint32) uint32 {
return 1
}
func cursorFilterCallback(ctx context.Context, mod api.Module, pCur, idxNum, idxStr, argc, argv uint32) uint32 {
return 1
}
func cursorColumnCallback(ctx context.Context, mod api.Module, pCur, pCtx, n uint32) uint32 {
return 1
}
func cursorRowidCallback(ctx context.Context, mod api.Module, pCur, pRowid uint32) uint32 {
return 1
}
func cursorCallbackI(ctx context.Context, mod api.Module, _ uint32) uint32 {
return 1
}