Transitive closure virtual table.

This commit is contained in:
Nuno Cruces
2024-09-21 11:40:16 +01:00
parent d7376209ee
commit 2526fc8444
4 changed files with 430 additions and 3 deletions

263
ext/closure/closure.go Normal file
View File

@@ -0,0 +1,263 @@
// Package closure provides a transitive closure virtual table.
//
// The "transitive_closure" virtual table finds the transitive closure of
// a parent/child relationship in a real table.
//
// https://sqlite.org/src/doc/tip/ext/misc/closure.c
package closure
import (
"fmt"
"math"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/vtabutil"
)
const (
_COL_ID = 0
_COL_DEPTH = 1
_COL_ROOT = 2
_COL_TABLENAME = 3
_COL_IDCOLUMN = 4
_COL_PARENTCOLUMN = 5
)
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "transitive_closure", nil,
func(db *sqlite3.Conn, _, _, _ string, arg ...string) (*closure, error) {
var (
table string
column string
parent string
done = util.Set[string]{}
)
for _, arg := range arg {
key, val := vtabutil.NamedArg(arg)
if done.Contains(key) {
return nil, fmt.Errorf("transitive_closure: more than one %q parameter", key)
}
switch key {
case "tablename":
table = vtabutil.Unquote(val)
case "idcolumn":
column = vtabutil.Unquote(val)
case "parentcolumn":
parent = vtabutil.Unquote(val)
default:
return nil, fmt.Errorf("transitive_closure: unknown %q parameter", key)
}
done.Add(key)
}
err := db.DeclareVTab(`CREATE TABLE x(id,depth,root HIDDEN,tablename HIDDEN,idcolumn HIDDEN,parentcolumn HIDDEN)`)
if err != nil {
return nil, err
}
return &closure{
db: db,
table: table,
column: column,
parent: parent,
}, nil
})
}
type closure struct {
db *sqlite3.Conn
table string
column string
parent string
}
func (c *closure) Destroy() error { return nil }
func (c *closure) BestIndex(idx *sqlite3.IndexInfo) error {
posi := 1
plan := 0
cost := 10000000.0
for i, cst := range idx.Constraint {
if !cst.Usable {
continue
}
if plan&1 == 0 && cst.Column == _COL_ROOT {
switch cst.Op {
case sqlite3.INDEX_CONSTRAINT_EQ:
plan |= 1
cost /= 100
idx.ConstraintUsage[i].ArgvIndex = 1
idx.ConstraintUsage[i].Omit = true
}
continue
}
if plan&0xf0 == 0 && cst.Column == _COL_DEPTH {
switch cst.Op {
case sqlite3.INDEX_CONSTRAINT_LT, sqlite3.INDEX_CONSTRAINT_LE, sqlite3.INDEX_CONSTRAINT_EQ:
plan |= posi << 4
cost /= 5
posi += 1
idx.ConstraintUsage[i].ArgvIndex = posi
if cst.Op == sqlite3.INDEX_CONSTRAINT_LT {
plan |= 2
}
}
continue
}
if plan&0xf00 == 0 && cst.Column == _COL_TABLENAME {
switch cst.Op {
case sqlite3.INDEX_CONSTRAINT_EQ:
plan |= posi << 8
cost /= 5
posi += 1
idx.ConstraintUsage[i].ArgvIndex = posi
idx.ConstraintUsage[i].Omit = true
}
continue
}
if plan&0xf000 == 0 && cst.Column == _COL_IDCOLUMN {
switch cst.Op {
case sqlite3.INDEX_CONSTRAINT_EQ:
plan |= posi << 12
posi += 1
idx.ConstraintUsage[i].ArgvIndex = posi
idx.ConstraintUsage[i].Omit = true
}
continue
}
if plan&0xf0000 == 0 && cst.Column == _COL_PARENTCOLUMN {
switch cst.Op {
case sqlite3.INDEX_CONSTRAINT_EQ:
plan |= posi << 16
posi += 1
idx.ConstraintUsage[i].ArgvIndex = posi
idx.ConstraintUsage[i].Omit = true
}
continue
}
}
if c.table == "" && plan&0xf00 == 0 ||
c.column == "" && plan&0xf000 == 0 ||
c.parent == "" && plan&0xf0000 == 0 {
plan = 0
}
if plan&1 == 0 {
plan = 0
cost *= 1e30
for i := range idx.Constraint {
idx.ConstraintUsage[i].ArgvIndex = 0
}
}
idx.EstimatedCost = cost
idx.IdxNum = plan
return nil
}
func (c *closure) Open() (sqlite3.VTabCursor, error) {
return &cursor{closure: c}, nil
}
type cursor struct {
*closure
nodes []node
}
type node struct {
id int64
depth int
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
if idxNum&1 == 0 {
return nil
}
root := arg[0].Int64()
maxDepth := math.MaxInt
if idxNum&0xf0 != 0 {
maxDepth = arg[(idxNum>>4)&0xf].Int()
if idxNum&2 != 0 {
maxDepth -= 1
}
}
table := c.table
if idxNum&0xf00 != 0 {
table = arg[(idxNum>>8)&0xf].Text()
}
column := c.column
if idxNum&0xf000 != 0 {
column = arg[(idxNum>>12)&0xf].Text()
}
parent := c.parent
if idxNum&0xf0000 != 0 {
parent = arg[(idxNum>>16)&0xf].Text()
}
sql := fmt.Sprintf(
`SELECT %[1]s.%[2]s FROM %[1]s WHERE %[1]s.%[3]s=?`,
sqlite3.QuoteIdentifier(table),
sqlite3.QuoteIdentifier(column),
sqlite3.QuoteIdentifier(parent),
)
stmt, _, err := c.db.Prepare(sql)
if err != nil {
return err
}
defer stmt.Close()
c.nodes = []node{{root, 0}}
set := util.Set[int64]{}
set.Add(root)
for i := 0; i < len(c.nodes); i++ {
curr := c.nodes[i]
if curr.depth >= maxDepth {
continue
}
stmt.BindInt64(1, curr.id)
for stmt.Step() {
if stmt.ColumnType(0) == sqlite3.INTEGER {
next := stmt.ColumnInt64(0)
if !set.Contains(next) {
set.Add(next)
c.nodes = append(c.nodes, node{next, curr.depth + 1})
}
}
}
stmt.Reset()
}
return nil
}
func (c *cursor) Column(ctx sqlite3.Context, n int) error {
switch n {
case _COL_ID:
ctx.ResultInt64(c.nodes[0].id)
case _COL_DEPTH:
ctx.ResultInt(c.nodes[0].depth)
case _COL_TABLENAME:
ctx.ResultText(c.table)
case _COL_IDCOLUMN:
ctx.ResultText(c.column)
case _COL_PARENTCOLUMN:
ctx.ResultText(c.parent)
}
return nil
}
func (c *cursor) Next() error {
c.nodes = c.nodes[1:]
return nil
}
func (c *cursor) EOF() bool {
return len(c.nodes) == 0
}
func (c *cursor) RowID() (int64, error) {
return c.nodes[0].id, nil
}

152
ext/closure/closure_test.go Normal file
View File

@@ -0,0 +1,152 @@
package closure_test
import (
_ "embed"
"fmt"
"log"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/closure"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
)
func TestMain(m *testing.M) {
sqlite3.AutoExtension(closure.Register)
m.Run()
}
func Example() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = db.Exec(`
CREATE TABLE employees (
id INTEGER PRIMARY KEY,
parent_id INTEGER,
name TEXT
);
CREATE INDEX employees_parent_idx ON employees(parent_id);
INSERT INTO employees (id, parent_id, name) VALUES
(11, NULL, 'Diane'),
(12, 11, 'Bob'),
(21, 11, 'Emma'),
(22, 21, 'Grace'),
(23, 21, 'Henry'),
(24, 21, 'Irene'),
(25, 21, 'Frank'),
(31, 11, 'Cindy'),
(32, 31, 'Dave'),
(33, 31, 'Alice');
CREATE VIRTUAL TABLE hierarchy USING transitive_closure(
tablename = "employees",
idcolumn = "id",
parentcolumn = "parent_id"
);
`)
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`
SELECT employees.id, name FROM employees, hierarchy
WHERE employees.id = hierarchy.id AND hierarchy.root = 31
`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnInt(0), stmt.ColumnText(1))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
err = stmt.Close()
if err != nil {
log.Fatal(err)
}
err = db.Close()
if err != nil {
log.Fatal(err)
}
// Output:
// 31 Cindy
// 32 Dave
// 33 Alice
}
func TestRegister(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`
CREATE TABLE employees (
id INTEGER PRIMARY KEY,
parent_id INTEGER,
name TEXT
);
CREATE INDEX employees_parent_idx ON employees(parent_id);
INSERT INTO employees (id, parent_id, name) VALUES
(11, NULL, 'Diane'),
(12, 11, 'Bob'),
(21, 11, 'Emma'),
(22, 21, 'Grace'),
(23, 21, 'Henry'),
(24, 21, 'Irene'),
(25, 21, 'Frank'),
(31, 11, 'Cindy'),
(32, 31, 'Dave'),
(33, 31, 'Alice');
CREATE VIRTUAL TABLE temp.closure USING transitive_closure;
`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`
SELECT employees.id, name FROM employees, closure
WHERE employees.id = closure.id
AND closure.root = 31
AND closure.depth < 1
AND closure.tablename='employees'
AND closure.idcolumn='id'
AND closure.parentcolumn='parent_id'
`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if !stmt.Step() {
t.Error("want row")
}
if stmt.Step() {
t.Error("don't want row")
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -40,12 +40,12 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
comma rune = ','
comment rune
done = map[string]struct{}{}
done = util.Set[string]{}
)
for _, arg := range arg {
key, val := vtabutil.NamedArg(arg)
if _, ok := done[key]; ok {
if done.Contains(key) {
return nil, fmt.Errorf("csv: more than one %q parameter", key)
}
switch key {
@@ -69,7 +69,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
if err != nil {
return nil, err
}
done[key] = struct{}{}
done.Add(key)
}
if (filename == "") == (data == "") {

12
internal/util/set.go Normal file
View File

@@ -0,0 +1,12 @@
package util
type Set[E comparable] map[E]struct{}
func (s Set[E]) Add(v E) {
s[v] = struct{}{}
}
func (s Set[E]) Contains(v E) bool {
_, ok := s[v]
return ok
}