From b262f5cd01146e1b6368046ae1889f2136ace708 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Thu, 30 Nov 2023 01:13:18 +0000 Subject: [PATCH] Statement virtual table. --- ext/statement/stmt.go | 215 +++++++++++++++++++++++++++++++++++++ ext/statement/stmt_test.go | 51 +++++++++ go.mod | 2 +- go.sum | 4 +- 4 files changed, 269 insertions(+), 3 deletions(-) create mode 100644 ext/statement/stmt.go create mode 100644 ext/statement/stmt_test.go diff --git a/ext/statement/stmt.go b/ext/statement/stmt.go new file mode 100644 index 0000000..0dacc79 --- /dev/null +++ b/ext/statement/stmt.go @@ -0,0 +1,215 @@ +// Package statement defines virtual tables and table-valued functions natively using SQL. +// +// https://github.com/0x09/sqlite-statement-vtab +package statement + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" + "unsafe" + + "github.com/ncruces/go-sqlite3" +) + +// Register registers the statement virtual table. +func Register(db *sqlite3.Conn) { + declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) { + if arg == nil || len(arg[0]) < 3 { + return nil, fmt.Errorf("statement: no statement provided") + } + sql := arg[0] + if len := len(sql); sql[0] != '(' || sql[len-1] != ')' { + return nil, fmt.Errorf("statement: statement must be parenthesized") + } else { + sql = sql[1 : len-1] + } + + table := &table{ + db: db, + sql: sql, + } + err = table.declare() + if err != nil { + return nil, err + } + return table, nil + } + + sqlite3.CreateModule(db, "statement", declare, declare) +} + +type table struct { + db *sqlite3.Conn + sql string + inputs int + outputs int +} + +func (t *table) declare() error { + stmt, tail, err := t.db.Prepare(t.sql) + if err != nil { + return err + } + defer stmt.Close() + if tail != "" { + return fmt.Errorf("statement: multiple statements") + } + // TODO: sqlite3_stmt_readonly + + t.inputs = stmt.BindCount() + t.outputs = stmt.ColumnCount() + + var sep = "" + var str strings.Builder + str.WriteString(`CREATE TABLE x(`) + for i := 0; i < t.outputs; i++ { + str.WriteString(sep) + name := stmt.ColumnName(i) + str.WriteString(sqlite3.QuoteIdentifier(name)) + // TODO: sqlite3_column_decltype + sep = "," + } + for i := 1; i <= t.inputs; i++ { + str.WriteString(sep) + name := stmt.BindName(i) + if name == "" { + str.WriteByte('\'') + str.WriteString(strconv.Itoa(i)) + str.WriteString("' HIDDEN") + } else { + str.WriteString(sqlite3.QuoteIdentifier(name)) + str.WriteString(" HIDDEN") + } + sep = "," + } + + str.WriteByte(')') + return t.db.DeclareVtab(str.String()) +} + +func (t *table) BestIndex(idx *sqlite3.IndexInfo) error { + idx.OrderByConsumed = false + idx.EstimatedCost = 1 + idx.EstimatedRows = 1 + + var argvIndex = 1 + var needIndex bool + var listIndex []int + for i, cst := range idx.Constraint { + // Skip if this is a constraint on one of our output columns. + if cst.Column < t.outputs { + continue + } + + // A given query plan is only usable if all provided input columns + // are usable and have equal constraints only. + if !cst.Usable || cst.Op != sqlite3.INDEX_CONSTRAINT_EQ { + return sqlite3.CONSTRAINT + } + + // The non-zero argvIdx values must be contiguous. + // If they're not, build a list and serialize it through IdxStr. + nextIndex := cst.Column - t.outputs + 1 + idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{ + ArgvIndex: argvIndex, + Omit: true, + } + if nextIndex != argvIndex { + needIndex = true + } + listIndex = append(listIndex, nextIndex) + argvIndex++ + } + + if needIndex { + buf, err := json.Marshal(listIndex) + if err != nil { + return err + } + idx.IdxStr = unsafe.String(&buf[0], len(buf)) + } + return nil +} + +func (t *table) Open() (sqlite3.VTabCursor, error) { + stmt, _, err := t.db.Prepare(t.sql) + if err != nil { + return nil, err + } + return &cursor{table: t, stmt: stmt}, nil +} + +func (t *table) Rename(new string) error { + return nil +} + +type cursor struct { + table *table + stmt *sqlite3.Stmt + arg []sqlite3.Value + rowID int64 + done bool +} + +func (c *cursor) Close() error { + return c.stmt.Close() +} + +func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { + c.arg = arg + c.rowID = 0 + if err := c.stmt.ClearBindings(); err != nil { + return err + } + if err := c.stmt.Reset(); err != nil { + return err + } + + var list []int + if idxStr != "" { + err := json.Unmarshal([]byte(idxStr), &list) + if err != nil { + return err + } + } + + for i, arg := range arg { + param := i + 1 + if list != nil { + param = list[i] + } + err := c.stmt.BindValue(param, arg) + if err != nil { + return err + } + } + return c.Next() +} + +func (c *cursor) Next() error { + if c.stmt.Step() { + c.rowID++ + return nil + } + c.done = true + return c.stmt.Err() +} + +func (c *cursor) EOF() bool { + return c.done +} + +func (c *cursor) RowID() (int64, error) { + return c.rowID, nil +} + +func (c *cursor) Column(ctx *sqlite3.Context, col int) error { + if col < c.table.outputs { + ctx.ResultValue(c.stmt.ColumnValue(col)) + } else if col-c.table.outputs < len(c.arg) { + ctx.ResultValue(c.arg[col-c.table.outputs]) + } + return nil +} diff --git a/ext/statement/stmt_test.go b/ext/statement/stmt_test.go new file mode 100644 index 0000000..8b4d696 --- /dev/null +++ b/ext/statement/stmt_test.go @@ -0,0 +1,51 @@ +package statement_test + +import ( + "fmt" + "log" + + "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/embed" + "github.com/ncruces/go-sqlite3/ext/statement" + "github.com/tetratelabs/wazero" +) + +func Example() { + // This crashes the compiler. + sqlite3.RuntimeConfig = wazero.NewRuntimeConfigInterpreter() + + db, err := sqlite3.Open(":memory:") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + statement.Register(db) + + err = db.Exec(` + CREATE VIRTUAL TABLE split_date USING statement(( + SELECT + strftime('%Y', :date) AS year, + strftime('%m', :date) AS month, + strftime('%d', :date) AS day + ))`) + if err != nil { + log.Fatal(err) + } + + stmt, _, err := db.Prepare(`SELECT * FROM split_date('2022-02-22')`) + if err != nil { + log.Fatal(err) + } + defer stmt.Close() + + if stmt.Step() { + fmt.Printf("Twosday was %d-%d-%d", stmt.ColumnInt(0), stmt.ColumnInt(1), stmt.ColumnInt(2)) + } + if err := stmt.Reset(); err != nil { + log.Fatal(err) + } + + // Output: + // Twosday was 2022-2-22 +} diff --git a/go.mod b/go.mod index b7d76d1..f78358d 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.21 require ( github.com/ncruces/julianday v1.0.0 github.com/psanford/httpreadat v0.1.0 - github.com/tetratelabs/wazero v1.5.0 + github.com/tetratelabs/wazero v1.5.1-0.20231130010520-d01ebff34db8 golang.org/x/sync v0.5.0 golang.org/x/sys v0.15.0 golang.org/x/text v0.14.0 diff --git a/go.sum b/go.sum index 5a08476..3aecc89 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g= github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE= github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE= -github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0= -github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A= +github.com/tetratelabs/wazero v1.5.1-0.20231130010520-d01ebff34db8 h1:0LUxnUU9dmbCq7fwC422Bx5ZjdzJJUfw66P1A0/DZcc= +github.com/tetratelabs/wazero v1.5.1-0.20231130010520-d01ebff34db8/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A= golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=