diff --git a/context.go b/context.go index 45e1b32..39066ef 100644 --- a/context.go +++ b/context.go @@ -14,6 +14,7 @@ import ( type Context struct { *module handle uint32 + final bool } // ResultBool sets the result of the function to a bool. @@ -154,4 +155,27 @@ func (c *Context) ResultError(err error) { c.call(c.api.resultErrorCode, uint64(c.handle), uint64(xcode)) } + +} + +func AggregateContext[T any](ctx Context) *T { + var size uint64 + if !ctx.final { + size = ptrlen + } + pAgg := uint32(ctx.call(ctx.api.aggregateData, uint64(ctx.handle), size)) + if pAgg == 0 { + return nil + } + pData := util.ReadUint32(ctx.mod, pAgg) + if data := util.GetHandle(ctx.ctx, pData); data != nil { + return data.(*T) + } + if ctx.final { + return nil + } + data := new(T) + pData = util.AddHandle(ctx.ctx, data) + util.WriteUint32(ctx.mod, pAgg, pData) + return data } diff --git a/embed/exports.txt b/embed/exports.txt index 70e5e3e..45418fa 100644 --- a/embed/exports.txt +++ b/embed/exports.txt @@ -50,8 +50,8 @@ sqlite3_get_autocommit sqlite3_anycollseq_init sqlite3_create_go_collation sqlite3_create_go_function -sqlite3_create_go_window_function sqlite3_create_go_aggregate_function +sqlite3_create_go_window_function sqlite3_aggregate_context sqlite3_user_data sqlite3_value_type diff --git a/embed/sqlite3.wasm b/embed/sqlite3.wasm index 3b8be54..4dd5c40 100755 Binary files a/embed/sqlite3.wasm and b/embed/sqlite3.wasm differ diff --git a/func.go b/func.go index 00b1369..8f7fc51 100644 --- a/func.go +++ b/func.go @@ -35,48 +35,124 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func( return c.error(r) } +// CreateWindowFunction defines a new aggregate or window function. +// +// https://www.sqlite.org/c3ref/create_function.html +func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn AggregateFunction) error { + call := c.api.createAggregate + namePtr := c.arena.string(name) + funcPtr := util.AddHandle(c.ctx, fn) + if _, ok := fn.(WindowFunction); ok { + call = c.api.createWindow + } + r := c.call(call, + uint64(c.handle), uint64(namePtr), uint64(nArg), + uint64(flag), uint64(funcPtr)) + return c.error(r) +} + +// ScalarFunction is the interface a scalar function should implement. +// +// https://www.sqlite.org/appfunc.html +type ScalarFunction interface { + Func(ctx Context, arg ...Value) +} + +// AggregateFunction is the interface an aggregate function should implement. +// +// https://www.sqlite.org/appfunc.html +type AggregateFunction interface { + Step(ctx Context, arg ...Value) + Final(ctx Context) +} + +// WindowFunction is the interface an aggregate window function should implement. +// +// https://www.sqlite.org/windowfunctions.html +type WindowFunction interface { + AggregateFunction + Value(ctx Context) + Inverse(ctx Context, arg ...Value) +} + func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { - util.ExportFuncVI(env, "go_destroy", cbDestroy) - util.ExportFuncIIIIII(env, "go_compare", cbCompare) - util.ExportFuncVIII(env, "go_func", cbFunc) - util.ExportFuncVIII(env, "go_step", cbStep) - util.ExportFuncVI(env, "go_final", cbFinal) - util.ExportFuncVI(env, "go_value", cbValue) - util.ExportFuncVIII(env, "go_inverse", cbInverse) + util.ExportFuncVI(env, "go_destroy", callbackDestroy) + util.ExportFuncIIIIII(env, "go_compare", callbackCompare) + util.ExportFuncVIII(env, "go_func", callbackFunc) + util.ExportFuncVIII(env, "go_step", callbackStep) + util.ExportFuncVI(env, "go_final", callbackFinal) + util.ExportFuncVI(env, "go_value", callbackValue) + util.ExportFuncVIII(env, "go_inverse", callbackInverse) return env } -func cbDestroy(ctx context.Context, mod api.Module, pApp uint32) { +func callbackDestroy(ctx context.Context, mod api.Module, pApp uint32) { util.DelHandle(ctx, pApp) } -func cbCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 { +func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 { fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int) return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2)))) } -func cbFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { +func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { module := ctx.Value(moduleKey{}).(*module) pApp := uint32(module.call(module.api.userData, uint64(pCtx))) + fn := util.GetHandle(ctx, pApp).(func(ctx Context, arg ...Value)) + fn(Context{ + module: module, + handle: pCtx, + }, callbackArgs(module, nArg, pArg)...) +} + +func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { + module := ctx.Value(moduleKey{}).(*module) + pApp := uint32(module.call(module.api.userData, uint64(pCtx))) + fn := util.GetHandle(ctx, pApp).(AggregateFunction) + fn.Step(Context{ + module: module, + handle: pCtx, + }, callbackArgs(module, nArg, pArg)...) +} + +func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) { + module := ctx.Value(moduleKey{}).(*module) + pApp := uint32(module.call(module.api.userData, uint64(pCtx))) + fn := util.GetHandle(ctx, pApp).(AggregateFunction) + fn.Final(Context{ + module: module, + handle: pCtx, + final: true, + }) +} + +func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) { + module := ctx.Value(moduleKey{}).(*module) + pApp := uint32(module.call(module.api.userData, uint64(pCtx))) + fn := util.GetHandle(ctx, pApp).(WindowFunction) + fn.Value(Context{ + module: module, + handle: pCtx, + }) +} + +func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { + module := ctx.Value(moduleKey{}).(*module) + pApp := uint32(module.call(module.api.userData, uint64(pCtx))) + fn := util.GetHandle(ctx, pApp).(WindowFunction) + fn.Inverse(Context{ + module: module, + handle: pCtx, + }, callbackArgs(module, nArg, pArg)...) +} + +func callbackArgs(module *module, nArg, pArg uint32) []Value { args := make([]Value, nArg) for i := range args { args[i] = Value{ - handle: util.ReadUint32(mod, pArg+ptrlen*uint32(i)), module: module, + handle: util.ReadUint32(module.mod, pArg+ptrlen*uint32(i)), } } - context := Context{ - handle: pCtx, - module: module, - } - fn := util.GetHandle(ctx, pApp).(func(ctx Context, arg ...Value)) - fn(context, args...) + return args } - -func cbStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {} - -func cbFinal(ctx context.Context, mod api.Module, pCtx uint32) {} - -func cbValue(ctx context.Context, mod api.Module, pCtx uint32) {} - -func cbInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {} diff --git a/func_test.go b/func_test.go index 16b35d5..d6cb53f 100644 --- a/func_test.go +++ b/func_test.go @@ -13,7 +13,7 @@ import ( ) func ExampleConn_CreateCollation() { - db, err := sqlite3.Open(memory) + db, err := sqlite3.Open(":memory:") if err != nil { log.Fatal(err) } @@ -65,7 +65,7 @@ func ExampleConn_CreateCollation() { } func ExampleConn_CreateFunction() { - db, err := sqlite3.Open(memory) + db, err := sqlite3.Open(":memory:") if err != nil { log.Fatal(err) } diff --git a/func_win_test.go b/func_win_test.go new file mode 100644 index 0000000..791ad6b --- /dev/null +++ b/func_win_test.go @@ -0,0 +1,79 @@ +package sqlite3_test + +import ( + "fmt" + "log" + "unicode" + + "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/embed" +) + +func ExampleConn_CreateWindowFunction() { + db, err := sqlite3.Open(":memory:") + if err != nil { + log.Fatal(err) + } + + err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`) + if err != nil { + log.Fatal(err) + } + + err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`) + if err != nil { + log.Fatal(err) + } + + err = db.CreateWindowFunction("count_ascii", 1, sqlite3.INNOCUOUS, countASCII{}) + if err != nil { + log.Fatal(err) + } + + stmt, _, err := db.Prepare(`SELECT count_ascii(word) FROM words`) + if err != nil { + log.Fatal(err) + } + defer stmt.Close() + + for stmt.Step() { + fmt.Println(stmt.ColumnInt(0)) + } + if err := stmt.Err(); err != nil { + log.Fatal(err) + } + + err = stmt.Close() + if err != nil { + log.Fatal(err) + } + + err = db.Close() + if err != nil { + log.Fatal(err) + } + // Output: + // 2 +} + +type countASCII struct{} + +func (countASCII) Step(ctx sqlite3.Context, arg ...sqlite3.Value) { + if arg[0].Type() != sqlite3.TEXT { + return + } + for _, c := range arg[0].RawText() { + if c > unicode.MaxASCII { + return + } + } + if count := sqlite3.AggregateContext[int](ctx); count != nil { + *count++ + } +} + +func (countASCII) Final(ctx sqlite3.Context) { + if count := sqlite3.AggregateContext[int](ctx); count != nil { + ctx.ResultInt(*count) + } +} diff --git a/go.work.sum b/go.work.sum new file mode 100644 index 0000000..cf0b1d6 --- /dev/null +++ b/go.work.sum @@ -0,0 +1,4 @@ +golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= diff --git a/module.go b/module.go index 4f253b6..e7576de 100644 --- a/module.go +++ b/module.go @@ -160,6 +160,9 @@ func newModule(mod api.Module) (m *module, err error) { autocommit: getFun("sqlite3_get_autocommit"), createCollation: getFun("sqlite3_create_go_collation"), createFunction: getFun("sqlite3_create_go_function"), + createAggregate: getFun("sqlite3_create_go_aggregate_function"), + createWindow: getFun("sqlite3_create_go_window_function"), + aggregateData: getFun("sqlite3_aggregate_context"), userData: getFun("sqlite3_user_data"), valueType: getFun("sqlite3_value_type"), valueInteger: getFun("sqlite3_value_int64"), @@ -374,6 +377,9 @@ type sqliteAPI struct { autocommit api.Function createCollation api.Function createFunction api.Function + createAggregate api.Function + createWindow api.Function + aggregateData api.Function userData api.Function valueType api.Function valueInteger api.Function diff --git a/sqlite3/func.c b/sqlite3/func.c index 1be22de..89a4712 100644 --- a/sqlite3/func.c +++ b/sqlite3/func.c @@ -21,15 +21,15 @@ int sqlite3_create_go_function(sqlite3 *db, const char *zName, int nArg, go_func, NULL, NULL, go_destroy); } -int sqlite3_create_go_window_function(sqlite3 *db, const char *zName, int nArg, - int flags, void *pApp) { +int sqlite3_create_go_aggregate_function(sqlite3 *db, const char *zName, + int nArg, int flags, void *pApp) { return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags, pApp, go_step, go_final, NULL, NULL, go_destroy); } -int sqlite3_create_go_aggregate_function(sqlite3 *db, const char *zName, - int nArg, int flags, void *pApp) { +int sqlite3_create_go_window_function(sqlite3 *db, const char *zName, int nArg, + int flags, void *pApp) { return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags, pApp, go_step, go_final, go_value, go_inverse, go_destroy);