From 2526fc8444c1493b009818ab92eaddf30f27d402 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sat, 21 Sep 2024 11:40:16 +0100 Subject: [PATCH] Transitive closure virtual table. --- ext/closure/closure.go | 263 ++++++++++++++++++++++++++++++++++++ ext/closure/closure_test.go | 152 +++++++++++++++++++++ ext/csv/csv.go | 6 +- internal/util/set.go | 12 ++ 4 files changed, 430 insertions(+), 3 deletions(-) create mode 100644 ext/closure/closure.go create mode 100644 ext/closure/closure_test.go create mode 100644 internal/util/set.go diff --git a/ext/closure/closure.go b/ext/closure/closure.go new file mode 100644 index 0000000..d586bb9 --- /dev/null +++ b/ext/closure/closure.go @@ -0,0 +1,263 @@ +// Package closure provides a transitive closure virtual table. +// +// The "transitive_closure" virtual table finds the transitive closure of +// a parent/child relationship in a real table. +// +// https://sqlite.org/src/doc/tip/ext/misc/closure.c +package closure + +import ( + "fmt" + "math" + + "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/internal/util" + "github.com/ncruces/go-sqlite3/util/vtabutil" +) + +const ( + _COL_ID = 0 + _COL_DEPTH = 1 + _COL_ROOT = 2 + _COL_TABLENAME = 3 + _COL_IDCOLUMN = 4 + _COL_PARENTCOLUMN = 5 +) + +func Register(db *sqlite3.Conn) error { + return sqlite3.CreateModule(db, "transitive_closure", nil, + func(db *sqlite3.Conn, _, _, _ string, arg ...string) (*closure, error) { + var ( + table string + column string + parent string + + done = util.Set[string]{} + ) + + for _, arg := range arg { + key, val := vtabutil.NamedArg(arg) + if done.Contains(key) { + return nil, fmt.Errorf("transitive_closure: more than one %q parameter", key) + } + switch key { + case "tablename": + table = vtabutil.Unquote(val) + case "idcolumn": + column = vtabutil.Unquote(val) + case "parentcolumn": + parent = vtabutil.Unquote(val) + default: + return nil, fmt.Errorf("transitive_closure: unknown %q parameter", key) + } + done.Add(key) + } + + err := db.DeclareVTab(`CREATE TABLE x(id,depth,root HIDDEN,tablename HIDDEN,idcolumn HIDDEN,parentcolumn HIDDEN)`) + if err != nil { + return nil, err + } + return &closure{ + db: db, + table: table, + column: column, + parent: parent, + }, nil + }) +} + +type closure struct { + db *sqlite3.Conn + table string + column string + parent string +} + +func (c *closure) Destroy() error { return nil } + +func (c *closure) BestIndex(idx *sqlite3.IndexInfo) error { + posi := 1 + plan := 0 + cost := 10000000.0 + + for i, cst := range idx.Constraint { + if !cst.Usable { + continue + } + if plan&1 == 0 && cst.Column == _COL_ROOT { + switch cst.Op { + case sqlite3.INDEX_CONSTRAINT_EQ: + plan |= 1 + cost /= 100 + idx.ConstraintUsage[i].ArgvIndex = 1 + idx.ConstraintUsage[i].Omit = true + } + continue + } + if plan&0xf0 == 0 && cst.Column == _COL_DEPTH { + switch cst.Op { + case sqlite3.INDEX_CONSTRAINT_LT, sqlite3.INDEX_CONSTRAINT_LE, sqlite3.INDEX_CONSTRAINT_EQ: + plan |= posi << 4 + cost /= 5 + posi += 1 + idx.ConstraintUsage[i].ArgvIndex = posi + if cst.Op == sqlite3.INDEX_CONSTRAINT_LT { + plan |= 2 + } + } + continue + } + if plan&0xf00 == 0 && cst.Column == _COL_TABLENAME { + switch cst.Op { + case sqlite3.INDEX_CONSTRAINT_EQ: + plan |= posi << 8 + cost /= 5 + posi += 1 + idx.ConstraintUsage[i].ArgvIndex = posi + idx.ConstraintUsage[i].Omit = true + } + continue + } + if plan&0xf000 == 0 && cst.Column == _COL_IDCOLUMN { + switch cst.Op { + case sqlite3.INDEX_CONSTRAINT_EQ: + plan |= posi << 12 + posi += 1 + idx.ConstraintUsage[i].ArgvIndex = posi + idx.ConstraintUsage[i].Omit = true + } + continue + } + if plan&0xf0000 == 0 && cst.Column == _COL_PARENTCOLUMN { + switch cst.Op { + case sqlite3.INDEX_CONSTRAINT_EQ: + plan |= posi << 16 + posi += 1 + idx.ConstraintUsage[i].ArgvIndex = posi + idx.ConstraintUsage[i].Omit = true + } + continue + } + } + + if c.table == "" && plan&0xf00 == 0 || + c.column == "" && plan&0xf000 == 0 || + c.parent == "" && plan&0xf0000 == 0 { + plan = 0 + } + if plan&1 == 0 { + plan = 0 + cost *= 1e30 + for i := range idx.Constraint { + idx.ConstraintUsage[i].ArgvIndex = 0 + } + } + + idx.EstimatedCost = cost + idx.IdxNum = plan + return nil +} + +func (c *closure) Open() (sqlite3.VTabCursor, error) { + return &cursor{closure: c}, nil +} + +type cursor struct { + *closure + nodes []node +} + +type node struct { + id int64 + depth int +} + +func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { + if idxNum&1 == 0 { + return nil + } + + root := arg[0].Int64() + maxDepth := math.MaxInt + if idxNum&0xf0 != 0 { + maxDepth = arg[(idxNum>>4)&0xf].Int() + if idxNum&2 != 0 { + maxDepth -= 1 + } + } + table := c.table + if idxNum&0xf00 != 0 { + table = arg[(idxNum>>8)&0xf].Text() + } + column := c.column + if idxNum&0xf000 != 0 { + column = arg[(idxNum>>12)&0xf].Text() + } + parent := c.parent + if idxNum&0xf0000 != 0 { + parent = arg[(idxNum>>16)&0xf].Text() + } + + sql := fmt.Sprintf( + `SELECT %[1]s.%[2]s FROM %[1]s WHERE %[1]s.%[3]s=?`, + sqlite3.QuoteIdentifier(table), + sqlite3.QuoteIdentifier(column), + sqlite3.QuoteIdentifier(parent), + ) + stmt, _, err := c.db.Prepare(sql) + if err != nil { + return err + } + defer stmt.Close() + + c.nodes = []node{{root, 0}} + set := util.Set[int64]{} + set.Add(root) + for i := 0; i < len(c.nodes); i++ { + curr := c.nodes[i] + if curr.depth >= maxDepth { + continue + } + stmt.BindInt64(1, curr.id) + for stmt.Step() { + if stmt.ColumnType(0) == sqlite3.INTEGER { + next := stmt.ColumnInt64(0) + if !set.Contains(next) { + set.Add(next) + c.nodes = append(c.nodes, node{next, curr.depth + 1}) + } + } + } + stmt.Reset() + } + return nil +} + +func (c *cursor) Column(ctx sqlite3.Context, n int) error { + switch n { + case _COL_ID: + ctx.ResultInt64(c.nodes[0].id) + case _COL_DEPTH: + ctx.ResultInt(c.nodes[0].depth) + case _COL_TABLENAME: + ctx.ResultText(c.table) + case _COL_IDCOLUMN: + ctx.ResultText(c.column) + case _COL_PARENTCOLUMN: + ctx.ResultText(c.parent) + } + return nil +} + +func (c *cursor) Next() error { + c.nodes = c.nodes[1:] + return nil +} + +func (c *cursor) EOF() bool { + return len(c.nodes) == 0 +} + +func (c *cursor) RowID() (int64, error) { + return c.nodes[0].id, nil +} diff --git a/ext/closure/closure_test.go b/ext/closure/closure_test.go new file mode 100644 index 0000000..f1b3da6 --- /dev/null +++ b/ext/closure/closure_test.go @@ -0,0 +1,152 @@ +package closure_test + +import ( + _ "embed" + "fmt" + "log" + "testing" + + "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/embed" + "github.com/ncruces/go-sqlite3/ext/closure" + _ "github.com/ncruces/go-sqlite3/internal/testcfg" +) + +func TestMain(m *testing.M) { + sqlite3.AutoExtension(closure.Register) + m.Run() +} + +func Example() { + db, err := sqlite3.Open(":memory:") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + err = db.Exec(` + CREATE TABLE employees ( + id INTEGER PRIMARY KEY, + parent_id INTEGER, + name TEXT + ); + CREATE INDEX employees_parent_idx ON employees(parent_id); + INSERT INTO employees (id, parent_id, name) VALUES + (11, NULL, 'Diane'), + (12, 11, 'Bob'), + (21, 11, 'Emma'), + (22, 21, 'Grace'), + (23, 21, 'Henry'), + (24, 21, 'Irene'), + (25, 21, 'Frank'), + (31, 11, 'Cindy'), + (32, 31, 'Dave'), + (33, 31, 'Alice'); + CREATE VIRTUAL TABLE hierarchy USING transitive_closure( + tablename = "employees", + idcolumn = "id", + parentcolumn = "parent_id" + ); + `) + if err != nil { + log.Fatal(err) + } + + stmt, _, err := db.Prepare(` + SELECT employees.id, name FROM employees, hierarchy + WHERE employees.id = hierarchy.id AND hierarchy.root = 31 + `) + if err != nil { + log.Fatal(err) + } + defer stmt.Close() + + for stmt.Step() { + fmt.Println(stmt.ColumnInt(0), stmt.ColumnText(1)) + } + if err := stmt.Err(); err != nil { + log.Fatal(err) + } + + err = stmt.Close() + if err != nil { + log.Fatal(err) + } + + err = db.Close() + if err != nil { + log.Fatal(err) + } + // Output: + // 31 Cindy + // 32 Dave + // 33 Alice +} + +func TestRegister(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Exec(` + CREATE TABLE employees ( + id INTEGER PRIMARY KEY, + parent_id INTEGER, + name TEXT + ); + CREATE INDEX employees_parent_idx ON employees(parent_id); + INSERT INTO employees (id, parent_id, name) VALUES + (11, NULL, 'Diane'), + (12, 11, 'Bob'), + (21, 11, 'Emma'), + (22, 21, 'Grace'), + (23, 21, 'Henry'), + (24, 21, 'Irene'), + (25, 21, 'Frank'), + (31, 11, 'Cindy'), + (32, 31, 'Dave'), + (33, 31, 'Alice'); + CREATE VIRTUAL TABLE temp.closure USING transitive_closure; + `) + if err != nil { + t.Fatal(err) + } + + stmt, _, err := db.Prepare(` + SELECT employees.id, name FROM employees, closure + WHERE employees.id = closure.id + AND closure.root = 31 + AND closure.depth < 1 + AND closure.tablename='employees' + AND closure.idcolumn='id' + AND closure.parentcolumn='parent_id' + `) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if !stmt.Step() { + t.Error("want row") + } + if stmt.Step() { + t.Error("don't want row") + } + if err := stmt.Err(); err != nil { + t.Fatal(err) + } + + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + + err = db.Close() + if err != nil { + t.Fatal(err) + } +} diff --git a/ext/csv/csv.go b/ext/csv/csv.go index d9d0951..097380e 100644 --- a/ext/csv/csv.go +++ b/ext/csv/csv.go @@ -40,12 +40,12 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error { comma rune = ',' comment rune - done = map[string]struct{}{} + done = util.Set[string]{} ) for _, arg := range arg { key, val := vtabutil.NamedArg(arg) - if _, ok := done[key]; ok { + if done.Contains(key) { return nil, fmt.Errorf("csv: more than one %q parameter", key) } switch key { @@ -69,7 +69,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error { if err != nil { return nil, err } - done[key] = struct{}{} + done.Add(key) } if (filename == "") == (data == "") { diff --git a/internal/util/set.go b/internal/util/set.go new file mode 100644 index 0000000..9cfd5ab --- /dev/null +++ b/internal/util/set.go @@ -0,0 +1,12 @@ +package util + +type Set[E comparable] map[E]struct{} + +func (s Set[E]) Add(v E) { + s[v] = struct{}{} +} + +func (s Set[E]) Contains(v E) bool { + _, ok := s[v] + return ok +}