mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-11 21:49:13 +00:00
Seq aggregate functions (#229)
This commit is contained in:
@@ -42,7 +42,7 @@ func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
|||||||
ctx.ResultText(fs.FileMode(arg[0].Int()).String())
|
ctx.ResultText(fs.FileMode(arg[0].Int()).String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func readfile(fsys fs.FS) func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
func readfile(fsys fs.FS) sqlite3.ScalarFunction {
|
||||||
return func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
return func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||||
var err error
|
var err error
|
||||||
var data []byte
|
var data []byte
|
||||||
|
|||||||
@@ -63,12 +63,12 @@ func (d fsdir) Open() (sqlite3.VTabCursor, error) {
|
|||||||
|
|
||||||
type cursor struct {
|
type cursor struct {
|
||||||
fsdir
|
fsdir
|
||||||
base string
|
base string
|
||||||
resume func() (entry, bool)
|
next func() (entry, bool)
|
||||||
cancel func()
|
stop func()
|
||||||
curr entry
|
curr entry
|
||||||
eof bool
|
eof bool
|
||||||
rowID int64
|
rowID int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type entry struct {
|
type entry struct {
|
||||||
@@ -78,8 +78,8 @@ type entry struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *cursor) Close() error {
|
func (c *cursor) Close() error {
|
||||||
if c.cancel != nil {
|
if c.stop != nil {
|
||||||
c.cancel()
|
c.stop()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -102,7 +102,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
|
|||||||
c.base = base
|
c.base = base
|
||||||
}
|
}
|
||||||
|
|
||||||
c.resume, c.cancel = iter.Pull(func(yield func(entry) bool) {
|
c.next, c.stop = iter.Pull(func(yield func(entry) bool) {
|
||||||
walkDir := func(path string, d fs.DirEntry, err error) error {
|
walkDir := func(path string, d fs.DirEntry, err error) error {
|
||||||
if yield(entry{d, err, path}) {
|
if yield(entry{d, err, path}) {
|
||||||
return nil
|
return nil
|
||||||
@@ -121,7 +121,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *cursor) Next() error {
|
func (c *cursor) Next() error {
|
||||||
curr, ok := c.resume()
|
curr, ok := c.next()
|
||||||
c.curr = curr
|
c.curr = curr
|
||||||
c.eof = !ok
|
c.eof = !ok
|
||||||
c.rowID++
|
c.rowID++
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ const (
|
|||||||
some
|
some
|
||||||
)
|
)
|
||||||
|
|
||||||
func newBoolean(kind int) func() sqlite3.AggregateFunction {
|
func newBoolean(kind int) sqlite3.AggregateConstructor {
|
||||||
return func() sqlite3.AggregateFunction { return &boolean{kind: kind} }
|
return func() sqlite3.AggregateFunction { return &boolean{kind: kind} }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ const (
|
|||||||
percentile_disc
|
percentile_disc
|
||||||
)
|
)
|
||||||
|
|
||||||
func newPercentile(kind int) func() sqlite3.AggregateFunction {
|
func newPercentile(kind int) sqlite3.AggregateConstructor {
|
||||||
return func() sqlite3.AggregateFunction { return &percentile{kind: kind} }
|
return func() sqlite3.AggregateFunction { return &percentile{kind: kind} }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ func special(kind int, n int64) (null, zero bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newVariance(kind int) func() sqlite3.AggregateFunction {
|
func newVariance(kind int) sqlite3.AggregateConstructor {
|
||||||
return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
|
return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,7 +178,7 @@ func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCovariance(kind int) func() sqlite3.AggregateFunction {
|
func newCovariance(kind int) sqlite3.AggregateConstructor {
|
||||||
return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
|
return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -254,7 +254,7 @@ func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMoments(kind int) func() sqlite3.AggregateFunction {
|
func newMoments(kind int) sqlite3.AggregateConstructor {
|
||||||
return func() sqlite3.AggregateFunction { return &momentfn{kind: kind} }
|
return func() sqlite3.AggregateFunction { return &momentfn{kind: kind} }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
110
func.go
110
func.go
@@ -3,6 +3,7 @@ package sqlite3
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
|
"iter"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/tetratelabs/wazero/api"
|
"github.com/tetratelabs/wazero/api"
|
||||||
@@ -45,7 +46,7 @@ func (c Conn) AnyCollationNeeded() error {
|
|||||||
// CreateCollation defines a new collating sequence.
|
// CreateCollation defines a new collating sequence.
|
||||||
//
|
//
|
||||||
// https://sqlite.org/c3ref/create_collation.html
|
// https://sqlite.org/c3ref/create_collation.html
|
||||||
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
|
func (c *Conn) CreateCollation(name string, fn CollatingFunction) error {
|
||||||
var funcPtr ptr_t
|
var funcPtr ptr_t
|
||||||
defer c.arena.mark()()
|
defer c.arena.mark()()
|
||||||
namePtr := c.arena.string(name)
|
namePtr := c.arena.string(name)
|
||||||
@@ -57,6 +58,10 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
|
|||||||
return c.error(rc)
|
return c.error(rc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Collating function is the type of a collation callback.
|
||||||
|
// Implementations must not retain a or b.
|
||||||
|
type CollatingFunction func(a, b []byte) int
|
||||||
|
|
||||||
// CreateFunction defines a new scalar SQL function.
|
// CreateFunction defines a new scalar SQL function.
|
||||||
//
|
//
|
||||||
// https://sqlite.org/c3ref/create_function.html
|
// https://sqlite.org/c3ref/create_function.html
|
||||||
@@ -77,34 +82,67 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn Scala
|
|||||||
// Implementations must not retain arg.
|
// Implementations must not retain arg.
|
||||||
type ScalarFunction func(ctx Context, arg ...Value)
|
type ScalarFunction func(ctx Context, arg ...Value)
|
||||||
|
|
||||||
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
|
// CreateAggregateFunction defines a new aggregate SQL function.
|
||||||
// If fn returns a [WindowFunction], then an aggregate window function is created.
|
|
||||||
// If fn returns an [io.Closer], it will be called to free resources.
|
|
||||||
//
|
//
|
||||||
// https://sqlite.org/c3ref/create_function.html
|
// https://sqlite.org/c3ref/create_function.html
|
||||||
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
|
func (c *Conn) CreateAggregateFunction(name string, nArg int, flag FunctionFlag, fn AggregateSeqFunction) error {
|
||||||
var funcPtr ptr_t
|
var funcPtr ptr_t
|
||||||
defer c.arena.mark()()
|
defer c.arena.mark()()
|
||||||
namePtr := c.arena.string(name)
|
namePtr := c.arena.string(name)
|
||||||
call := "sqlite3_create_aggregate_function_go"
|
|
||||||
if fn != nil {
|
if fn != nil {
|
||||||
agg := fn()
|
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
|
||||||
if c, ok := agg.(io.Closer); ok {
|
var a aggregateFunc
|
||||||
if err := c.Close(); err != nil {
|
coro := func(yieldCoro func(struct{}) bool) {
|
||||||
return err
|
seq := func(yieldSeq func([]Value) bool) {
|
||||||
|
for yieldSeq(a.arg) {
|
||||||
|
if !yieldCoro(struct{}{}) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn(&a.ctx, seq)
|
||||||
}
|
}
|
||||||
}
|
a.next, a.stop = iter.Pull(coro)
|
||||||
if _, ok := agg.(WindowFunction); ok {
|
return &a
|
||||||
call = "sqlite3_create_window_function_go"
|
}))
|
||||||
}
|
|
||||||
funcPtr = util.AddHandle(c.ctx, fn)
|
|
||||||
}
|
}
|
||||||
rc := res_t(c.call(call,
|
rc := res_t(c.call("sqlite3_create_aggregate_function_go",
|
||||||
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
|
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
|
||||||
stk_t(flag), stk_t(funcPtr)))
|
stk_t(flag), stk_t(funcPtr)))
|
||||||
return c.error(rc)
|
return c.error(rc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AggregateSeqFunction is the type of an aggregate SQL function.
|
||||||
|
// Implementations must not retain the slices produced by seq.
|
||||||
|
type AggregateSeqFunction func(ctx *Context, seq iter.Seq[[]Value])
|
||||||
|
|
||||||
|
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
|
||||||
|
// If fn returns a [WindowFunction], an aggregate window function is created.
|
||||||
|
// If fn returns an [io.Closer], it will be called to free resources.
|
||||||
|
//
|
||||||
|
// https://sqlite.org/c3ref/create_function.html
|
||||||
|
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn AggregateConstructor) error {
|
||||||
|
var funcPtr ptr_t
|
||||||
|
defer c.arena.mark()()
|
||||||
|
namePtr := c.arena.string(name)
|
||||||
|
if fn != nil {
|
||||||
|
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
|
||||||
|
agg := fn()
|
||||||
|
if win, ok := agg.(WindowFunction); ok {
|
||||||
|
return win
|
||||||
|
}
|
||||||
|
return windowFunc{agg, name}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
rc := res_t(c.call("sqlite3_create_window_function_go",
|
||||||
|
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
|
||||||
|
stk_t(flag), stk_t(funcPtr)))
|
||||||
|
return c.error(rc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AggregateConstructor is a an [AggregateFunction] constructor.
|
||||||
|
type AggregateConstructor func() AggregateFunction
|
||||||
|
|
||||||
// AggregateFunction is the interface an aggregate function should implement.
|
// AggregateFunction is the interface an aggregate function should implement.
|
||||||
//
|
//
|
||||||
// https://sqlite.org/appfunc.html
|
// https://sqlite.org/appfunc.html
|
||||||
@@ -153,7 +191,7 @@ func collationCallback(ctx context.Context, mod api.Module, pArg, pDB ptr_t, eTe
|
|||||||
}
|
}
|
||||||
|
|
||||||
func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int32, pKey1 ptr_t, nKey2 int32, pKey2 ptr_t) uint32 {
|
func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int32, pKey1 ptr_t, nKey2 int32, pKey2 ptr_t) uint32 {
|
||||||
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
|
fn := util.GetHandle(ctx, pApp).(CollatingFunction)
|
||||||
return uint32(fn(util.View(mod, pKey1, int64(nKey1)), util.View(mod, pKey2, int64(nKey2))))
|
return uint32(fn(util.View(mod, pKey1, int64(nKey1)), util.View(mod, pKey2, int64(nKey2))))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -211,7 +249,7 @@ func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// We need to create the aggregate.
|
// We need to create the aggregate.
|
||||||
fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)()
|
fn := util.GetHandle(db.ctx, pApp).(AggregateConstructor)()
|
||||||
if pAgg != 0 {
|
if pAgg != 0 {
|
||||||
handle := util.AddHandle(db.ctx, fn)
|
handle := util.AddHandle(db.ctx, fn)
|
||||||
util.Write32(db.mod, pAgg, handle)
|
util.Write32(db.mod, pAgg, handle)
|
||||||
@@ -232,6 +270,7 @@ func callbackArgs(db *Conn, arg []Value, pArg ptr_t) {
|
|||||||
var funcArgsPool sync.Pool
|
var funcArgsPool sync.Pool
|
||||||
|
|
||||||
func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) {
|
func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) {
|
||||||
|
clear(p[:])
|
||||||
funcArgsPool.Put(p)
|
funcArgsPool.Put(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -242,3 +281,38 @@ func getFuncArgs() *[_MAX_FUNCTION_ARG]Value {
|
|||||||
return p.(*[_MAX_FUNCTION_ARG]Value)
|
return p.(*[_MAX_FUNCTION_ARG]Value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type aggregateFunc struct {
|
||||||
|
ctx Context
|
||||||
|
arg []Value
|
||||||
|
next func() (struct{}, bool)
|
||||||
|
stop func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *aggregateFunc) Step(ctx Context, arg ...Value) {
|
||||||
|
a.ctx = ctx
|
||||||
|
a.arg = arg
|
||||||
|
if _, more := a.next(); !more {
|
||||||
|
a.stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *aggregateFunc) Value(ctx Context) {
|
||||||
|
a.ctx = ctx
|
||||||
|
a.stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *aggregateFunc) Close() error {
|
||||||
|
a.stop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type windowFunc struct {
|
||||||
|
AggregateFunction
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w windowFunc) Inverse(ctx Context, arg ...Value) {
|
||||||
|
// Implementing inverse allows certain queries that don't really need it to succeed.
|
||||||
|
ctx.ResultError(util.ErrorString(w.name + ": may not be used as a window function"))
|
||||||
|
}
|
||||||
|
|||||||
57
func_seq_test.go
Normal file
57
func_seq_test.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package sqlite3_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/ncruces/go-sqlite3"
|
||||||
|
_ "github.com/ncruces/go-sqlite3/embed"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExampleConn_CreateAggregateFunction() {
|
||||||
|
db, err := sqlite3.Open(":memory:")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
err = db.Exec(`CREATE TABLE test (col)`)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Exec(`INSERT INTO test VALUES (1), (2), (3)`)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.CreateAggregateFunction("seq_avg", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS,
|
||||||
|
func(ctx *sqlite3.Context, seq iter.Seq[[]sqlite3.Value]) {
|
||||||
|
count := 0
|
||||||
|
total := 0.0
|
||||||
|
for arg := range seq {
|
||||||
|
total += arg[0].Float()
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
ctx.ResultFloat(total / float64(count))
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, _, err := db.Prepare(`SELECT seq_avg(col) FROM test`)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
for stmt.Step() {
|
||||||
|
fmt.Println(stmt.ColumnFloat(0))
|
||||||
|
}
|
||||||
|
if err := stmt.Err(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
// Output:
|
||||||
|
// 2
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user