diff --git a/README.md b/README.md index dc874d2..1da5d16 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,10 @@ and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings. reads [comma-separated values](https://sqlite.org/csv.html). - [`github.com/ncruces/go-sqlite3/ext/lines`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/lines) reads files [line-by-line](https://github.com/asg017/sqlite-lines). +- [`github.com/ncruces/go-sqlite3/ext/pivot`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/pivot) + creates [pivot tables](https://github.com/jakethaw/pivot_vtab). +- [`github.com/ncruces/go-sqlite3/ext/statement`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/statement) + creates [table-valued functions with SQL](https://github.com/0x09/sqlite-statement-vtab). - [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats) provides [statistics functions](https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html). - [`github.com/ncruces/go-sqlite3/ext/unicode`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/unicode) diff --git a/conn.go b/conn.go index d637033..2e339ac 100644 --- a/conn.go +++ b/conn.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net/url" - "runtime" "strings" "github.com/ncruces/go-sqlite3/internal/util" @@ -56,8 +55,6 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) { defer func() { if conn == nil { sqlite.close() - } else { - runtime.SetFinalizer(conn, util.Finalizer[Conn](3)) } }() @@ -92,7 +89,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { for _, p := range query["_pragma"] { pragmas.WriteString(`PRAGMA `) pragmas.WriteString(p) - pragmas.WriteByte(';') + pragmas.WriteString(`;`) } } @@ -140,7 +137,6 @@ func (c *Conn) Close() error { } c.handle = 0 - runtime.SetFinalizer(c, nil) return c.close() } @@ -194,7 +190,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str if stmt.handle == 0 { return nil, "", nil } - return + return stmt, tail, nil } // GetAutocommit tests the connection for auto-commit mode. diff --git a/embed/exports.txt b/embed/exports.txt index ddd284d..0fefb37 100644 --- a/embed/exports.txt +++ b/embed/exports.txt @@ -83,6 +83,8 @@ sqlite3_user_data sqlite3_value_blob sqlite3_value_bytes sqlite3_value_double +sqlite3_value_dup +sqlite3_value_free sqlite3_value_int64 sqlite3_value_nochange sqlite3_value_pointer_go diff --git a/embed/sqlite3.wasm b/embed/sqlite3.wasm index c900bda..b65fcd1 100755 Binary files a/embed/sqlite3.wasm and b/embed/sqlite3.wasm differ diff --git a/error.go b/error.go index ca86b66..838f1aa 100644 --- a/error.go +++ b/error.go @@ -44,8 +44,7 @@ func (e *Error) Error() string { } if e.msg != "" { - b.WriteByte(':') - b.WriteByte(' ') + b.WriteString(": ") b.WriteString(e.msg) } diff --git a/ext/array/array.go b/ext/array/array.go index 928aa6b..41eced8 100644 --- a/ext/array/array.go +++ b/ext/array/array.go @@ -119,7 +119,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { return nil } -func indexable(v reflect.Value) (_ reflect.Value, err error) { +func indexable(v reflect.Value) (reflect.Value, error) { if v.Kind() == reflect.Slice { return v, nil } diff --git a/ext/csv/schema.go b/ext/csv/schema.go index 2ca807c..e3243e1 100644 --- a/ext/csv/schema.go +++ b/ext/csv/schema.go @@ -8,9 +8,9 @@ import ( ) func getSchema(header bool, columns int, row []string) string { - var sep = "" + var sep string var str strings.Builder - str.WriteString(`CREATE TABLE x(`) + str.WriteString("CREATE TABLE x(") if 0 <= columns && columns < len(row) { row = row[:columns] @@ -20,7 +20,7 @@ func getSchema(header bool, columns int, row []string) string { if header && f != "" { str.WriteString(sqlite3.QuoteIdentifier(f)) } else { - str.WriteByte('c') + str.WriteString("c") str.WriteString(strconv.Itoa(i + 1)) } str.WriteString(" TEXT") @@ -28,7 +28,7 @@ func getSchema(header bool, columns int, row []string) string { } for i := len(row); i < columns; i++ { str.WriteString(sep) - str.WriteByte('c') + str.WriteString("c") str.WriteString(strconv.Itoa(i + 1)) str.WriteString(" TEXT") sep = "," diff --git a/ext/pivot/pivot.go b/ext/pivot/pivot.go new file mode 100644 index 0000000..f9b20b6 --- /dev/null +++ b/ext/pivot/pivot.go @@ -0,0 +1,267 @@ +// Package pivot implements a pivot virtual table. +// +// https://github.com/jakethaw/pivot_vtab +package pivot + +import ( + "errors" + "fmt" + "strings" + + "github.com/ncruces/go-sqlite3" +) + +// Register registers the pivot virtual table. +func Register(db *sqlite3.Conn) { + sqlite3.CreateModule(db, "pivot", declare, declare) +} + +type table struct { + db *sqlite3.Conn + scan string + cell string + keys []string + cols []*sqlite3.Value +} + +func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) { + if len(arg) != 3 { + return nil, fmt.Errorf("pivot: wrong number of arguments") + } + + table := &table{db: db} + defer func() { + if err != nil { + table.Close() + } + }() + + var sep string + var create strings.Builder + create.WriteString("CREATE TABLE x(") + + // Row key query. + table.scan = "SELECT * FROM\n" + arg[0] + stmt, _, err := db.Prepare(table.scan) + if err != nil { + return nil, err + } + defer stmt.Close() + + table.keys = make([]string, stmt.ColumnCount()) + for i := range table.keys { + name := sqlite3.QuoteIdentifier(stmt.ColumnName(i)) + table.keys[i] = name + create.WriteString(sep) + create.WriteString(name) + sep = "," + } + stmt.Close() + + // Column definition query. + stmt, _, err = db.Prepare("SELECT * FROM\n" + arg[1]) + if err != nil { + return nil, err + } + + if stmt.ColumnCount() != 2 { + return nil, fmt.Errorf("pivot: column definition query expects 2 result columns") + } + for stmt.Step() { + name := sqlite3.QuoteIdentifier(stmt.ColumnText(1)) + table.cols = append(table.cols, stmt.ColumnValue(0).Dup()) + create.WriteString(",") + create.WriteString(name) + } + stmt.Close() + + // Pivot cell query. + table.cell = "SELECT * FROM\n" + arg[2] + stmt, _, err = db.Prepare(table.cell) + if err != nil { + return nil, err + } + + if stmt.ColumnCount() != 1 { + return nil, fmt.Errorf("pivot: cell query expects 1 result columns") + } + if stmt.BindCount() != len(table.keys)+1 { + return nil, fmt.Errorf("pivot: cell query expects %d bound parameters", len(table.keys)+1) + } + + create.WriteByte(')') + err = db.DeclareVtab(create.String()) + if err != nil { + return nil, err + } + return table, nil +} + +func (t *table) Close() error { + for i := range t.cols { + t.cols[i].Close() + } + return nil +} + +func (t *table) BestIndex(idx *sqlite3.IndexInfo) error { + var idxStr strings.Builder + idxStr.WriteString(t.scan) + + argvIndex := 1 + sep := " WHERE " + for i, cst := range idx.Constraint { + if !cst.Usable || !(0 <= cst.Column && cst.Column < len(t.keys)) { + continue + } + + var op string + switch cst.Op { + case sqlite3.INDEX_CONSTRAINT_EQ: + op = "=" + case sqlite3.INDEX_CONSTRAINT_LT: + op = "<" + case sqlite3.INDEX_CONSTRAINT_GT: + op = ">" + case sqlite3.INDEX_CONSTRAINT_LE: + op = "<=" + case sqlite3.INDEX_CONSTRAINT_GE: + op = ">=" + case sqlite3.INDEX_CONSTRAINT_NE: + op = "<>" + case sqlite3.INDEX_CONSTRAINT_MATCH: + op = "MATCH" + case sqlite3.INDEX_CONSTRAINT_LIKE: + op = "LIKE" + case sqlite3.INDEX_CONSTRAINT_GLOB: + op = "GLOB" + case sqlite3.INDEX_CONSTRAINT_REGEXP: + op = "REGEXP" + case sqlite3.INDEX_CONSTRAINT_IS, sqlite3.INDEX_CONSTRAINT_ISNULL: + op = "IS" + case sqlite3.INDEX_CONSTRAINT_ISNOT, sqlite3.INDEX_CONSTRAINT_ISNOTNULL: + op = "IS NOT" + default: + continue + } + + idxStr.WriteString(sep) + idxStr.WriteString(t.keys[cst.Column]) + idxStr.WriteString(" ") + idxStr.WriteString(op) + idxStr.WriteString(" ?") + + idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{ + ArgvIndex: argvIndex, + Omit: true, + } + sep = " AND " + argvIndex++ + } + + sep = " ORDER BY " + idx.OrderByConsumed = true + for _, ord := range idx.OrderBy { + if !(0 <= ord.Column && ord.Column < len(t.keys)) { + idx.OrderByConsumed = false + continue + } + idxStr.WriteString(sep) + idxStr.WriteString(t.keys[ord.Column]) + if ord.Desc { + idxStr.WriteString(" DESC") + } + sep = "," + } + + idx.EstimatedCost = 1e9 / float64(argvIndex) + idx.IdxStr = idxStr.String() + return nil +} + +func (t *table) Open() (sqlite3.VTabCursor, error) { + return &cursor{table: t}, nil +} + +func (t *table) Rename(new string) error { + return nil +} + +type cursor struct { + table *table + scan *sqlite3.Stmt + cell *sqlite3.Stmt + rowID int64 +} + +func (c *cursor) Close() error { + return errors.Join(c.scan.Close(), c.cell.Close()) +} + +func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { + err := c.scan.Close() + if err != nil { + return err + } + + c.scan, _, err = c.table.db.Prepare(idxStr) + if err != nil { + return err + } + for i, arg := range arg { + err := c.scan.BindValue(i+1, arg) + if err != nil { + return err + } + } + + if c.cell == nil { + c.cell, _, err = c.table.db.Prepare(c.table.cell) + if err != nil { + return err + } + } + + c.rowID = 0 + return c.Next() +} + +func (c *cursor) Next() error { + if c.scan.Step() { + count := c.scan.ColumnCount() + for i := 0; i < count; i++ { + err := c.cell.BindValue(i+1, c.scan.ColumnValue(i)) + if err != nil { + return err + } + } + c.rowID++ + } + return c.scan.Err() +} + +func (c *cursor) EOF() bool { + return !c.scan.Busy() +} + +func (c *cursor) RowID() (int64, error) { + return c.rowID, nil +} + +func (c *cursor) Column(ctx *sqlite3.Context, col int) error { + count := c.scan.ColumnCount() + if col < count { + ctx.ResultValue(c.scan.ColumnValue(col)) + return nil + } + + err := c.cell.BindValue(count+1, *c.table.cols[col-count]) + if err != nil { + return err + } + + if c.cell.Step() { + ctx.ResultValue(c.cell.ColumnValue(0)) + } + return c.cell.Reset() +} diff --git a/ext/pivot/pivot_test.go b/ext/pivot/pivot_test.go new file mode 100644 index 0000000..97970eb --- /dev/null +++ b/ext/pivot/pivot_test.go @@ -0,0 +1,219 @@ +package pivot_test + +import ( + "fmt" + "log" + "strings" + "testing" + + "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/embed" + "github.com/ncruces/go-sqlite3/ext/pivot" +) + +// https://antonz.org/sqlite-pivot-table/ +func Example() { + db, err := sqlite3.Open(":memory:") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + pivot.Register(db) + + err = db.Exec(` + CREATE TABLE sales(product TEXT, year INT, income DECIMAL); + INSERT INTO sales(product, year, income) VALUES + ('alpha', 2020, 100), + ('alpha', 2021, 120), + ('alpha', 2022, 130), + ('alpha', 2023, 140), + ('beta', 2020, 10), + ('beta', 2021, 20), + ('beta', 2022, 40), + ('beta', 2023, 80), + ('gamma', 2020, 80), + ('gamma', 2021, 75), + ('gamma', 2022, 78), + ('gamma', 2023, 80); + `) + if err != nil { + log.Fatal(err) + } + + err = db.Exec(` + CREATE VIRTUAL TABLE v_sales USING pivot( + -- rows + (SELECT DISTINCT product FROM sales), + -- columns + (SELECT DISTINCT year, year FROM sales), + -- cells + (SELECT sum(income) FROM sales WHERE product = ? AND year = ?) + )`) + if err != nil { + log.Fatal(err) + } + + stmt, _, err := db.Prepare(`SELECT * FROM v_sales`) + if err != nil { + log.Fatal(err) + } + defer stmt.Close() + + cols := make([]string, stmt.ColumnCount()) + for i := range cols { + cols[i] = stmt.ColumnName(i) + } + fmt.Println(pretty(cols)) + for stmt.Step() { + for i := range cols { + cols[i] = stmt.ColumnText(i) + } + fmt.Println(pretty(cols)) + } + if err := stmt.Reset(); err != nil { + log.Fatal(err) + } + + // Output: + // product 2020 2021 2022 2023 + // alpha 100 120 130 140 + // beta 10 20 40 80 + // gamma 80 75 78 80 +} + +func TestRegister(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + pivot.Register(db) + + err = db.Exec(` + CREATE TABLE r AS + SELECT 1 id UNION SELECT 2 UNION SELECT 3; + + CREATE TABLE c( + id INTEGER PRIMARY KEY, + name TEXT + ); + INSERT INTO c (name) VALUES + ('a'),('b'),('c'),('d'); + + CREATE TABLE x( + r_id INT, + c_id INT, + val TEXT + ); + INSERT INTO x (r_id, c_id, val) + SELECT r.id, c.id, c.name || r.id + FROM c, r; + `) + if err != nil { + t.Fatal(err) + } + + err = db.Exec(` + CREATE VIRTUAL TABLE v_x USING pivot( + -- rows + (SELECT id r_id FROM r), + -- columns + (SELECT id c_id, name FROM c), + -- cells + (SELECT val FROM x WHERE r_id = ?1 AND c_id = ?2) + )`) + if err != nil { + t.Fatal(err) + } + + stmt, _, err := db.Prepare(`SELECT * FROM v_x WHERE rowid <> 0 AND r_id <> 1 ORDER BY rowid, r_id DESC LIMIT 1`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if stmt.Step() { + if got := stmt.ColumnInt(0); got != 3 { + t.Errorf("got %d, want 3", got) + } + } +} + +func TestRegister_errors(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + pivot.Register(db) + + err = db.Exec(`CREATE VIRTUAL TABLE pivot USING pivot()`) + if err == nil { + t.Fatal("want error") + } else { + t.Log(err) + } + + err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot(SELECT 1, SELECT 2, SELECT 3)`) + if err == nil { + t.Fatal("want error") + } else { + t.Log(err) + } + + err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot((SELECT 1), SELECT 2, SELECT 3)`) + if err == nil { + t.Fatal("want error") + } else { + t.Log(err) + } + + err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot((SELECT 1), (SELECT 2), SELECT 3)`) + if err == nil { + t.Fatal("want error") + } else { + t.Log(err) + } + + err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot((SELECT 1), (SELECT 1, 2), SELECT 3)`) + if err == nil { + t.Fatal("want error") + } else { + t.Log(err) + } + + err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot((SELECT 1), (SELECT 1, 2), (SELECT 3, 4))`) + if err == nil { + t.Fatal("want error") + } else { + t.Log(err) + } + + err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot((SELECT 1), (SELECT 1, 2), (SELECT 3))`) + if err == nil { + t.Fatal("want error") + } else { + t.Log(err) + } +} + +func pretty(cols []string) string { + var buf strings.Builder + for i, s := range cols { + if i != 0 { + buf.WriteByte(' ') + } + for buf.Len()%8 != 0 { + buf.WriteByte(' ') + } + buf.WriteString(s) + } + return buf.String() +} diff --git a/ext/statement/stmt.go b/ext/statement/stmt.go index 951280a..77289c0 100644 --- a/ext/statement/stmt.go +++ b/ext/statement/stmt.go @@ -15,55 +15,6 @@ import ( // Register registers the statement virtual table. func Register(db *sqlite3.Conn) { - declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (*table, error) { - if len(arg) != 1 { - return nil, fmt.Errorf("statement: wrong number of arguments") - } - - sql := "SELECT * FROM\n" + arg[0] - - stmt, _, err := db.Prepare(sql) - if err != nil { - return nil, err - } - - var sep = "" - var str strings.Builder - str.WriteString(`CREATE TABLE x(`) - outputs := stmt.ColumnCount() - for i := 0; i < outputs; i++ { - name := sqlite3.QuoteIdentifier(stmt.ColumnName(i)) - str.WriteString(sep) - str.WriteString(name) - str.WriteByte(' ') - str.WriteString(stmt.ColumnDeclType(i)) - sep = "," - } - inputs := stmt.BindCount() - for i := 1; i <= inputs; i++ { - str.WriteString(sep) - name := stmt.BindName(i) - if name == "" { - str.WriteString("[") - str.WriteString(strconv.Itoa(i)) - str.WriteString("] HIDDEN") - } else { - str.WriteString(sqlite3.QuoteIdentifier(name[1:])) - str.WriteString(" HIDDEN") - } - sep = "," - } - str.WriteByte(')') - - err = db.DeclareVtab(str.String()) - if err != nil { - stmt.Close() - return nil, err - } - - return &table{sql: sql, stmt: stmt}, nil - } - sqlite3.CreateModule(db, "statement", declare, declare) } @@ -73,6 +24,55 @@ type table struct { inuse bool } +func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (*table, error) { + if len(arg) != 1 { + return nil, fmt.Errorf("statement: wrong number of arguments") + } + + sql := "SELECT * FROM\n" + arg[0] + + stmt, _, err := db.Prepare(sql) + if err != nil { + return nil, err + } + + var sep string + var str strings.Builder + str.WriteString("CREATE TABLE x(") + outputs := stmt.ColumnCount() + for i := 0; i < outputs; i++ { + name := sqlite3.QuoteIdentifier(stmt.ColumnName(i)) + str.WriteString(sep) + str.WriteString(name) + str.WriteString(" ") + str.WriteString(stmt.ColumnDeclType(i)) + sep = "," + } + inputs := stmt.BindCount() + for i := 1; i <= inputs; i++ { + str.WriteString(sep) + name := stmt.BindName(i) + if name == "" { + str.WriteString("[") + str.WriteString(strconv.Itoa(i)) + str.WriteString("] HIDDEN") + } else { + str.WriteString(sqlite3.QuoteIdentifier(name[1:])) + str.WriteString(" HIDDEN") + } + sep = "," + } + str.WriteByte(')') + + err = db.DeclareVtab(str.String()) + if err != nil { + stmt.Close() + return nil, err + } + + return &table{sql: sql, stmt: stmt}, nil +} + func (t *table) Close() error { return t.stmt.Close() } @@ -120,11 +120,12 @@ func (t *table) BestIndex(idx *sqlite3.IndexInfo) error { return nil } -func (t *table) Open() (_ sqlite3.VTabCursor, err error) { +func (t *table) Open() (sqlite3.VTabCursor, error) { stmt := t.stmt if !t.inuse { t.inuse = true } else { + var err error stmt, _, err = t.stmt.Conn().Prepare(t.sql) if err != nil { return nil, err @@ -186,7 +187,6 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { func (c *cursor) Next() error { if c.stmt.Step() { c.rowID++ - return nil } return c.stmt.Err() } diff --git a/ext/statement/stmt_test.go b/ext/statement/stmt_test.go index 208edcf..9e9d6a9 100644 --- a/ext/statement/stmt_test.go +++ b/ext/statement/stmt_test.go @@ -95,7 +95,6 @@ func TestRegister(t *testing.T) { t.Errorf("hypot(%d, %d) = %d", x, y, hypot) } } - } func TestRegister_errors(t *testing.T) { diff --git a/internal/util/error.go b/internal/util/error.go index 20f80ac..1f5555f 100644 --- a/internal/util/error.go +++ b/internal/util/error.go @@ -1,7 +1,6 @@ package util import ( - "fmt" "runtime" "strconv" ) @@ -34,14 +33,6 @@ func AssertErr() ErrorString { return ErrorString(msg) } -func Finalizer[T any](skip int) func(*T) { - msg := fmt.Sprintf("sqlite3: %T not closed", new(T)) - if _, file, line, ok := runtime.Caller(skip + 1); ok && skip >= 0 { - msg += " (" + file + ":" + strconv.Itoa(line) + ")" - } - return func(*T) { panic(ErrorString(msg)) } -} - func ErrorCodeString(rc uint32) string { switch rc { case ABORT_ROLLBACK: diff --git a/stmt.go b/stmt.go index 85ac447..e9b80ae 100644 --- a/stmt.go +++ b/stmt.go @@ -81,6 +81,7 @@ func (s *Stmt) Step() bool { r := s.c.call("sqlite3_step", uint64(s.handle)) switch r { case _ROW: + s.err = nil return true case _DONE: s.err = nil diff --git a/value.go b/value.go index 0426aad..e180eb4 100644 --- a/value.go +++ b/value.go @@ -16,6 +16,7 @@ type Value struct { *sqlite handle uint32 unprot bool + copied bool } func (v Value) protected() uint64 { @@ -25,6 +26,30 @@ func (v Value) protected() uint64 { return uint64(v.handle) } +// Dup makes a copy of the SQL value and returns a pointer to that copy. +// +// https://sqlite.org/c3ref/value_dup.html +func (v Value) Dup() *Value { + r := v.call("sqlite3_value_dup", uint64(v.handle)) + return &Value{ + copied: true, + sqlite: v.sqlite, + handle: uint32(r), + } +} + +// Close frees an SQL value previously obtained by [Value.Dup]. +// +// https://sqlite.org/c3ref/value_dup.html +func (dup *Value) Close() error { + if !dup.copied { + panic(util.ValueErr) + } + dup.call("sqlite3_value_free", uint64(dup.handle)) + dup.handle = 0 + return nil +} + // Type returns the initial [Datatype] of the value. // // https://sqlite.org/c3ref/value_blob.html