From 26adda4529afa2165feadec5c471a663e05dba02 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sat, 8 Mar 2025 14:07:43 +0000 Subject: [PATCH] Seq aggregate functions (#229) --- ext/fileio/fileio.go | 2 +- ext/fileio/fsdir.go | 20 ++++---- ext/stats/boolean.go | 2 +- ext/stats/percentile.go | 2 +- ext/stats/stats.go | 6 +-- func.go | 110 +++++++++++++++++++++++++++++++++------- func_seq_test.go | 57 +++++++++++++++++++++ 7 files changed, 165 insertions(+), 34 deletions(-) create mode 100644 func_seq_test.go diff --git a/ext/fileio/fileio.go b/ext/fileio/fileio.go index 234abee..5b2a67c 100644 --- a/ext/fileio/fileio.go +++ b/ext/fileio/fileio.go @@ -42,7 +42,7 @@ func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) { ctx.ResultText(fs.FileMode(arg[0].Int()).String()) } -func readfile(fsys fs.FS) func(ctx sqlite3.Context, arg ...sqlite3.Value) { +func readfile(fsys fs.FS) sqlite3.ScalarFunction { return func(ctx sqlite3.Context, arg ...sqlite3.Value) { var err error var data []byte diff --git a/ext/fileio/fsdir.go b/ext/fileio/fsdir.go index dc22774..734cb23 100644 --- a/ext/fileio/fsdir.go +++ b/ext/fileio/fsdir.go @@ -63,12 +63,12 @@ func (d fsdir) Open() (sqlite3.VTabCursor, error) { type cursor struct { fsdir - base string - resume func() (entry, bool) - cancel func() - curr entry - eof bool - rowID int64 + base string + next func() (entry, bool) + stop func() + curr entry + eof bool + rowID int64 } type entry struct { @@ -78,8 +78,8 @@ type entry struct { } func (c *cursor) Close() error { - if c.cancel != nil { - c.cancel() + if c.stop != nil { + c.stop() } return nil } @@ -102,7 +102,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { c.base = base } - c.resume, c.cancel = iter.Pull(func(yield func(entry) bool) { + c.next, c.stop = iter.Pull(func(yield func(entry) bool) { walkDir := func(path string, d fs.DirEntry, err error) error { if yield(entry{d, err, path}) { return nil @@ -121,7 +121,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { } func (c *cursor) Next() error { - curr, ok := c.resume() + curr, ok := c.next() c.curr = curr c.eof = !ok c.rowID++ diff --git a/ext/stats/boolean.go b/ext/stats/boolean.go index f30fa07..7aa5aaf 100644 --- a/ext/stats/boolean.go +++ b/ext/stats/boolean.go @@ -7,7 +7,7 @@ const ( some ) -func newBoolean(kind int) func() sqlite3.AggregateFunction { +func newBoolean(kind int) sqlite3.AggregateConstructor { return func() sqlite3.AggregateFunction { return &boolean{kind: kind} } } diff --git a/ext/stats/percentile.go b/ext/stats/percentile.go index 2e63f74..a84c54d 100644 --- a/ext/stats/percentile.go +++ b/ext/stats/percentile.go @@ -21,7 +21,7 @@ const ( percentile_disc ) -func newPercentile(kind int) func() sqlite3.AggregateFunction { +func newPercentile(kind int) sqlite3.AggregateConstructor { return func() sqlite3.AggregateFunction { return &percentile{kind: kind} } } diff --git a/ext/stats/stats.go b/ext/stats/stats.go index b0c7ff3..a345cb8 100644 --- a/ext/stats/stats.go +++ b/ext/stats/stats.go @@ -130,7 +130,7 @@ func special(kind int, n int64) (null, zero bool) { } } -func newVariance(kind int) func() sqlite3.AggregateFunction { +func newVariance(kind int) sqlite3.AggregateConstructor { return func() sqlite3.AggregateFunction { return &variance{kind: kind} } } @@ -178,7 +178,7 @@ func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) { } } -func newCovariance(kind int) func() sqlite3.AggregateFunction { +func newCovariance(kind int) sqlite3.AggregateConstructor { return func() sqlite3.AggregateFunction { return &covariance{kind: kind} } } @@ -254,7 +254,7 @@ func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) { } } -func newMoments(kind int) func() sqlite3.AggregateFunction { +func newMoments(kind int) sqlite3.AggregateConstructor { return func() sqlite3.AggregateFunction { return &momentfn{kind: kind} } } diff --git a/func.go b/func.go index f907fa9..e0dcb37 100644 --- a/func.go +++ b/func.go @@ -3,6 +3,7 @@ package sqlite3 import ( "context" "io" + "iter" "sync" "github.com/tetratelabs/wazero/api" @@ -45,7 +46,7 @@ func (c Conn) AnyCollationNeeded() error { // CreateCollation defines a new collating sequence. // // https://sqlite.org/c3ref/create_collation.html -func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { +func (c *Conn) CreateCollation(name string, fn CollatingFunction) error { var funcPtr ptr_t defer c.arena.mark()() namePtr := c.arena.string(name) @@ -57,6 +58,10 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { return c.error(rc) } +// Collating function is the type of a collation callback. +// Implementations must not retain a or b. +type CollatingFunction func(a, b []byte) int + // CreateFunction defines a new scalar SQL function. // // https://sqlite.org/c3ref/create_function.html @@ -77,34 +82,67 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn Scala // Implementations must not retain arg. type ScalarFunction func(ctx Context, arg ...Value) -// CreateWindowFunction defines a new aggregate or aggregate window SQL function. -// If fn returns a [WindowFunction], then an aggregate window function is created. -// If fn returns an [io.Closer], it will be called to free resources. +// CreateAggregateFunction defines a new aggregate SQL function. // // https://sqlite.org/c3ref/create_function.html -func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error { +func (c *Conn) CreateAggregateFunction(name string, nArg int, flag FunctionFlag, fn AggregateSeqFunction) error { var funcPtr ptr_t defer c.arena.mark()() namePtr := c.arena.string(name) - call := "sqlite3_create_aggregate_function_go" if fn != nil { - agg := fn() - if c, ok := agg.(io.Closer); ok { - if err := c.Close(); err != nil { - return err + funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction { + var a aggregateFunc + coro := func(yieldCoro func(struct{}) bool) { + seq := func(yieldSeq func([]Value) bool) { + for yieldSeq(a.arg) { + if !yieldCoro(struct{}{}) { + break + } + } + } + fn(&a.ctx, seq) } - } - if _, ok := agg.(WindowFunction); ok { - call = "sqlite3_create_window_function_go" - } - funcPtr = util.AddHandle(c.ctx, fn) + a.next, a.stop = iter.Pull(coro) + return &a + })) } - rc := res_t(c.call(call, + rc := res_t(c.call("sqlite3_create_aggregate_function_go", stk_t(c.handle), stk_t(namePtr), stk_t(nArg), stk_t(flag), stk_t(funcPtr))) return c.error(rc) } +// AggregateSeqFunction is the type of an aggregate SQL function. +// Implementations must not retain the slices produced by seq. +type AggregateSeqFunction func(ctx *Context, seq iter.Seq[[]Value]) + +// CreateWindowFunction defines a new aggregate or aggregate window SQL function. +// If fn returns a [WindowFunction], an aggregate window function is created. +// If fn returns an [io.Closer], it will be called to free resources. +// +// https://sqlite.org/c3ref/create_function.html +func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn AggregateConstructor) error { + var funcPtr ptr_t + defer c.arena.mark()() + namePtr := c.arena.string(name) + if fn != nil { + funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction { + agg := fn() + if win, ok := agg.(WindowFunction); ok { + return win + } + return windowFunc{agg, name} + })) + } + rc := res_t(c.call("sqlite3_create_window_function_go", + stk_t(c.handle), stk_t(namePtr), stk_t(nArg), + stk_t(flag), stk_t(funcPtr))) + return c.error(rc) +} + +// AggregateConstructor is a an [AggregateFunction] constructor. +type AggregateConstructor func() AggregateFunction + // AggregateFunction is the interface an aggregate function should implement. // // https://sqlite.org/appfunc.html @@ -153,7 +191,7 @@ func collationCallback(ctx context.Context, mod api.Module, pArg, pDB ptr_t, eTe } func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int32, pKey1 ptr_t, nKey2 int32, pKey2 ptr_t) uint32 { - fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int) + fn := util.GetHandle(ctx, pApp).(CollatingFunction) return uint32(fn(util.View(mod, pKey1, int64(nKey1)), util.View(mod, pKey2, int64(nKey2)))) } @@ -211,7 +249,7 @@ func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) { } // We need to create the aggregate. - fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)() + fn := util.GetHandle(db.ctx, pApp).(AggregateConstructor)() if pAgg != 0 { handle := util.AddHandle(db.ctx, fn) util.Write32(db.mod, pAgg, handle) @@ -232,6 +270,7 @@ func callbackArgs(db *Conn, arg []Value, pArg ptr_t) { var funcArgsPool sync.Pool func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) { + clear(p[:]) funcArgsPool.Put(p) } @@ -242,3 +281,38 @@ func getFuncArgs() *[_MAX_FUNCTION_ARG]Value { return p.(*[_MAX_FUNCTION_ARG]Value) } } + +type aggregateFunc struct { + ctx Context + arg []Value + next func() (struct{}, bool) + stop func() +} + +func (a *aggregateFunc) Step(ctx Context, arg ...Value) { + a.ctx = ctx + a.arg = arg + if _, more := a.next(); !more { + a.stop() + } +} + +func (a *aggregateFunc) Value(ctx Context) { + a.ctx = ctx + a.stop() +} + +func (a *aggregateFunc) Close() error { + a.stop() + return nil +} + +type windowFunc struct { + AggregateFunction + name string +} + +func (w windowFunc) Inverse(ctx Context, arg ...Value) { + // Implementing inverse allows certain queries that don't really need it to succeed. + ctx.ResultError(util.ErrorString(w.name + ": may not be used as a window function")) +} diff --git a/func_seq_test.go b/func_seq_test.go new file mode 100644 index 0000000..f05c2ca --- /dev/null +++ b/func_seq_test.go @@ -0,0 +1,57 @@ +package sqlite3_test + +import ( + "fmt" + "iter" + "log" + + "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/embed" +) + +func ExampleConn_CreateAggregateFunction() { + db, err := sqlite3.Open(":memory:") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + err = db.Exec(`CREATE TABLE test (col)`) + if err != nil { + log.Fatal(err) + } + + err = db.Exec(`INSERT INTO test VALUES (1), (2), (3)`) + if err != nil { + log.Fatal(err) + } + + err = db.CreateAggregateFunction("seq_avg", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, + func(ctx *sqlite3.Context, seq iter.Seq[[]sqlite3.Value]) { + count := 0 + total := 0.0 + for arg := range seq { + total += arg[0].Float() + count++ + } + ctx.ResultFloat(total / float64(count)) + }) + if err != nil { + log.Fatal(err) + } + + stmt, _, err := db.Prepare(`SELECT seq_avg(col) FROM test`) + if err != nil { + log.Fatal(err) + } + defer stmt.Close() + + for stmt.Step() { + fmt.Println(stmt.ColumnFloat(0)) + } + if err := stmt.Err(); err != nil { + log.Fatal(err) + } + // Output: + // 2 +}