Seq aggregate functions (#229)

This commit is contained in:
Nuno Cruces
2025-03-08 14:07:43 +00:00
committed by GitHub
parent 2f6cd8de1d
commit 26adda4529
7 changed files with 165 additions and 34 deletions

View File

@@ -42,7 +42,7 @@ func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultText(fs.FileMode(arg[0].Int()).String()) ctx.ResultText(fs.FileMode(arg[0].Int()).String())
} }
func readfile(fsys fs.FS) func(ctx sqlite3.Context, arg ...sqlite3.Value) { func readfile(fsys fs.FS) sqlite3.ScalarFunction {
return func(ctx sqlite3.Context, arg ...sqlite3.Value) { return func(ctx sqlite3.Context, arg ...sqlite3.Value) {
var err error var err error
var data []byte var data []byte

View File

@@ -64,8 +64,8 @@ func (d fsdir) Open() (sqlite3.VTabCursor, error) {
type cursor struct { type cursor struct {
fsdir fsdir
base string base string
resume func() (entry, bool) next func() (entry, bool)
cancel func() stop func()
curr entry curr entry
eof bool eof bool
rowID int64 rowID int64
@@ -78,8 +78,8 @@ type entry struct {
} }
func (c *cursor) Close() error { func (c *cursor) Close() error {
if c.cancel != nil { if c.stop != nil {
c.cancel() c.stop()
} }
return nil return nil
} }
@@ -102,7 +102,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
c.base = base c.base = base
} }
c.resume, c.cancel = iter.Pull(func(yield func(entry) bool) { c.next, c.stop = iter.Pull(func(yield func(entry) bool) {
walkDir := func(path string, d fs.DirEntry, err error) error { walkDir := func(path string, d fs.DirEntry, err error) error {
if yield(entry{d, err, path}) { if yield(entry{d, err, path}) {
return nil return nil
@@ -121,7 +121,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
} }
func (c *cursor) Next() error { func (c *cursor) Next() error {
curr, ok := c.resume() curr, ok := c.next()
c.curr = curr c.curr = curr
c.eof = !ok c.eof = !ok
c.rowID++ c.rowID++

View File

@@ -7,7 +7,7 @@ const (
some some
) )
func newBoolean(kind int) func() sqlite3.AggregateFunction { func newBoolean(kind int) sqlite3.AggregateConstructor {
return func() sqlite3.AggregateFunction { return &boolean{kind: kind} } return func() sqlite3.AggregateFunction { return &boolean{kind: kind} }
} }

View File

@@ -21,7 +21,7 @@ const (
percentile_disc percentile_disc
) )
func newPercentile(kind int) func() sqlite3.AggregateFunction { func newPercentile(kind int) sqlite3.AggregateConstructor {
return func() sqlite3.AggregateFunction { return &percentile{kind: kind} } return func() sqlite3.AggregateFunction { return &percentile{kind: kind} }
} }

View File

@@ -130,7 +130,7 @@ func special(kind int, n int64) (null, zero bool) {
} }
} }
func newVariance(kind int) func() sqlite3.AggregateFunction { func newVariance(kind int) sqlite3.AggregateConstructor {
return func() sqlite3.AggregateFunction { return &variance{kind: kind} } return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
} }
@@ -178,7 +178,7 @@ func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
} }
} }
func newCovariance(kind int) func() sqlite3.AggregateFunction { func newCovariance(kind int) sqlite3.AggregateConstructor {
return func() sqlite3.AggregateFunction { return &covariance{kind: kind} } return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
} }
@@ -254,7 +254,7 @@ func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
} }
} }
func newMoments(kind int) func() sqlite3.AggregateFunction { func newMoments(kind int) sqlite3.AggregateConstructor {
return func() sqlite3.AggregateFunction { return &momentfn{kind: kind} } return func() sqlite3.AggregateFunction { return &momentfn{kind: kind} }
} }

106
func.go
View File

@@ -3,6 +3,7 @@ package sqlite3
import ( import (
"context" "context"
"io" "io"
"iter"
"sync" "sync"
"github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/api"
@@ -45,7 +46,7 @@ func (c Conn) AnyCollationNeeded() error {
// CreateCollation defines a new collating sequence. // CreateCollation defines a new collating sequence.
// //
// https://sqlite.org/c3ref/create_collation.html // https://sqlite.org/c3ref/create_collation.html
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { func (c *Conn) CreateCollation(name string, fn CollatingFunction) error {
var funcPtr ptr_t var funcPtr ptr_t
defer c.arena.mark()() defer c.arena.mark()()
namePtr := c.arena.string(name) namePtr := c.arena.string(name)
@@ -57,6 +58,10 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
return c.error(rc) return c.error(rc)
} }
// Collating function is the type of a collation callback.
// Implementations must not retain a or b.
type CollatingFunction func(a, b []byte) int
// CreateFunction defines a new scalar SQL function. // CreateFunction defines a new scalar SQL function.
// //
// https://sqlite.org/c3ref/create_function.html // https://sqlite.org/c3ref/create_function.html
@@ -77,34 +82,67 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn Scala
// Implementations must not retain arg. // Implementations must not retain arg.
type ScalarFunction func(ctx Context, arg ...Value) type ScalarFunction func(ctx Context, arg ...Value)
// CreateWindowFunction defines a new aggregate or aggregate window SQL function. // CreateAggregateFunction defines a new aggregate SQL function.
// If fn returns a [WindowFunction], then an aggregate window function is created.
// If fn returns an [io.Closer], it will be called to free resources.
// //
// https://sqlite.org/c3ref/create_function.html // https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error { func (c *Conn) CreateAggregateFunction(name string, nArg int, flag FunctionFlag, fn AggregateSeqFunction) error {
var funcPtr ptr_t var funcPtr ptr_t
defer c.arena.mark()() defer c.arena.mark()()
namePtr := c.arena.string(name) namePtr := c.arena.string(name)
call := "sqlite3_create_aggregate_function_go"
if fn != nil { if fn != nil {
agg := fn() funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
if c, ok := agg.(io.Closer); ok { var a aggregateFunc
if err := c.Close(); err != nil { coro := func(yieldCoro func(struct{}) bool) {
return err seq := func(yieldSeq func([]Value) bool) {
for yieldSeq(a.arg) {
if !yieldCoro(struct{}{}) {
break
} }
} }
if _, ok := agg.(WindowFunction); ok {
call = "sqlite3_create_window_function_go"
} }
funcPtr = util.AddHandle(c.ctx, fn) fn(&a.ctx, seq)
} }
rc := res_t(c.call(call, a.next, a.stop = iter.Pull(coro)
return &a
}))
}
rc := res_t(c.call("sqlite3_create_aggregate_function_go",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg), stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
stk_t(flag), stk_t(funcPtr))) stk_t(flag), stk_t(funcPtr)))
return c.error(rc) return c.error(rc)
} }
// AggregateSeqFunction is the type of an aggregate SQL function.
// Implementations must not retain the slices produced by seq.
type AggregateSeqFunction func(ctx *Context, seq iter.Seq[[]Value])
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
// If fn returns a [WindowFunction], an aggregate window function is created.
// If fn returns an [io.Closer], it will be called to free resources.
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn AggregateConstructor) error {
var funcPtr ptr_t
defer c.arena.mark()()
namePtr := c.arena.string(name)
if fn != nil {
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
agg := fn()
if win, ok := agg.(WindowFunction); ok {
return win
}
return windowFunc{agg, name}
}))
}
rc := res_t(c.call("sqlite3_create_window_function_go",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
stk_t(flag), stk_t(funcPtr)))
return c.error(rc)
}
// AggregateConstructor is a an [AggregateFunction] constructor.
type AggregateConstructor func() AggregateFunction
// AggregateFunction is the interface an aggregate function should implement. // AggregateFunction is the interface an aggregate function should implement.
// //
// https://sqlite.org/appfunc.html // https://sqlite.org/appfunc.html
@@ -153,7 +191,7 @@ func collationCallback(ctx context.Context, mod api.Module, pArg, pDB ptr_t, eTe
} }
func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int32, pKey1 ptr_t, nKey2 int32, pKey2 ptr_t) uint32 { func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int32, pKey1 ptr_t, nKey2 int32, pKey2 ptr_t) uint32 {
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int) fn := util.GetHandle(ctx, pApp).(CollatingFunction)
return uint32(fn(util.View(mod, pKey1, int64(nKey1)), util.View(mod, pKey2, int64(nKey2)))) return uint32(fn(util.View(mod, pKey1, int64(nKey1)), util.View(mod, pKey2, int64(nKey2))))
} }
@@ -211,7 +249,7 @@ func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) {
} }
// We need to create the aggregate. // We need to create the aggregate.
fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)() fn := util.GetHandle(db.ctx, pApp).(AggregateConstructor)()
if pAgg != 0 { if pAgg != 0 {
handle := util.AddHandle(db.ctx, fn) handle := util.AddHandle(db.ctx, fn)
util.Write32(db.mod, pAgg, handle) util.Write32(db.mod, pAgg, handle)
@@ -232,6 +270,7 @@ func callbackArgs(db *Conn, arg []Value, pArg ptr_t) {
var funcArgsPool sync.Pool var funcArgsPool sync.Pool
func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) { func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) {
clear(p[:])
funcArgsPool.Put(p) funcArgsPool.Put(p)
} }
@@ -242,3 +281,38 @@ func getFuncArgs() *[_MAX_FUNCTION_ARG]Value {
return p.(*[_MAX_FUNCTION_ARG]Value) return p.(*[_MAX_FUNCTION_ARG]Value)
} }
} }
type aggregateFunc struct {
ctx Context
arg []Value
next func() (struct{}, bool)
stop func()
}
func (a *aggregateFunc) Step(ctx Context, arg ...Value) {
a.ctx = ctx
a.arg = arg
if _, more := a.next(); !more {
a.stop()
}
}
func (a *aggregateFunc) Value(ctx Context) {
a.ctx = ctx
a.stop()
}
func (a *aggregateFunc) Close() error {
a.stop()
return nil
}
type windowFunc struct {
AggregateFunction
name string
}
func (w windowFunc) Inverse(ctx Context, arg ...Value) {
// Implementing inverse allows certain queries that don't really need it to succeed.
ctx.ResultError(util.ErrorString(w.name + ": may not be used as a window function"))
}

57
func_seq_test.go Normal file
View File

@@ -0,0 +1,57 @@
package sqlite3_test
import (
"fmt"
"iter"
"log"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func ExampleConn_CreateAggregateFunction() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE test (col)`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (1), (2), (3)`)
if err != nil {
log.Fatal(err)
}
err = db.CreateAggregateFunction("seq_avg", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS,
func(ctx *sqlite3.Context, seq iter.Seq[[]sqlite3.Value]) {
count := 0
total := 0.0
for arg := range seq {
total += arg[0].Float()
count++
}
ctx.ResultFloat(total / float64(count))
})
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT seq_avg(col) FROM test`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnFloat(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Output:
// 2
}