diff --git a/sqlite.go b/sqlite.go index 2597f6d..d83b401 100644 --- a/sqlite.go +++ b/sqlite.go @@ -440,11 +440,11 @@ func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { util.ExportFuncIII(env, "go_vtab_rollback_to", vtabCallbackII) util.ExportFuncIIIIII(env, "go_vtab_integrity", vtabIntegrityCallback) util.ExportFuncIII(env, "go_cur_open", cursorOpenCallback) - util.ExportFuncII(env, "go_cur_close", cursorCallbackI) + util.ExportFuncII(env, "go_cur_close", cursorCloseCallback) util.ExportFuncIIIIII(env, "go_cur_filter", cursorFilterCallback) - util.ExportFuncII(env, "go_cur_next", cursorCallbackI) - util.ExportFuncII(env, "go_cur_eof", cursorCallbackI) + util.ExportFuncII(env, "go_cur_next", cursorNextCallback) + util.ExportFuncII(env, "go_cur_eof", cursorEOFCallback) util.ExportFuncIIII(env, "go_cur_column", cursorColumnCallback) - util.ExportFuncIII(env, "go_cur_rowid", cursorRowidCallback) + util.ExportFuncIII(env, "go_cur_rowid", cursorRowIDCallback) return env } diff --git a/vtab.go b/vtab.go index 33f01e3..8104222 100644 --- a/vtab.go +++ b/vtab.go @@ -383,23 +383,66 @@ func vtabCallbackIIII(ctx context.Context, mod api.Module, _, _, _, _ uint32) ui } func cursorOpenCallback(ctx context.Context, mod api.Module, pVTab, ppCur uint32) uint32 { - return uint32(ERROR) + vtab := vtabGetHandle(ctx, mod, pVTab).(VTab) + + cursor, err := vtab.Open() + if err == nil { + vtabPutHandle(ctx, mod, ppCur, cursor) + } + + // TODO: error message? + return errorCode(err, ERROR) +} + +func cursorCloseCallback(ctx context.Context, mod api.Module, pCur uint32) uint32 { + cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) + err := cursor.Close() + // TODO: error message? + return errorCode(err, ERROR) } func cursorFilterCallback(ctx context.Context, mod api.Module, pCur, idxNum, idxStr, argc, argv uint32) uint32 { - return uint32(ERROR) + cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) + db := ctx.Value(connKey{}).(*Conn) + args := callbackArgs(db, argc, argv) + err := cursor.Filter(int(idxNum), util.ReadString(mod, idxStr, _MAX_STRING), args...) + // TODO: error message? + return errorCode(err, ERROR) +} + +func cursorEOFCallback(ctx context.Context, mod api.Module, pCur uint32) uint32 { + cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) + if cursor.EOF() { + return 1 + } + return 0 +} + +func cursorNextCallback(ctx context.Context, mod api.Module, pCur uint32) uint32 { + cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) + err := cursor.Next() + // TODO: error message? + return errorCode(err, ERROR) } func cursorColumnCallback(ctx context.Context, mod api.Module, pCur, pCtx, n uint32) uint32 { - return uint32(ERROR) + cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) + db := ctx.Value(connKey{}).(*Conn) + err := cursor.Column(&Context{db, pCtx}, int(n)) + // TODO: error message? + return errorCode(err, ERROR) } -func cursorRowidCallback(ctx context.Context, mod api.Module, pCur, pRowid uint32) uint32 { - return uint32(ERROR) -} +func cursorRowIDCallback(ctx context.Context, mod api.Module, pCur, pRowID uint32) uint32 { + cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) -func cursorCallbackI(ctx context.Context, mod api.Module, _ uint32) uint32 { - return uint32(ERROR) + rowID, err := cursor.RowID() + if err == nil { + util.WriteUint64(mod, pRowID, uint64(rowID)) + } + + // TODO: error message? + return errorCode(err, ERROR) } func vtabGetHandle(ctx context.Context, mod api.Module, ptr uint32) any { diff --git a/vtab_test.go b/vtab_test.go index 649b440..848102e 100644 --- a/vtab_test.go +++ b/vtab_test.go @@ -1,6 +1,7 @@ package sqlite3_test import ( + "fmt" "log" "github.com/ncruces/go-sqlite3" @@ -19,13 +20,22 @@ func ExampleCreateModule() { log.Fatal(err) } - stmt, _, err := db.Prepare(`SELECT value FROM generate_series(5,100,5)`) + stmt, _, err := db.Prepare(`SELECT rowid, value FROM generate_series(2, 10, 3)`) if err != nil { log.Fatal(err) } defer stmt.Close() + for stmt.Step() { + fmt.Println(stmt.ColumnInt(0), stmt.ColumnInt(1)) + } + if err := stmt.Err(); err != nil { + log.Fatal(err) + } // Output: + // 2 2 + // 5 5 + // 8 8 } type seriesModule struct{} @@ -53,7 +63,7 @@ func (*seriesTable) BestIndex(idx *sqlite3.IndexInfo) error { idx.IdxStr = "default" argv := 1 for i, cst := range idx.Constraint { - if cst.Usable && cst.Op == sqlite3.Eq { + if cst.Op == sqlite3.Eq { idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{ ArgvIndex: argv, Omit: true, @@ -64,4 +74,62 @@ func (*seriesTable) BestIndex(idx *sqlite3.IndexInfo) error { return nil } -func (*seriesTable) Open() (sqlite3.VTabCursor, error) { return nil, nil } +func (tab *seriesTable) Open() (sqlite3.VTabCursor, error) { + return &seriesCursor{tab, 0}, nil +} + +type seriesCursor struct { + *seriesTable + value int64 +} + +func (*seriesCursor) Close() error { + return nil +} + +func (cur *seriesCursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { + switch len(arg) { + case 0: + cur.seriesTable.start = 0 + cur.seriesTable.stop = 1000 + case 1: + cur.seriesTable.start = arg[0].Int64() + cur.seriesTable.stop = 1000 + case 2: + cur.seriesTable.start = arg[0].Int64() + cur.seriesTable.stop = arg[1].Int64() + case 3: + cur.seriesTable.start = arg[0].Int64() + cur.seriesTable.stop = arg[1].Int64() + cur.seriesTable.step = arg[2].Int64() + } + cur.value = cur.seriesTable.start + return nil +} + +func (cur *seriesCursor) Column(ctx *sqlite3.Context, col int) error { + switch col { + case 0: + ctx.ResultInt64(cur.value) + case 1: + ctx.ResultInt64(cur.start) + case 2: + ctx.ResultInt64(cur.stop) + case 3: + ctx.ResultInt64(cur.step) + } + return nil +} + +func (cur *seriesCursor) Next() error { + cur.value += cur.step + return nil +} + +func (cur *seriesCursor) EOF() bool { + return cur.value > cur.stop +} + +func (cur *seriesCursor) RowID() (int64, error) { + return int64(cur.value), nil +}