Statement virtual table.

This commit is contained in:
Nuno Cruces
2023-11-30 01:13:18 +00:00
parent 4160b9a4bb
commit b262f5cd01
4 changed files with 269 additions and 3 deletions

215
ext/statement/stmt.go Normal file
View File

@@ -0,0 +1,215 @@
// Package statement defines virtual tables and table-valued functions natively using SQL.
//
// https://github.com/0x09/sqlite-statement-vtab
package statement
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"unsafe"
"github.com/ncruces/go-sqlite3"
)
// Register registers the statement virtual table.
func Register(db *sqlite3.Conn) {
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
if arg == nil || len(arg[0]) < 3 {
return nil, fmt.Errorf("statement: no statement provided")
}
sql := arg[0]
if len := len(sql); sql[0] != '(' || sql[len-1] != ')' {
return nil, fmt.Errorf("statement: statement must be parenthesized")
} else {
sql = sql[1 : len-1]
}
table := &table{
db: db,
sql: sql,
}
err = table.declare()
if err != nil {
return nil, err
}
return table, nil
}
sqlite3.CreateModule(db, "statement", declare, declare)
}
type table struct {
db *sqlite3.Conn
sql string
inputs int
outputs int
}
func (t *table) declare() error {
stmt, tail, err := t.db.Prepare(t.sql)
if err != nil {
return err
}
defer stmt.Close()
if tail != "" {
return fmt.Errorf("statement: multiple statements")
}
// TODO: sqlite3_stmt_readonly
t.inputs = stmt.BindCount()
t.outputs = stmt.ColumnCount()
var sep = ""
var str strings.Builder
str.WriteString(`CREATE TABLE x(`)
for i := 0; i < t.outputs; i++ {
str.WriteString(sep)
name := stmt.ColumnName(i)
str.WriteString(sqlite3.QuoteIdentifier(name))
// TODO: sqlite3_column_decltype
sep = ","
}
for i := 1; i <= t.inputs; i++ {
str.WriteString(sep)
name := stmt.BindName(i)
if name == "" {
str.WriteByte('\'')
str.WriteString(strconv.Itoa(i))
str.WriteString("' HIDDEN")
} else {
str.WriteString(sqlite3.QuoteIdentifier(name))
str.WriteString(" HIDDEN")
}
sep = ","
}
str.WriteByte(')')
return t.db.DeclareVtab(str.String())
}
func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
idx.OrderByConsumed = false
idx.EstimatedCost = 1
idx.EstimatedRows = 1
var argvIndex = 1
var needIndex bool
var listIndex []int
for i, cst := range idx.Constraint {
// Skip if this is a constraint on one of our output columns.
if cst.Column < t.outputs {
continue
}
// A given query plan is only usable if all provided input columns
// are usable and have equal constraints only.
if !cst.Usable || cst.Op != sqlite3.INDEX_CONSTRAINT_EQ {
return sqlite3.CONSTRAINT
}
// The non-zero argvIdx values must be contiguous.
// If they're not, build a list and serialize it through IdxStr.
nextIndex := cst.Column - t.outputs + 1
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
ArgvIndex: argvIndex,
Omit: true,
}
if nextIndex != argvIndex {
needIndex = true
}
listIndex = append(listIndex, nextIndex)
argvIndex++
}
if needIndex {
buf, err := json.Marshal(listIndex)
if err != nil {
return err
}
idx.IdxStr = unsafe.String(&buf[0], len(buf))
}
return nil
}
func (t *table) Open() (sqlite3.VTabCursor, error) {
stmt, _, err := t.db.Prepare(t.sql)
if err != nil {
return nil, err
}
return &cursor{table: t, stmt: stmt}, nil
}
func (t *table) Rename(new string) error {
return nil
}
type cursor struct {
table *table
stmt *sqlite3.Stmt
arg []sqlite3.Value
rowID int64
done bool
}
func (c *cursor) Close() error {
return c.stmt.Close()
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
c.arg = arg
c.rowID = 0
if err := c.stmt.ClearBindings(); err != nil {
return err
}
if err := c.stmt.Reset(); err != nil {
return err
}
var list []int
if idxStr != "" {
err := json.Unmarshal([]byte(idxStr), &list)
if err != nil {
return err
}
}
for i, arg := range arg {
param := i + 1
if list != nil {
param = list[i]
}
err := c.stmt.BindValue(param, arg)
if err != nil {
return err
}
}
return c.Next()
}
func (c *cursor) Next() error {
if c.stmt.Step() {
c.rowID++
return nil
}
c.done = true
return c.stmt.Err()
}
func (c *cursor) EOF() bool {
return c.done
}
func (c *cursor) RowID() (int64, error) {
return c.rowID, nil
}
func (c *cursor) Column(ctx *sqlite3.Context, col int) error {
if col < c.table.outputs {
ctx.ResultValue(c.stmt.ColumnValue(col))
} else if col-c.table.outputs < len(c.arg) {
ctx.ResultValue(c.arg[col-c.table.outputs])
}
return nil
}

View File

@@ -0,0 +1,51 @@
package statement_test
import (
"fmt"
"log"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/statement"
"github.com/tetratelabs/wazero"
)
func Example() {
// This crashes the compiler.
sqlite3.RuntimeConfig = wazero.NewRuntimeConfigInterpreter()
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
statement.Register(db)
err = db.Exec(`
CREATE VIRTUAL TABLE split_date USING statement((
SELECT
strftime('%Y', :date) AS year,
strftime('%m', :date) AS month,
strftime('%d', :date) AS day
))`)
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT * FROM split_date('2022-02-22')`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
fmt.Printf("Twosday was %d-%d-%d", stmt.ColumnInt(0), stmt.ColumnInt(1), stmt.ColumnInt(2))
}
if err := stmt.Reset(); err != nil {
log.Fatal(err)
}
// Output:
// Twosday was 2022-2-22
}