mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-11 21:49:13 +00:00
266 lines
5.6 KiB
Go
266 lines
5.6 KiB
Go
// Package closure provides a transitive closure virtual table.
|
|
//
|
|
// The transitive_closure virtual table finds the transitive closure of
|
|
// a parent/child relationship in a real table.
|
|
//
|
|
// https://sqlite.org/src/doc/tip/ext/misc/closure.c
|
|
package closure
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
|
|
"github.com/ncruces/go-sqlite3"
|
|
"github.com/ncruces/go-sqlite3/internal/util"
|
|
"github.com/ncruces/go-sqlite3/util/sql3util"
|
|
)
|
|
|
|
const (
|
|
_COL_ID = 0
|
|
_COL_DEPTH = 1
|
|
_COL_ROOT = 2
|
|
_COL_TABLENAME = 3
|
|
_COL_IDCOLUMN = 4
|
|
_COL_PARENTCOLUMN = 5
|
|
)
|
|
|
|
// Register registers the transitive_closure virtual table:
|
|
//
|
|
// CREATE VIRTUAL TABLE temp.closure USING transitive_closure;
|
|
func Register(db *sqlite3.Conn) error {
|
|
return sqlite3.CreateModule(db, "transitive_closure", nil,
|
|
func(db *sqlite3.Conn, _, _, _ string, arg ...string) (*closure, error) {
|
|
var (
|
|
table string
|
|
column string
|
|
parent string
|
|
|
|
done = util.Set[string]{}
|
|
)
|
|
|
|
for _, arg := range arg {
|
|
key, val := sql3util.NamedArg(arg)
|
|
if done.Contains(key) {
|
|
return nil, fmt.Errorf("transitive_closure: more than one %q parameter", key)
|
|
}
|
|
switch key {
|
|
case "tablename":
|
|
table = sql3util.Unquote(val)
|
|
case "idcolumn":
|
|
column = sql3util.Unquote(val)
|
|
case "parentcolumn":
|
|
parent = sql3util.Unquote(val)
|
|
default:
|
|
return nil, fmt.Errorf("transitive_closure: unknown %q parameter", key)
|
|
}
|
|
done.Add(key)
|
|
}
|
|
|
|
err := db.DeclareVTab(`CREATE TABLE x(id,depth,root HIDDEN,tablename HIDDEN,idcolumn HIDDEN,parentcolumn HIDDEN)`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &closure{
|
|
db: db,
|
|
table: table,
|
|
column: column,
|
|
parent: parent,
|
|
}, nil
|
|
})
|
|
}
|
|
|
|
type closure struct {
|
|
db *sqlite3.Conn
|
|
table string
|
|
column string
|
|
parent string
|
|
}
|
|
|
|
func (c *closure) Destroy() error { return nil }
|
|
|
|
func (c *closure) BestIndex(idx *sqlite3.IndexInfo) error {
|
|
plan := 0
|
|
posi := 1
|
|
cost := 1e7
|
|
|
|
for i, cst := range idx.Constraint {
|
|
switch {
|
|
case !cst.Usable:
|
|
continue
|
|
|
|
case plan&1 == 0 && cst.Column == _COL_ROOT:
|
|
switch cst.Op {
|
|
case sqlite3.INDEX_CONSTRAINT_EQ:
|
|
plan |= 1
|
|
cost /= 100
|
|
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
|
|
ArgvIndex: 1,
|
|
Omit: true,
|
|
}
|
|
}
|
|
|
|
case plan&0xf0 == 0 && cst.Column == _COL_DEPTH:
|
|
switch cst.Op {
|
|
case sqlite3.INDEX_CONSTRAINT_LT, sqlite3.INDEX_CONSTRAINT_LE, sqlite3.INDEX_CONSTRAINT_EQ:
|
|
plan |= posi << 4
|
|
cost /= 5
|
|
posi += 1
|
|
idx.ConstraintUsage[i].ArgvIndex = posi
|
|
if cst.Op == sqlite3.INDEX_CONSTRAINT_LT {
|
|
plan |= 2
|
|
}
|
|
}
|
|
|
|
case plan&0xf00 == 0 && cst.Column == _COL_TABLENAME:
|
|
switch cst.Op {
|
|
case sqlite3.INDEX_CONSTRAINT_EQ:
|
|
plan |= posi << 8
|
|
cost /= 5
|
|
posi += 1
|
|
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
|
|
ArgvIndex: posi,
|
|
Omit: true,
|
|
}
|
|
}
|
|
|
|
case plan&0xf000 == 0 && cst.Column == _COL_IDCOLUMN:
|
|
switch cst.Op {
|
|
case sqlite3.INDEX_CONSTRAINT_EQ:
|
|
plan |= posi << 12
|
|
posi += 1
|
|
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
|
|
ArgvIndex: posi,
|
|
Omit: true,
|
|
}
|
|
}
|
|
|
|
case plan&0xf0000 == 0 && cst.Column == _COL_PARENTCOLUMN:
|
|
switch cst.Op {
|
|
case sqlite3.INDEX_CONSTRAINT_EQ:
|
|
plan |= posi << 16
|
|
posi += 1
|
|
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
|
|
ArgvIndex: posi,
|
|
Omit: true,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if plan&1 == 0 ||
|
|
c.table == "" && plan&0xf00 == 0 ||
|
|
c.column == "" && plan&0xf000 == 0 ||
|
|
c.parent == "" && plan&0xf0000 == 0 {
|
|
return sqlite3.CONSTRAINT
|
|
}
|
|
|
|
idx.IdxFlags = sqlite3.INDEX_SCAN_HEX
|
|
idx.EstimatedCost = cost
|
|
idx.IdxNum = plan
|
|
return nil
|
|
}
|
|
|
|
func (c *closure) Open() (sqlite3.VTabCursor, error) {
|
|
return &cursor{closure: c}, nil
|
|
}
|
|
|
|
type cursor struct {
|
|
*closure
|
|
nodes []node
|
|
}
|
|
|
|
type node struct {
|
|
id int64
|
|
depth int
|
|
}
|
|
|
|
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
|
|
root := arg[0].Int64()
|
|
maxDepth := math.MaxInt
|
|
if idxNum&0xf0 != 0 {
|
|
maxDepth = arg[(idxNum>>4)&0xf].Int()
|
|
if idxNum&2 != 0 {
|
|
maxDepth -= 1
|
|
}
|
|
}
|
|
table := c.table
|
|
if idxNum&0xf00 != 0 {
|
|
table = arg[(idxNum>>8)&0xf].Text()
|
|
}
|
|
column := c.column
|
|
if idxNum&0xf000 != 0 {
|
|
column = arg[(idxNum>>12)&0xf].Text()
|
|
}
|
|
parent := c.parent
|
|
if idxNum&0xf0000 != 0 {
|
|
parent = arg[(idxNum>>16)&0xf].Text()
|
|
}
|
|
|
|
sql := fmt.Sprintf(
|
|
`SELECT %[1]s.%[2]s FROM %[1]s WHERE %[1]s.%[3]s=?`,
|
|
sqlite3.QuoteIdentifier(table),
|
|
sqlite3.QuoteIdentifier(column),
|
|
sqlite3.QuoteIdentifier(parent),
|
|
)
|
|
stmt, _, err := c.db.Prepare(sql)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
c.nodes = []node{{root, 0}}
|
|
set := util.Set[int64]{}
|
|
set.Add(root)
|
|
for i := range c.nodes {
|
|
curr := c.nodes[i]
|
|
if curr.depth >= maxDepth {
|
|
continue
|
|
}
|
|
if err := stmt.BindInt64(1, curr.id); err != nil {
|
|
return err
|
|
}
|
|
for stmt.Step() {
|
|
if stmt.ColumnType(0) == sqlite3.INTEGER {
|
|
next := stmt.ColumnInt64(0)
|
|
if !set.Contains(next) {
|
|
set.Add(next)
|
|
c.nodes = append(c.nodes, node{next, curr.depth + 1})
|
|
}
|
|
}
|
|
}
|
|
if err := stmt.Reset(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *cursor) Column(ctx sqlite3.Context, n int) error {
|
|
switch n {
|
|
case _COL_ID:
|
|
ctx.ResultInt64(c.nodes[0].id)
|
|
case _COL_DEPTH:
|
|
ctx.ResultInt(c.nodes[0].depth)
|
|
case _COL_TABLENAME:
|
|
ctx.ResultText(c.table)
|
|
case _COL_IDCOLUMN:
|
|
ctx.ResultText(c.column)
|
|
case _COL_PARENTCOLUMN:
|
|
ctx.ResultText(c.parent)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *cursor) Next() error {
|
|
c.nodes = c.nodes[1:]
|
|
return nil
|
|
}
|
|
|
|
func (c *cursor) EOF() bool {
|
|
return len(c.nodes) == 0
|
|
}
|
|
|
|
func (c *cursor) RowID() (int64, error) {
|
|
return c.nodes[0].id, nil
|
|
}
|