diff --git a/const.go b/const.go index 49fda64..31f1972 100644 --- a/const.go +++ b/const.go @@ -13,6 +13,7 @@ const ( _MAX_LENGTH = 1e9 _MAX_SQL_LENGTH = 1e9 _MAX_ALLOCATION_SIZE = 0x7ffffeff + _MAX_FUNCTION_ARG = 100 ptrlen = 4 ) diff --git a/ext/hash/hash.go b/ext/hash/hash.go index 8f0656d..c649c8f 100644 --- a/ext/hash/hash.go +++ b/ext/hash/hash.go @@ -93,6 +93,5 @@ func hashFunc(ctx sqlite3.Context, arg sqlite3.Value, fn crypto.Hash) { h := fn.New() h.Write(data) - var res [64]byte - ctx.ResultBlob(h.Sum(res[:0])) + ctx.ResultBlob(h.Sum(nil)) } diff --git a/func.go b/func.go index d16e443..72253c4 100644 --- a/func.go +++ b/func.go @@ -2,6 +2,7 @@ package sqlite3 import ( "context" + "sync" "github.com/ncruces/go-sqlite3/internal/util" "github.com/tetratelabs/wazero/api" @@ -43,6 +44,7 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn Scala } // ScalarFunction is the type of a scalar SQL function. +// Implementations must not retain arg. type ScalarFunction func(ctx Context, arg ...Value) // CreateWindowFunction defines a new aggregate or aggregate window SQL function. @@ -69,7 +71,8 @@ func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn // https://sqlite.org/appfunc.html type AggregateFunction interface { // Step is invoked to add a row to the current window. - // The function arguments, if any, corresponding to the row being added are passed to Step. + // The function arguments, if any, corresponding to the row being added, are passed to Step. + // Implementations must not retain arg. Step(ctx Context, arg ...Value) // Value is invoked to return the current (or final) value of the aggregate. @@ -84,6 +87,7 @@ type WindowFunction interface { // Inverse is invoked to remove the oldest presently aggregated result of Step from the current window. // The function arguments, if any, are those passed to Step for the row being removed. + // Implementations must not retain arg. Inverse(ctx Context, arg ...Value) } @@ -108,15 +112,21 @@ func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nK } func funcCallback(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { + args := getFuncArgs() + defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) fn := userDataHandle(db, pCtx).(ScalarFunction) - fn(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...) + callbackArgs(db, args[:nArg], pArg) + fn(Context{db, pCtx}, args[:nArg]...) } func stepCallback(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { + args := getFuncArgs() + defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) fn := aggregateCtxHandle(db, pCtx, nil) - fn.Step(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...) + callbackArgs(db, args[:nArg], pArg) + fn.Step(Context{db, pCtx}, args[:nArg]...) } func finalCallback(ctx context.Context, mod api.Module, pCtx uint32) { @@ -136,9 +146,12 @@ func valueCallback(ctx context.Context, mod api.Module, pCtx uint32) { } func inverseCallback(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { + args := getFuncArgs() + defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) fn := aggregateCtxHandle(db, pCtx, nil).(WindowFunction) - fn.Inverse(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...) + callbackArgs(db, args[:nArg], pArg) + fn.Inverse(Context{db, pCtx}, args[:nArg]...) } func userDataHandle(db *Conn, pCtx uint32) any { @@ -174,13 +187,25 @@ func aggregateCtxHandle(db *Conn, pCtx uint32, close *uint32) AggregateFunction return fn } -func callbackArgs(db *Conn, nArg, pArg uint32) []Value { - args := make([]Value, nArg) - for i := range args { - args[i] = Value{ +func callbackArgs(db *Conn, arg []Value, pArg uint32) { + for i := range arg { + arg[i] = Value{ c: db, handle: util.ReadUint32(db.mod, pArg+ptrlen*uint32(i)), } } - return args +} + +var funcArgsPool sync.Pool + +func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) { + funcArgsPool.Put(p) +} + +func getFuncArgs() *[_MAX_FUNCTION_ARG]Value { + if p := funcArgsPool.Get(); p == nil { + return new([_MAX_FUNCTION_ARG]Value) + } else { + return p.(*[_MAX_FUNCTION_ARG]Value) + } } diff --git a/vtab.go b/vtab.go index d4dc011..3d46450 100644 --- a/vtab.go +++ b/vtab.go @@ -414,12 +414,12 @@ const ( ) func vtabModuleCallback(i int) func(_ context.Context, _ api.Module, _, _, _, _, _ uint32) uint32 { - return func(ctx context.Context, mod api.Module, pMod, argc, argv, ppVTab, pzErr uint32) uint32 { - arg := make([]reflect.Value, 1+argc) + return func(ctx context.Context, mod api.Module, pMod, nArg, pArg, ppVTab, pzErr uint32) uint32 { + arg := make([]reflect.Value, 1+nArg) arg[0] = reflect.ValueOf(ctx.Value(connKey{})) - for i := uint32(0); i < argc; i++ { - ptr := util.ReadUint32(mod, argv+i*ptrlen) + for i := uint32(0); i < nArg; i++ { + ptr := util.ReadUint32(mod, pArg+i*ptrlen) arg[i+1] = reflect.ValueOf(util.ReadString(mod, ptr, _MAX_SQL_LENGTH)) } @@ -461,11 +461,12 @@ func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab, argc, argv, pRowID uint32) uint32 { +func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab, nArg, pArg, pRowID uint32) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabUpdater) db := ctx.Value(connKey{}).(*Conn) - args := callbackArgs(db, argc, argv) + args := make([]Value, nArg) + callbackArgs(db, args, pArg) rowID, err := vtab.Update(args...) if err == nil { util.WriteUint64(mod, pRowID, uint64(rowID)) @@ -563,10 +564,11 @@ func cursorCloseCallback(ctx context.Context, mod api.Module, pCur uint32) uint3 return vtabError(ctx, mod, 0, _VTAB_ERROR, err) } -func cursorFilterCallback(ctx context.Context, mod api.Module, pCur, idxNum, idxStr, argc, argv uint32) uint32 { +func cursorFilterCallback(ctx context.Context, mod api.Module, pCur, idxNum, idxStr, nArg, pArg uint32) uint32 { cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) db := ctx.Value(connKey{}).(*Conn) - args := callbackArgs(db, argc, argv) + args := make([]Value, nArg) + callbackArgs(db, args, pArg) var idxName string if idxStr != 0 { idxName = util.ReadString(mod, idxStr, _MAX_LENGTH)