Pivot virtual table.

This commit is contained in:
Nuno Cruces
2023-12-06 15:39:26 +00:00
parent 8b45cac16b
commit 089a0c0670
14 changed files with 577 additions and 74 deletions

View File

@@ -119,7 +119,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
return nil
}
func indexable(v reflect.Value) (_ reflect.Value, err error) {
func indexable(v reflect.Value) (reflect.Value, error) {
if v.Kind() == reflect.Slice {
return v, nil
}

View File

@@ -8,9 +8,9 @@ import (
)
func getSchema(header bool, columns int, row []string) string {
var sep = ""
var sep string
var str strings.Builder
str.WriteString(`CREATE TABLE x(`)
str.WriteString("CREATE TABLE x(")
if 0 <= columns && columns < len(row) {
row = row[:columns]
@@ -20,7 +20,7 @@ func getSchema(header bool, columns int, row []string) string {
if header && f != "" {
str.WriteString(sqlite3.QuoteIdentifier(f))
} else {
str.WriteByte('c')
str.WriteString("c")
str.WriteString(strconv.Itoa(i + 1))
}
str.WriteString(" TEXT")
@@ -28,7 +28,7 @@ func getSchema(header bool, columns int, row []string) string {
}
for i := len(row); i < columns; i++ {
str.WriteString(sep)
str.WriteByte('c')
str.WriteString("c")
str.WriteString(strconv.Itoa(i + 1))
str.WriteString(" TEXT")
sep = ","

267
ext/pivot/pivot.go Normal file
View File

@@ -0,0 +1,267 @@
// Package pivot implements a pivot virtual table.
//
// https://github.com/jakethaw/pivot_vtab
package pivot
import (
"errors"
"fmt"
"strings"
"github.com/ncruces/go-sqlite3"
)
// Register registers the pivot virtual table.
func Register(db *sqlite3.Conn) {
sqlite3.CreateModule(db, "pivot", declare, declare)
}
type table struct {
db *sqlite3.Conn
scan string
cell string
keys []string
cols []*sqlite3.Value
}
func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
if len(arg) != 3 {
return nil, fmt.Errorf("pivot: wrong number of arguments")
}
table := &table{db: db}
defer func() {
if err != nil {
table.Close()
}
}()
var sep string
var create strings.Builder
create.WriteString("CREATE TABLE x(")
// Row key query.
table.scan = "SELECT * FROM\n" + arg[0]
stmt, _, err := db.Prepare(table.scan)
if err != nil {
return nil, err
}
defer stmt.Close()
table.keys = make([]string, stmt.ColumnCount())
for i := range table.keys {
name := sqlite3.QuoteIdentifier(stmt.ColumnName(i))
table.keys[i] = name
create.WriteString(sep)
create.WriteString(name)
sep = ","
}
stmt.Close()
// Column definition query.
stmt, _, err = db.Prepare("SELECT * FROM\n" + arg[1])
if err != nil {
return nil, err
}
if stmt.ColumnCount() != 2 {
return nil, fmt.Errorf("pivot: column definition query expects 2 result columns")
}
for stmt.Step() {
name := sqlite3.QuoteIdentifier(stmt.ColumnText(1))
table.cols = append(table.cols, stmt.ColumnValue(0).Dup())
create.WriteString(",")
create.WriteString(name)
}
stmt.Close()
// Pivot cell query.
table.cell = "SELECT * FROM\n" + arg[2]
stmt, _, err = db.Prepare(table.cell)
if err != nil {
return nil, err
}
if stmt.ColumnCount() != 1 {
return nil, fmt.Errorf("pivot: cell query expects 1 result columns")
}
if stmt.BindCount() != len(table.keys)+1 {
return nil, fmt.Errorf("pivot: cell query expects %d bound parameters", len(table.keys)+1)
}
create.WriteByte(')')
err = db.DeclareVtab(create.String())
if err != nil {
return nil, err
}
return table, nil
}
func (t *table) Close() error {
for i := range t.cols {
t.cols[i].Close()
}
return nil
}
func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
var idxStr strings.Builder
idxStr.WriteString(t.scan)
argvIndex := 1
sep := " WHERE "
for i, cst := range idx.Constraint {
if !cst.Usable || !(0 <= cst.Column && cst.Column < len(t.keys)) {
continue
}
var op string
switch cst.Op {
case sqlite3.INDEX_CONSTRAINT_EQ:
op = "="
case sqlite3.INDEX_CONSTRAINT_LT:
op = "<"
case sqlite3.INDEX_CONSTRAINT_GT:
op = ">"
case sqlite3.INDEX_CONSTRAINT_LE:
op = "<="
case sqlite3.INDEX_CONSTRAINT_GE:
op = ">="
case sqlite3.INDEX_CONSTRAINT_NE:
op = "<>"
case sqlite3.INDEX_CONSTRAINT_MATCH:
op = "MATCH"
case sqlite3.INDEX_CONSTRAINT_LIKE:
op = "LIKE"
case sqlite3.INDEX_CONSTRAINT_GLOB:
op = "GLOB"
case sqlite3.INDEX_CONSTRAINT_REGEXP:
op = "REGEXP"
case sqlite3.INDEX_CONSTRAINT_IS, sqlite3.INDEX_CONSTRAINT_ISNULL:
op = "IS"
case sqlite3.INDEX_CONSTRAINT_ISNOT, sqlite3.INDEX_CONSTRAINT_ISNOTNULL:
op = "IS NOT"
default:
continue
}
idxStr.WriteString(sep)
idxStr.WriteString(t.keys[cst.Column])
idxStr.WriteString(" ")
idxStr.WriteString(op)
idxStr.WriteString(" ?")
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
ArgvIndex: argvIndex,
Omit: true,
}
sep = " AND "
argvIndex++
}
sep = " ORDER BY "
idx.OrderByConsumed = true
for _, ord := range idx.OrderBy {
if !(0 <= ord.Column && ord.Column < len(t.keys)) {
idx.OrderByConsumed = false
continue
}
idxStr.WriteString(sep)
idxStr.WriteString(t.keys[ord.Column])
if ord.Desc {
idxStr.WriteString(" DESC")
}
sep = ","
}
idx.EstimatedCost = 1e9 / float64(argvIndex)
idx.IdxStr = idxStr.String()
return nil
}
func (t *table) Open() (sqlite3.VTabCursor, error) {
return &cursor{table: t}, nil
}
func (t *table) Rename(new string) error {
return nil
}
type cursor struct {
table *table
scan *sqlite3.Stmt
cell *sqlite3.Stmt
rowID int64
}
func (c *cursor) Close() error {
return errors.Join(c.scan.Close(), c.cell.Close())
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
err := c.scan.Close()
if err != nil {
return err
}
c.scan, _, err = c.table.db.Prepare(idxStr)
if err != nil {
return err
}
for i, arg := range arg {
err := c.scan.BindValue(i+1, arg)
if err != nil {
return err
}
}
if c.cell == nil {
c.cell, _, err = c.table.db.Prepare(c.table.cell)
if err != nil {
return err
}
}
c.rowID = 0
return c.Next()
}
func (c *cursor) Next() error {
if c.scan.Step() {
count := c.scan.ColumnCount()
for i := 0; i < count; i++ {
err := c.cell.BindValue(i+1, c.scan.ColumnValue(i))
if err != nil {
return err
}
}
c.rowID++
}
return c.scan.Err()
}
func (c *cursor) EOF() bool {
return !c.scan.Busy()
}
func (c *cursor) RowID() (int64, error) {
return c.rowID, nil
}
func (c *cursor) Column(ctx *sqlite3.Context, col int) error {
count := c.scan.ColumnCount()
if col < count {
ctx.ResultValue(c.scan.ColumnValue(col))
return nil
}
err := c.cell.BindValue(count+1, *c.table.cols[col-count])
if err != nil {
return err
}
if c.cell.Step() {
ctx.ResultValue(c.cell.ColumnValue(0))
}
return c.cell.Reset()
}

219
ext/pivot/pivot_test.go Normal file
View File

@@ -0,0 +1,219 @@
package pivot_test
import (
"fmt"
"log"
"strings"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/pivot"
)
// https://antonz.org/sqlite-pivot-table/
func Example() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
pivot.Register(db)
err = db.Exec(`
CREATE TABLE sales(product TEXT, year INT, income DECIMAL);
INSERT INTO sales(product, year, income) VALUES
('alpha', 2020, 100),
('alpha', 2021, 120),
('alpha', 2022, 130),
('alpha', 2023, 140),
('beta', 2020, 10),
('beta', 2021, 20),
('beta', 2022, 40),
('beta', 2023, 80),
('gamma', 2020, 80),
('gamma', 2021, 75),
('gamma', 2022, 78),
('gamma', 2023, 80);
`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`
CREATE VIRTUAL TABLE v_sales USING pivot(
-- rows
(SELECT DISTINCT product FROM sales),
-- columns
(SELECT DISTINCT year, year FROM sales),
-- cells
(SELECT sum(income) FROM sales WHERE product = ? AND year = ?)
)`)
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT * FROM v_sales`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
cols := make([]string, stmt.ColumnCount())
for i := range cols {
cols[i] = stmt.ColumnName(i)
}
fmt.Println(pretty(cols))
for stmt.Step() {
for i := range cols {
cols[i] = stmt.ColumnText(i)
}
fmt.Println(pretty(cols))
}
if err := stmt.Reset(); err != nil {
log.Fatal(err)
}
// Output:
// product 2020 2021 2022 2023
// alpha 100 120 130 140
// beta 10 20 40 80
// gamma 80 75 78 80
}
func TestRegister(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
pivot.Register(db)
err = db.Exec(`
CREATE TABLE r AS
SELECT 1 id UNION SELECT 2 UNION SELECT 3;
CREATE TABLE c(
id INTEGER PRIMARY KEY,
name TEXT
);
INSERT INTO c (name) VALUES
('a'),('b'),('c'),('d');
CREATE TABLE x(
r_id INT,
c_id INT,
val TEXT
);
INSERT INTO x (r_id, c_id, val)
SELECT r.id, c.id, c.name || r.id
FROM c, r;
`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`
CREATE VIRTUAL TABLE v_x USING pivot(
-- rows
(SELECT id r_id FROM r),
-- columns
(SELECT id c_id, name FROM c),
-- cells
(SELECT val FROM x WHERE r_id = ?1 AND c_id = ?2)
)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT * FROM v_x WHERE rowid <> 0 AND r_id <> 1 ORDER BY rowid, r_id DESC LIMIT 1`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnInt(0); got != 3 {
t.Errorf("got %d, want 3", got)
}
}
}
func TestRegister_errors(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
pivot.Register(db)
err = db.Exec(`CREATE VIRTUAL TABLE pivot USING pivot()`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot(SELECT 1, SELECT 2, SELECT 3)`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot((SELECT 1), SELECT 2, SELECT 3)`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot((SELECT 1), (SELECT 2), SELECT 3)`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot((SELECT 1), (SELECT 1, 2), SELECT 3)`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot((SELECT 1), (SELECT 1, 2), (SELECT 3, 4))`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot((SELECT 1), (SELECT 1, 2), (SELECT 3))`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
}
func pretty(cols []string) string {
var buf strings.Builder
for i, s := range cols {
if i != 0 {
buf.WriteByte(' ')
}
for buf.Len()%8 != 0 {
buf.WriteByte(' ')
}
buf.WriteString(s)
}
return buf.String()
}

View File

@@ -15,55 +15,6 @@ import (
// Register registers the statement virtual table.
func Register(db *sqlite3.Conn) {
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (*table, error) {
if len(arg) != 1 {
return nil, fmt.Errorf("statement: wrong number of arguments")
}
sql := "SELECT * FROM\n" + arg[0]
stmt, _, err := db.Prepare(sql)
if err != nil {
return nil, err
}
var sep = ""
var str strings.Builder
str.WriteString(`CREATE TABLE x(`)
outputs := stmt.ColumnCount()
for i := 0; i < outputs; i++ {
name := sqlite3.QuoteIdentifier(stmt.ColumnName(i))
str.WriteString(sep)
str.WriteString(name)
str.WriteByte(' ')
str.WriteString(stmt.ColumnDeclType(i))
sep = ","
}
inputs := stmt.BindCount()
for i := 1; i <= inputs; i++ {
str.WriteString(sep)
name := stmt.BindName(i)
if name == "" {
str.WriteString("[")
str.WriteString(strconv.Itoa(i))
str.WriteString("] HIDDEN")
} else {
str.WriteString(sqlite3.QuoteIdentifier(name[1:]))
str.WriteString(" HIDDEN")
}
sep = ","
}
str.WriteByte(')')
err = db.DeclareVtab(str.String())
if err != nil {
stmt.Close()
return nil, err
}
return &table{sql: sql, stmt: stmt}, nil
}
sqlite3.CreateModule(db, "statement", declare, declare)
}
@@ -73,6 +24,55 @@ type table struct {
inuse bool
}
func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (*table, error) {
if len(arg) != 1 {
return nil, fmt.Errorf("statement: wrong number of arguments")
}
sql := "SELECT * FROM\n" + arg[0]
stmt, _, err := db.Prepare(sql)
if err != nil {
return nil, err
}
var sep string
var str strings.Builder
str.WriteString("CREATE TABLE x(")
outputs := stmt.ColumnCount()
for i := 0; i < outputs; i++ {
name := sqlite3.QuoteIdentifier(stmt.ColumnName(i))
str.WriteString(sep)
str.WriteString(name)
str.WriteString(" ")
str.WriteString(stmt.ColumnDeclType(i))
sep = ","
}
inputs := stmt.BindCount()
for i := 1; i <= inputs; i++ {
str.WriteString(sep)
name := stmt.BindName(i)
if name == "" {
str.WriteString("[")
str.WriteString(strconv.Itoa(i))
str.WriteString("] HIDDEN")
} else {
str.WriteString(sqlite3.QuoteIdentifier(name[1:]))
str.WriteString(" HIDDEN")
}
sep = ","
}
str.WriteByte(')')
err = db.DeclareVtab(str.String())
if err != nil {
stmt.Close()
return nil, err
}
return &table{sql: sql, stmt: stmt}, nil
}
func (t *table) Close() error {
return t.stmt.Close()
}
@@ -120,11 +120,12 @@ func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
return nil
}
func (t *table) Open() (_ sqlite3.VTabCursor, err error) {
func (t *table) Open() (sqlite3.VTabCursor, error) {
stmt := t.stmt
if !t.inuse {
t.inuse = true
} else {
var err error
stmt, _, err = t.stmt.Conn().Prepare(t.sql)
if err != nil {
return nil, err
@@ -186,7 +187,6 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
func (c *cursor) Next() error {
if c.stmt.Step() {
c.rowID++
return nil
}
return c.stmt.Err()
}

View File

@@ -95,7 +95,6 @@ func TestRegister(t *testing.T) {
t.Errorf("hypot(%d, %d) = %d", x, y, hypot)
}
}
}
func TestRegister_errors(t *testing.T) {