mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Transitive closure virtual table.
This commit is contained in:
263
ext/closure/closure.go
Normal file
263
ext/closure/closure.go
Normal file
@@ -0,0 +1,263 @@
|
||||
// Package closure provides a transitive closure virtual table.
|
||||
//
|
||||
// The "transitive_closure" virtual table finds the transitive closure of
|
||||
// a parent/child relationship in a real table.
|
||||
//
|
||||
// https://sqlite.org/src/doc/tip/ext/misc/closure.c
|
||||
package closure
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/ncruces/go-sqlite3/util/vtabutil"
|
||||
)
|
||||
|
||||
const (
|
||||
_COL_ID = 0
|
||||
_COL_DEPTH = 1
|
||||
_COL_ROOT = 2
|
||||
_COL_TABLENAME = 3
|
||||
_COL_IDCOLUMN = 4
|
||||
_COL_PARENTCOLUMN = 5
|
||||
)
|
||||
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
return sqlite3.CreateModule(db, "transitive_closure", nil,
|
||||
func(db *sqlite3.Conn, _, _, _ string, arg ...string) (*closure, error) {
|
||||
var (
|
||||
table string
|
||||
column string
|
||||
parent string
|
||||
|
||||
done = util.Set[string]{}
|
||||
)
|
||||
|
||||
for _, arg := range arg {
|
||||
key, val := vtabutil.NamedArg(arg)
|
||||
if done.Contains(key) {
|
||||
return nil, fmt.Errorf("transitive_closure: more than one %q parameter", key)
|
||||
}
|
||||
switch key {
|
||||
case "tablename":
|
||||
table = vtabutil.Unquote(val)
|
||||
case "idcolumn":
|
||||
column = vtabutil.Unquote(val)
|
||||
case "parentcolumn":
|
||||
parent = vtabutil.Unquote(val)
|
||||
default:
|
||||
return nil, fmt.Errorf("transitive_closure: unknown %q parameter", key)
|
||||
}
|
||||
done.Add(key)
|
||||
}
|
||||
|
||||
err := db.DeclareVTab(`CREATE TABLE x(id,depth,root HIDDEN,tablename HIDDEN,idcolumn HIDDEN,parentcolumn HIDDEN)`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &closure{
|
||||
db: db,
|
||||
table: table,
|
||||
column: column,
|
||||
parent: parent,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
type closure struct {
|
||||
db *sqlite3.Conn
|
||||
table string
|
||||
column string
|
||||
parent string
|
||||
}
|
||||
|
||||
func (c *closure) Destroy() error { return nil }
|
||||
|
||||
func (c *closure) BestIndex(idx *sqlite3.IndexInfo) error {
|
||||
posi := 1
|
||||
plan := 0
|
||||
cost := 10000000.0
|
||||
|
||||
for i, cst := range idx.Constraint {
|
||||
if !cst.Usable {
|
||||
continue
|
||||
}
|
||||
if plan&1 == 0 && cst.Column == _COL_ROOT {
|
||||
switch cst.Op {
|
||||
case sqlite3.INDEX_CONSTRAINT_EQ:
|
||||
plan |= 1
|
||||
cost /= 100
|
||||
idx.ConstraintUsage[i].ArgvIndex = 1
|
||||
idx.ConstraintUsage[i].Omit = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if plan&0xf0 == 0 && cst.Column == _COL_DEPTH {
|
||||
switch cst.Op {
|
||||
case sqlite3.INDEX_CONSTRAINT_LT, sqlite3.INDEX_CONSTRAINT_LE, sqlite3.INDEX_CONSTRAINT_EQ:
|
||||
plan |= posi << 4
|
||||
cost /= 5
|
||||
posi += 1
|
||||
idx.ConstraintUsage[i].ArgvIndex = posi
|
||||
if cst.Op == sqlite3.INDEX_CONSTRAINT_LT {
|
||||
plan |= 2
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if plan&0xf00 == 0 && cst.Column == _COL_TABLENAME {
|
||||
switch cst.Op {
|
||||
case sqlite3.INDEX_CONSTRAINT_EQ:
|
||||
plan |= posi << 8
|
||||
cost /= 5
|
||||
posi += 1
|
||||
idx.ConstraintUsage[i].ArgvIndex = posi
|
||||
idx.ConstraintUsage[i].Omit = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if plan&0xf000 == 0 && cst.Column == _COL_IDCOLUMN {
|
||||
switch cst.Op {
|
||||
case sqlite3.INDEX_CONSTRAINT_EQ:
|
||||
plan |= posi << 12
|
||||
posi += 1
|
||||
idx.ConstraintUsage[i].ArgvIndex = posi
|
||||
idx.ConstraintUsage[i].Omit = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if plan&0xf0000 == 0 && cst.Column == _COL_PARENTCOLUMN {
|
||||
switch cst.Op {
|
||||
case sqlite3.INDEX_CONSTRAINT_EQ:
|
||||
plan |= posi << 16
|
||||
posi += 1
|
||||
idx.ConstraintUsage[i].ArgvIndex = posi
|
||||
idx.ConstraintUsage[i].Omit = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if c.table == "" && plan&0xf00 == 0 ||
|
||||
c.column == "" && plan&0xf000 == 0 ||
|
||||
c.parent == "" && plan&0xf0000 == 0 {
|
||||
plan = 0
|
||||
}
|
||||
if plan&1 == 0 {
|
||||
plan = 0
|
||||
cost *= 1e30
|
||||
for i := range idx.Constraint {
|
||||
idx.ConstraintUsage[i].ArgvIndex = 0
|
||||
}
|
||||
}
|
||||
|
||||
idx.EstimatedCost = cost
|
||||
idx.IdxNum = plan
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *closure) Open() (sqlite3.VTabCursor, error) {
|
||||
return &cursor{closure: c}, nil
|
||||
}
|
||||
|
||||
type cursor struct {
|
||||
*closure
|
||||
nodes []node
|
||||
}
|
||||
|
||||
type node struct {
|
||||
id int64
|
||||
depth int
|
||||
}
|
||||
|
||||
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
|
||||
if idxNum&1 == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
root := arg[0].Int64()
|
||||
maxDepth := math.MaxInt
|
||||
if idxNum&0xf0 != 0 {
|
||||
maxDepth = arg[(idxNum>>4)&0xf].Int()
|
||||
if idxNum&2 != 0 {
|
||||
maxDepth -= 1
|
||||
}
|
||||
}
|
||||
table := c.table
|
||||
if idxNum&0xf00 != 0 {
|
||||
table = arg[(idxNum>>8)&0xf].Text()
|
||||
}
|
||||
column := c.column
|
||||
if idxNum&0xf000 != 0 {
|
||||
column = arg[(idxNum>>12)&0xf].Text()
|
||||
}
|
||||
parent := c.parent
|
||||
if idxNum&0xf0000 != 0 {
|
||||
parent = arg[(idxNum>>16)&0xf].Text()
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT %[1]s.%[2]s FROM %[1]s WHERE %[1]s.%[3]s=?`,
|
||||
sqlite3.QuoteIdentifier(table),
|
||||
sqlite3.QuoteIdentifier(column),
|
||||
sqlite3.QuoteIdentifier(parent),
|
||||
)
|
||||
stmt, _, err := c.db.Prepare(sql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
c.nodes = []node{{root, 0}}
|
||||
set := util.Set[int64]{}
|
||||
set.Add(root)
|
||||
for i := 0; i < len(c.nodes); i++ {
|
||||
curr := c.nodes[i]
|
||||
if curr.depth >= maxDepth {
|
||||
continue
|
||||
}
|
||||
stmt.BindInt64(1, curr.id)
|
||||
for stmt.Step() {
|
||||
if stmt.ColumnType(0) == sqlite3.INTEGER {
|
||||
next := stmt.ColumnInt64(0)
|
||||
if !set.Contains(next) {
|
||||
set.Add(next)
|
||||
c.nodes = append(c.nodes, node{next, curr.depth + 1})
|
||||
}
|
||||
}
|
||||
}
|
||||
stmt.Reset()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *cursor) Column(ctx sqlite3.Context, n int) error {
|
||||
switch n {
|
||||
case _COL_ID:
|
||||
ctx.ResultInt64(c.nodes[0].id)
|
||||
case _COL_DEPTH:
|
||||
ctx.ResultInt(c.nodes[0].depth)
|
||||
case _COL_TABLENAME:
|
||||
ctx.ResultText(c.table)
|
||||
case _COL_IDCOLUMN:
|
||||
ctx.ResultText(c.column)
|
||||
case _COL_PARENTCOLUMN:
|
||||
ctx.ResultText(c.parent)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *cursor) Next() error {
|
||||
c.nodes = c.nodes[1:]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *cursor) EOF() bool {
|
||||
return len(c.nodes) == 0
|
||||
}
|
||||
|
||||
func (c *cursor) RowID() (int64, error) {
|
||||
return c.nodes[0].id, nil
|
||||
}
|
||||
152
ext/closure/closure_test.go
Normal file
152
ext/closure/closure_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package closure_test
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"log"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/ext/closure"
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
sqlite3.AutoExtension(closure.Register)
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func Example() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`
|
||||
CREATE TABLE employees (
|
||||
id INTEGER PRIMARY KEY,
|
||||
parent_id INTEGER,
|
||||
name TEXT
|
||||
);
|
||||
CREATE INDEX employees_parent_idx ON employees(parent_id);
|
||||
INSERT INTO employees (id, parent_id, name) VALUES
|
||||
(11, NULL, 'Diane'),
|
||||
(12, 11, 'Bob'),
|
||||
(21, 11, 'Emma'),
|
||||
(22, 21, 'Grace'),
|
||||
(23, 21, 'Henry'),
|
||||
(24, 21, 'Irene'),
|
||||
(25, 21, 'Frank'),
|
||||
(31, 11, 'Cindy'),
|
||||
(32, 31, 'Dave'),
|
||||
(33, 31, 'Alice');
|
||||
CREATE VIRTUAL TABLE hierarchy USING transitive_closure(
|
||||
tablename = "employees",
|
||||
idcolumn = "id",
|
||||
parentcolumn = "parent_id"
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`
|
||||
SELECT employees.id, name FROM employees, hierarchy
|
||||
WHERE employees.id = hierarchy.id AND hierarchy.root = 31
|
||||
`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnInt(0), stmt.ColumnText(1))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Output:
|
||||
// 31 Cindy
|
||||
// 32 Dave
|
||||
// 33 Alice
|
||||
}
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`
|
||||
CREATE TABLE employees (
|
||||
id INTEGER PRIMARY KEY,
|
||||
parent_id INTEGER,
|
||||
name TEXT
|
||||
);
|
||||
CREATE INDEX employees_parent_idx ON employees(parent_id);
|
||||
INSERT INTO employees (id, parent_id, name) VALUES
|
||||
(11, NULL, 'Diane'),
|
||||
(12, 11, 'Bob'),
|
||||
(21, 11, 'Emma'),
|
||||
(22, 21, 'Grace'),
|
||||
(23, 21, 'Henry'),
|
||||
(24, 21, 'Irene'),
|
||||
(25, 21, 'Frank'),
|
||||
(31, 11, 'Cindy'),
|
||||
(32, 31, 'Dave'),
|
||||
(33, 31, 'Alice');
|
||||
CREATE VIRTUAL TABLE temp.closure USING transitive_closure;
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`
|
||||
SELECT employees.id, name FROM employees, closure
|
||||
WHERE employees.id = closure.id
|
||||
AND closure.root = 31
|
||||
AND closure.depth < 1
|
||||
AND closure.tablename='employees'
|
||||
AND closure.idcolumn='id'
|
||||
AND closure.parentcolumn='parent_id'
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if !stmt.Step() {
|
||||
t.Error("want row")
|
||||
}
|
||||
if stmt.Step() {
|
||||
t.Error("don't want row")
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -40,12 +40,12 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
|
||||
comma rune = ','
|
||||
comment rune
|
||||
|
||||
done = map[string]struct{}{}
|
||||
done = util.Set[string]{}
|
||||
)
|
||||
|
||||
for _, arg := range arg {
|
||||
key, val := vtabutil.NamedArg(arg)
|
||||
if _, ok := done[key]; ok {
|
||||
if done.Contains(key) {
|
||||
return nil, fmt.Errorf("csv: more than one %q parameter", key)
|
||||
}
|
||||
switch key {
|
||||
@@ -69,7 +69,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
done[key] = struct{}{}
|
||||
done.Add(key)
|
||||
}
|
||||
|
||||
if (filename == "") == (data == "") {
|
||||
|
||||
Reference in New Issue
Block a user