diff --git a/context.go b/context.go index 6de20bb..6266936 100644 --- a/context.go +++ b/context.go @@ -179,7 +179,7 @@ func (ctx Context) ResultJSON(value any) { ctx.ResultRawText(data) } -// ResultValue sets the result of the function a copy of [Value]. +// ResultValue sets the result of the function to a copy of [Value]. // // https://sqlite.org/c3ref/result_blob.html func (ctx Context) ResultValue(value Value) { diff --git a/internal/util/mem.go b/internal/util/mem.go index 11f3735..a09523f 100644 --- a/internal/util/mem.go +++ b/internal/util/mem.go @@ -24,6 +24,17 @@ func View(mod api.Module, ptr uint32, size uint64) []byte { return buf } +func ReadUint8(mod api.Module, ptr uint32) uint8 { + if ptr == 0 { + panic(NilErr) + } + v, ok := mod.Memory().ReadByte(ptr) + if !ok { + panic(RangeErr) + } + return v +} + func ReadUint32(mod api.Module, ptr uint32) uint32 { if ptr == 0 { panic(NilErr) @@ -35,6 +46,16 @@ func ReadUint32(mod api.Module, ptr uint32) uint32 { return v } +func WriteUint8(mod api.Module, ptr uint32, v uint8) { + if ptr == 0 { + panic(NilErr) + } + ok := mod.Memory().WriteByte(ptr, v) + if !ok { + panic(RangeErr) + } +} + func WriteUint32(mod api.Module, ptr uint32, v uint32) { if ptr == 0 { panic(NilErr) diff --git a/internal/util/mem_test.go b/internal/util/mem_test.go index d18a2b8..733ab34 100644 --- a/internal/util/mem_test.go +++ b/internal/util/mem_test.go @@ -28,6 +28,20 @@ func TestView_overflow(t *testing.T) { t.Error("want panic") } +func TestReadUint8_nil(t *testing.T) { + defer func() { _ = recover() }() + mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) + ReadUint8(mock, 0) + t.Error("want panic") +} + +func TestReadUint8_range(t *testing.T) { + defer func() { _ = recover() }() + mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) + ReadUint8(mock, wazerotest.PageSize) + t.Error("want panic") +} + func TestReadUint32_nil(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) @@ -56,6 +70,20 @@ func TestReadUint64_range(t *testing.T) { t.Error("want panic") } +func TestWriteUint8_nil(t *testing.T) { + defer func() { _ = recover() }() + mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) + WriteUint8(mock, 0, 1) + t.Error("want panic") +} + +func TestWriteUint8_range(t *testing.T) { + defer func() { _ = recover() }() + mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) + WriteUint8(mock, wazerotest.PageSize, 1) + t.Error("want panic") +} + func TestWriteUint32_nil(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) diff --git a/sqlite.go b/sqlite.go index 6de79a0..2597f6d 100644 --- a/sqlite.go +++ b/sqlite.go @@ -423,10 +423,10 @@ func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { util.ExportFuncVI(env, "go_value", valueCallback) util.ExportFuncVIII(env, "go_inverse", inverseCallback) util.ExportFuncIIIIII(env, "go_compare", compareCallback) - util.ExportFuncIIIIII(env, "go_vtab_create", vtabConnectCallback) - util.ExportFuncIIIIII(env, "go_vtab_connect", vtabConnectCallback) + util.ExportFuncIIIIII(env, "go_vtab_create", vtabReflectCallback("Create")) + util.ExportFuncIIIIII(env, "go_vtab_connect", vtabReflectCallback("Connect")) util.ExportFuncII(env, "go_vtab_disconnect", vtabDisconnectCallback) - util.ExportFuncII(env, "go_vtab_destroy", vtabDisconnectCallback) + util.ExportFuncII(env, "go_vtab_destroy", vtabDestroyCallback) util.ExportFuncIII(env, "go_vtab_best_index", vtabBestIndexCallback) util.ExportFuncIIIII(env, "go_vtab_update", vtabCallbackIIII) util.ExportFuncIII(env, "go_vtab_rename", vtabCallbackII) diff --git a/sqlite3/vtab.c b/sqlite3/vtab.c index 8134118..dccf1f1 100644 --- a/sqlite3/vtab.c +++ b/sqlite3/vtab.c @@ -223,4 +223,8 @@ int sqlite3_vtab_config_go(sqlite3 *db, int op, int constraint) { static_assert(offsetof(struct go_module, base) == 4, "Unexpected offset"); static_assert(offsetof(struct go_vtab, base) == 4, "Unexpected offset"); -static_assert(offsetof(struct go_cursor, base) == 4, "Unexpected offset"); \ No newline at end of file +static_assert(offsetof(struct go_cursor, base) == 4, "Unexpected offset"); +static_assert(sizeof(struct sqlite3_index_info) == 72, "Unexpected size"); +static_assert(sizeof(struct sqlite3_index_constraint) == 12, "Unexpected size"); +static_assert(sizeof(struct sqlite3_index_constraint_usage) == 8, "Unexpected size"); +static_assert(sizeof(struct sqlite3_index_orderby) == 8, "Unexpected size"); \ No newline at end of file diff --git a/vtab.go b/vtab.go index cbd001e..33f01e3 100644 --- a/vtab.go +++ b/vtab.go @@ -189,20 +189,10 @@ type VTabCursor interface { // https://sqlite.org/c3ref/index_info.html type IndexInfo struct { /* Inputs */ - Constraint []struct { - Column int - Op IndexConstraintOp - Usable bool - } - OrderBy []struct { - Column int - Desc bool - } + Constraint []IndexConstraint + OrderBy []IndexOrderBy /* Outputs */ - ConstraintUsage []struct { - ArgvIndex int - Omit bool - } + ConstraintUsage []IndexConstraintUsage IdxNum int IdxStr string IdxFlags IndexScanFlag @@ -212,6 +202,85 @@ type IndexInfo struct { ColumnsUsed int64 } +// An IndexConstraint describes virtual table indexing constraint information. +// +// https://sqlite.org/c3ref/index_info.html +type IndexConstraint struct { + Column int + Op IndexConstraintOp + Usable bool +} + +// An IndexOrderBy describes virtual table indexing order by information. +// +// https://sqlite.org/c3ref/index_info.html +type IndexOrderBy struct { + Column int + Desc bool +} + +// An IndexConstraintUsage describes how virtual table indexing constraints will be used. +// +// https://sqlite.org/c3ref/index_info.html +type IndexConstraintUsage struct { + ArgvIndex int + Omit bool +} + +func (idx *IndexInfo) load(ctx context.Context, mod api.Module, ptr uint32) { + // https://sqlite.org/c3ref/index_info.html + + idx.Constraint = make([]IndexConstraint, util.ReadUint32(mod, ptr+0)) + idx.ConstraintUsage = make([]IndexConstraintUsage, util.ReadUint32(mod, ptr+0)) + idx.OrderBy = make([]IndexOrderBy, util.ReadUint32(mod, ptr+8)) + + constraintPtr := util.ReadUint32(mod, ptr+4) + for i := range idx.Constraint { + idx.Constraint[i] = IndexConstraint{ + Column: int(util.ReadUint32(mod, constraintPtr+0)), + Op: IndexConstraintOp(util.ReadUint8(mod, constraintPtr+4)), + Usable: util.ReadUint8(mod, constraintPtr+8) != 0, + } + constraintPtr += 12 + } + + orderByPtr := util.ReadUint32(mod, ptr+12) + for i := range idx.OrderBy { + idx.OrderBy[i] = IndexOrderBy{ + Column: int(util.ReadUint32(mod, orderByPtr+0)), + Desc: util.ReadUint8(mod, orderByPtr+4) != 0, + } + orderByPtr += 8 + } +} + +func (idx *IndexInfo) save(ctx context.Context, mod api.Module, ptr uint32) { + // https://sqlite.org/c3ref/index_info.html + + usagePtr := util.ReadUint32(mod, ptr+16) + for _, usage := range idx.ConstraintUsage { + util.WriteUint32(mod, usagePtr+0, uint32(usage.ArgvIndex)) + if usage.Omit { + util.WriteUint8(mod, usagePtr+4, 1) + } + usagePtr += 8 + } + + util.WriteUint32(mod, ptr+20, uint32(idx.IdxNum)) + if idx.IdxStr != "" { + conn := ctx.Value(connKey{}).(*Conn) + util.WriteUint32(mod, ptr+24, conn.newString(idx.IdxStr)) + util.WriteUint32(mod, ptr+28, 1) + } + if idx.OrderByConsumed { + util.WriteUint32(mod, ptr+32, 1) + } + util.WriteFloat64(mod, ptr+40, idx.EstimatedCost) + util.WriteUint64(mod, ptr+48, uint64(idx.EstimatedRows)) + util.WriteUint32(mod, ptr+56, uint32(idx.IdxFlags)) + util.WriteUint64(mod, ptr+64, uint64(idx.ColumnsUsed)) +} + // IndexConstraintOp is a virtual table constraint operator code. // // https://sqlite.org/c3ref/c_index_constraint_eq.html @@ -240,83 +309,108 @@ const ( // IndexScanFlag is a virtual table scan flag. // // https://www.sqlite.org/c3ref/c_index_scan_unique.html -type IndexScanFlag uint8 +type IndexScanFlag uint32 const ( Unique IndexScanFlag = 1 ) -func vtabConnectCallback(ctx context.Context, mod api.Module, pMod, argc, argv, ppVTab, pzErr uint32) uint32 { - const handleOffset = 4 - handle := util.ReadUint32(mod, pMod-handleOffset) - module := util.GetHandle(ctx, handle) - db := ctx.Value(connKey{}).(*Conn) +func vtabReflectCallback(name string) func(_ context.Context, _ api.Module, _, _, _, _, _ uint32) uint32 { + return func(ctx context.Context, mod api.Module, pMod, argc, argv, ppVTab, pzErr uint32) uint32 { + module := vtabGetHandle(ctx, mod, pMod) + db := ctx.Value(connKey{}).(*Conn) - arg := make([]reflect.Value, 1+argc) - arg[0] = reflect.ValueOf(db) + arg := make([]reflect.Value, 1+argc) + arg[0] = reflect.ValueOf(db) - for i := uint32(0); i < argc; i++ { - ptr := util.ReadUint32(mod, argv+i*ptrlen) - arg[i+1] = reflect.ValueOf(util.ReadString(mod, ptr, _MAX_STRING)) + for i := uint32(0); i < argc; i++ { + ptr := util.ReadUint32(mod, argv+i*ptrlen) + arg[i+1] = reflect.ValueOf(util.ReadString(mod, ptr, _MAX_STRING)) + } + + res := reflect.ValueOf(module).MethodByName(name).Call(arg) + err, _ := res[1].Interface().(error) + if err == nil { + vtabPutHandle(ctx, mod, ppVTab, res[0].Interface()) + return _OK + } + + // TODO: error message? + return errorCode(err, ERROR) } - - res := reflect.ValueOf(module).MethodByName("Connect").Call(arg) - err, _ := res[1].Interface().(error) - if err == nil { - handle := util.AddHandle(ctx, res[0].Interface()) - ptr := util.ReadUint32(mod, ppVTab) - util.WriteUint32(mod, ptr-handleOffset, handle) - return _OK - } - - // TODO: error message - return errorCode(err, ERROR) -} - -func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo uint32) uint32 { - const handleOffset = 4 - handle := util.ReadUint32(mod, pVTab-handleOffset) - vtab := util.GetHandle(ctx, handle).(VTab) - _ = vtab - return 1 } func vtabDisconnectCallback(ctx context.Context, mod api.Module, pVTab uint32) uint32 { - return 1 + vtab := vtabGetHandle(ctx, mod, pVTab).(VTab) + err := vtab.Disconnect() + // TODO: error message? + return errorCode(err, _OK) +} + +func vtabDestroyCallback(ctx context.Context, mod api.Module, pVTab uint32) uint32 { + vtab := vtabGetHandle(ctx, mod, pVTab).(VTabDestroyer) + err := vtab.Destroy() + // TODO: error message? + return errorCode(err, _OK) +} + +func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo uint32) uint32 { + var info IndexInfo + info.load(ctx, mod, pIdxInfo) + + vtab := vtabGetHandle(ctx, mod, pVTab).(VTab) + err := vtab.BestIndex(&info) + + info.save(ctx, mod, pIdxInfo) + // TODO: error message? + return errorCode(err, _OK) } func vtabIntegrityCallback(ctx context.Context, mod api.Module, pVTab, zSchema, zTabName, mFlags, pzErr uint32) uint32 { - return 1 + return uint32(ERROR) } func vtabCallbackI(ctx context.Context, mod api.Module, _ uint32) uint32 { - return 1 + return uint32(ERROR) } func vtabCallbackII(ctx context.Context, mod api.Module, _, _ uint32) uint32 { - return 1 + return uint32(ERROR) } func vtabCallbackIIII(ctx context.Context, mod api.Module, _, _, _, _ uint32) uint32 { - return 1 + return uint32(ERROR) } func cursorOpenCallback(ctx context.Context, mod api.Module, pVTab, ppCur uint32) uint32 { - return 1 + return uint32(ERROR) } func cursorFilterCallback(ctx context.Context, mod api.Module, pCur, idxNum, idxStr, argc, argv uint32) uint32 { - return 1 + return uint32(ERROR) } func cursorColumnCallback(ctx context.Context, mod api.Module, pCur, pCtx, n uint32) uint32 { - return 1 + return uint32(ERROR) } func cursorRowidCallback(ctx context.Context, mod api.Module, pCur, pRowid uint32) uint32 { - return 1 + return uint32(ERROR) } func cursorCallbackI(ctx context.Context, mod api.Module, _ uint32) uint32 { - return 1 + return uint32(ERROR) +} + +func vtabGetHandle(ctx context.Context, mod api.Module, ptr uint32) any { + const handleOffset = 4 + handle := util.ReadUint32(mod, ptr-handleOffset) + return util.GetHandle(ctx, handle) +} + +func vtabPutHandle(ctx context.Context, mod api.Module, pptr uint32, val any) { + const handleOffset = 4 + handle := util.AddHandle(ctx, val) + ptr := util.ReadUint32(mod, pptr) + util.WriteUint32(mod, ptr-handleOffset, handle) } diff --git a/vtab_test.go b/vtab_test.go new file mode 100644 index 0000000..649b440 --- /dev/null +++ b/vtab_test.go @@ -0,0 +1,67 @@ +package sqlite3_test + +import ( + "log" + + "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/embed" +) + +func ExampleCreateModule() { + db, err := sqlite3.Open(":memory:") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + err = sqlite3.CreateModule(db, "generate_series", seriesModule{}) + if err != nil { + log.Fatal(err) + } + + stmt, _, err := db.Prepare(`SELECT value FROM generate_series(5,100,5)`) + if err != nil { + log.Fatal(err) + } + defer stmt.Close() + + // Output: +} + +type seriesModule struct{} + +func (seriesModule) Connect(c *sqlite3.Conn, arg ...string) (*seriesTable, error) { + err := c.DeclareVtab(`CREATE TABLE x(value, start HIDDEN, stop HIDDEN, step HIDDEN)`) + if err != nil { + return nil, err + } + return &seriesTable{0, 0, 1}, nil +} + +type seriesTable struct { + start int64 + stop int64 + step int64 +} + +func (*seriesTable) Disconnect() error { + return nil +} + +func (*seriesTable) BestIndex(idx *sqlite3.IndexInfo) error { + idx.IdxNum = 0 + idx.IdxStr = "default" + argv := 1 + for i, cst := range idx.Constraint { + if cst.Usable && cst.Op == sqlite3.Eq { + idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{ + ArgvIndex: argv, + Omit: true, + } + argv++ + } + } + return nil +} + +func (*seriesTable) Open() (sqlite3.VTabCursor, error) { return nil, nil }