From 1cc7ecfe8d64394ee1b8806649e7983137c00260 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Mon, 3 Jul 2023 15:45:16 +0100 Subject: [PATCH] Custom aggregate functions. --- context.go | 24 ------------------ func.go | 65 ++++++++++++++++++++++++++++++++---------------- func_win_test.go | 57 +++++++++++++++++++++++++++++------------- module.go | 4 +-- 4 files changed, 85 insertions(+), 65 deletions(-) diff --git a/context.go b/context.go index 39066ef..45e1b32 100644 --- a/context.go +++ b/context.go @@ -14,7 +14,6 @@ import ( type Context struct { *module handle uint32 - final bool } // ResultBool sets the result of the function to a bool. @@ -155,27 +154,4 @@ func (c *Context) ResultError(err error) { c.call(c.api.resultErrorCode, uint64(c.handle), uint64(xcode)) } - -} - -func AggregateContext[T any](ctx Context) *T { - var size uint64 - if !ctx.final { - size = ptrlen - } - pAgg := uint32(ctx.call(ctx.api.aggregateData, uint64(ctx.handle), size)) - if pAgg == 0 { - return nil - } - pData := util.ReadUint32(ctx.mod, pAgg) - if data := util.GetHandle(ctx.ctx, pData); data != nil { - return data.(*T) - } - if ctx.final { - return nil - } - data := new(T) - pData = util.AddHandle(ctx.ctx, data) - util.WriteUint32(ctx.mod, pAgg, pData) - return data } diff --git a/func.go b/func.go index 8f7fc51..3bf73ae 100644 --- a/func.go +++ b/func.go @@ -23,7 +23,7 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { return nil } -// CreateFunction defines a new scalar function. +// CreateFunction defines a new scalar SQL function. // // https://www.sqlite.org/c3ref/create_function.html func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error { @@ -35,14 +35,15 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func( return c.error(r) } -// CreateWindowFunction defines a new aggregate or window function. +// CreateWindowFunction defines a new aggregate or aggregate window SQL function. +// If fn returns a [WindowFunction], then an aggregate window function is created. // // https://www.sqlite.org/c3ref/create_function.html -func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn AggregateFunction) error { +func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error { call := c.api.createAggregate namePtr := c.arena.string(name) funcPtr := util.AddHandle(c.ctx, fn) - if _, ok := fn.(WindowFunction); ok { + if _, ok := fn().(WindowFunction); ok { call = c.api.createWindow } r := c.call(call, @@ -51,13 +52,6 @@ func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn return c.error(r) } -// ScalarFunction is the interface a scalar function should implement. -// -// https://www.sqlite.org/appfunc.html -type ScalarFunction interface { - Func(ctx Context, arg ...Value) -} - // AggregateFunction is the interface an aggregate function should implement. // // https://www.sqlite.org/appfunc.html @@ -97,8 +91,7 @@ func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nK func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { module := ctx.Value(moduleKey{}).(*module) - pApp := uint32(module.call(module.api.userData, uint64(pCtx))) - fn := util.GetHandle(ctx, pApp).(func(ctx Context, arg ...Value)) + fn := callbackHandle(module, pCtx).(func(ctx Context, arg ...Value)) fn(Context{ module: module, handle: pCtx, @@ -107,8 +100,7 @@ func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { module := ctx.Value(moduleKey{}).(*module) - pApp := uint32(module.call(module.api.userData, uint64(pCtx))) - fn := util.GetHandle(ctx, pApp).(AggregateFunction) + fn := callbackAggregate(module, pCtx, nil).(AggregateFunction) fn.Step(Context{ module: module, handle: pCtx, @@ -116,20 +108,19 @@ func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) } func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) { + var handle uint32 module := ctx.Value(moduleKey{}).(*module) - pApp := uint32(module.call(module.api.userData, uint64(pCtx))) - fn := util.GetHandle(ctx, pApp).(AggregateFunction) + fn := callbackAggregate(module, pCtx, &handle).(AggregateFunction) fn.Final(Context{ module: module, handle: pCtx, - final: true, }) + util.DelHandle(ctx, handle) } func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) { module := ctx.Value(moduleKey{}).(*module) - pApp := uint32(module.call(module.api.userData, uint64(pCtx))) - fn := util.GetHandle(ctx, pApp).(WindowFunction) + fn := callbackAggregate(module, pCtx, nil).(WindowFunction) fn.Value(Context{ module: module, handle: pCtx, @@ -138,14 +129,44 @@ func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) { func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { module := ctx.Value(moduleKey{}).(*module) - pApp := uint32(module.call(module.api.userData, uint64(pCtx))) - fn := util.GetHandle(ctx, pApp).(WindowFunction) + fn := callbackAggregate(module, pCtx, nil).(WindowFunction) fn.Inverse(Context{ module: module, handle: pCtx, }, callbackArgs(module, nArg, pArg)...) } +func callbackHandle(module *module, pCtx uint32) any { + pApp := uint32(module.call(module.api.userData, uint64(pCtx))) + return util.GetHandle(module.ctx, pApp) +} + +func callbackAggregate(module *module, pCtx uint32, delete *uint32) any { + var size uint64 + if delete == nil { + size = ptrlen + } + ptr := uint32(module.call(module.api.aggregateCtx, uint64(pCtx), size)) + + if ptr != 0 { + if handle := util.ReadUint32(module.mod, ptr); handle != 0 { + fn := util.GetHandle(module.ctx, handle) + if delete != nil { + *delete = handle + } + if fn != nil { + return fn + } + } + } + + fn := callbackHandle(module, pCtx).(func() AggregateFunction)() + if ptr != 0 { + util.WriteUint32(module.mod, ptr, util.AddHandle(module.ctx, fn)) + } + return fn +} + func callbackArgs(module *module, nArg, pArg uint32) []Value { args := make([]Value, nArg) for i := range args { diff --git a/func_win_test.go b/func_win_test.go index 791ad6b..13ba5d2 100644 --- a/func_win_test.go +++ b/func_win_test.go @@ -25,12 +25,12 @@ func ExampleConn_CreateWindowFunction() { log.Fatal(err) } - err = db.CreateWindowFunction("count_ascii", 1, sqlite3.INNOCUOUS, countASCII{}) + err = db.CreateWindowFunction("count_ascii", 1, sqlite3.INNOCUOUS, newASCIICounter) if err != nil { log.Fatal(err) } - stmt, _, err := db.Prepare(`SELECT count_ascii(word) FROM words`) + stmt, _, err := db.Prepare(`SELECT count_ascii(word) OVER (ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM words`) if err != nil { log.Fatal(err) } @@ -53,27 +53,50 @@ func ExampleConn_CreateWindowFunction() { log.Fatal(err) } // Output: + // 1 // 2 + // 2 + // 1 + // 0 + // 0 } -type countASCII struct{} +type countASCII struct { + result int +} -func (countASCII) Step(ctx sqlite3.Context, arg ...sqlite3.Value) { - if arg[0].Type() != sqlite3.TEXT { - return +func newASCIICounter() sqlite3.AggregateFunction { + return &countASCII{} +} + +func (f *countASCII) Final(ctx sqlite3.Context) { + f.Value(ctx) +} + +func (f *countASCII) Value(ctx sqlite3.Context) { + ctx.ResultInt(f.result) +} + +func (f *countASCII) Step(ctx sqlite3.Context, arg ...sqlite3.Value) { + if f.isASCII(arg[0]) { + f.result++ } - for _, c := range arg[0].RawText() { +} + +func (f *countASCII) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) { + if f.isASCII(arg[0]) { + f.result-- + } +} + +func (f *countASCII) isASCII(arg sqlite3.Value) bool { + if arg.Type() != sqlite3.TEXT { + return false + } + for _, c := range arg.RawText() { if c > unicode.MaxASCII { - return + return false } } - if count := sqlite3.AggregateContext[int](ctx); count != nil { - *count++ - } -} - -func (countASCII) Final(ctx sqlite3.Context) { - if count := sqlite3.AggregateContext[int](ctx); count != nil { - ctx.ResultInt(*count) - } + return true } diff --git a/module.go b/module.go index e7576de..944b107 100644 --- a/module.go +++ b/module.go @@ -162,7 +162,7 @@ func newModule(mod api.Module) (m *module, err error) { createFunction: getFun("sqlite3_create_go_function"), createAggregate: getFun("sqlite3_create_go_aggregate_function"), createWindow: getFun("sqlite3_create_go_window_function"), - aggregateData: getFun("sqlite3_aggregate_context"), + aggregateCtx: getFun("sqlite3_aggregate_context"), userData: getFun("sqlite3_user_data"), valueType: getFun("sqlite3_value_type"), valueInteger: getFun("sqlite3_value_int64"), @@ -379,7 +379,7 @@ type sqliteAPI struct { createFunction api.Function createAggregate api.Function createWindow api.Function - aggregateData api.Function + aggregateCtx api.Function userData api.Function valueType api.Function valueInteger api.Function