From 9e7a0a875dab7fdf36157a3fd12621bac6e18085 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Mon, 10 Mar 2025 12:01:15 +0000 Subject: [PATCH] Improved arg reuse. --- const.go | 7 +++-- ext/bloom/bloom.go | 6 ++--- ext/statement/stmt.go | 4 +-- func.go | 59 ++++++++++++++++++++++--------------------- vtab.go | 22 +++++++++------- 5 files changed, 51 insertions(+), 47 deletions(-) diff --git a/const.go b/const.go index 82d8051..522f68b 100644 --- a/const.go +++ b/const.go @@ -11,10 +11,9 @@ const ( _ROW = 100 /* sqlite3_step() has another row ready */ _DONE = 101 /* sqlite3_step() has finished executing */ - _MAX_NAME = 1e6 // Self-imposed limit for most NUL terminated strings. - _MAX_LENGTH = 1e9 - _MAX_SQL_LENGTH = 1e9 - _MAX_FUNCTION_ARG = 100 + _MAX_NAME = 1e6 // Self-imposed limit for most NUL terminated strings. + _MAX_LENGTH = 1e9 + _MAX_SQL_LENGTH = 1e9 ptrlen = util.PtrLen intlen = util.IntLen diff --git a/ext/bloom/bloom.go b/ext/bloom/bloom.go index e065406..b71f90a 100644 --- a/ext/bloom/bloom.go +++ b/ext/bloom/bloom.go @@ -268,13 +268,13 @@ func (b *bloom) Open() (sqlite3.VTabCursor, error) { type cursor struct { *bloom - arg *sqlite3.Value + arg sqlite3.Value eof bool } func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { c.eof = false - c.arg = &arg[0] + c.arg = arg[0] blob := arg[0].RawBlob() f, err := c.db.OpenBlob(c.schema, c.storage, "data", 1, false) @@ -312,7 +312,7 @@ func (c *cursor) Column(ctx sqlite3.Context, n int) error { case 0: ctx.ResultBool(true) case 1: - ctx.ResultValue(*c.arg) + ctx.ResultValue(c.arg) } return nil } diff --git a/ext/statement/stmt.go b/ext/statement/stmt.go index 6e1a000..cb9614e 100644 --- a/ext/statement/stmt.go +++ b/ext/statement/stmt.go @@ -159,8 +159,6 @@ func (c *cursor) Close() error { } func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { - c.arg = arg - c.rowID = 0 err := errors.Join( c.stmt.Reset(), c.stmt.ClearBindings()) @@ -187,6 +185,8 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { return err } } + c.arg = append(c.arg[:0], arg...) + c.rowID = 0 return c.Next() } diff --git a/func.go b/func.go index e0dcb37..095e6c0 100644 --- a/func.go +++ b/func.go @@ -5,6 +5,7 @@ import ( "io" "iter" "sync" + "sync/atomic" "github.com/tetratelabs/wazero/api" @@ -196,21 +197,19 @@ func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int3 } func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp ptr_t, nArg int32, pArg ptr_t) { - args := getFuncArgs() - defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) + args := callbackArgs(db, nArg, pArg) + defer returnArgs(args) fn := util.GetHandle(db.ctx, pApp).(ScalarFunction) - callbackArgs(db, args[:nArg], pArg) - fn(Context{db, pCtx}, args[:nArg]...) + fn(Context{db, pCtx}, *args...) } func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, nArg int32, pArg ptr_t) { - args := getFuncArgs() - defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) - callbackArgs(db, args[:nArg], pArg) + args := callbackArgs(db, nArg, pArg) + defer returnArgs(args) fn, _ := callbackAggregate(db, pAgg, pApp) - fn.Step(Context{db, pCtx}, args[:nArg]...) + fn.Step(Context{db, pCtx}, *args...) } func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, final int32) { @@ -234,12 +233,11 @@ func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, } func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t, nArg int32, pArg ptr_t) { - args := getFuncArgs() - defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) - callbackArgs(db, args[:nArg], pArg) + args := callbackArgs(db, nArg, pArg) + defer returnArgs(args) fn := util.GetHandle(db.ctx, pAgg).(WindowFunction) - fn.Inverse(Context{db, pCtx}, args[:nArg]...) + fn.Inverse(Context{db, pCtx}, *args...) } func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) { @@ -258,28 +256,31 @@ func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) { return fn, 0 } -func callbackArgs(db *Conn, arg []Value, pArg ptr_t) { - for i := range arg { - arg[i] = Value{ +var ( + valueArgsPool sync.Pool + valueArgsLen atomic.Int32 +) + +func callbackArgs(db *Conn, nArg int32, pArg ptr_t) *[]Value { + arg, ok := valueArgsPool.Get().(*[]Value) + if !ok || cap(*arg) < int(nArg) { + max := valueArgsLen.Or(nArg) | nArg + lst := make([]Value, max) + arg = &lst + } + lst := (*arg)[:nArg] + for i := range lst { + lst[i] = Value{ c: db, handle: util.Read32[ptr_t](db.mod, pArg+ptr_t(i)*ptrlen), } } + *arg = lst + return arg } -var funcArgsPool sync.Pool - -func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) { - clear(p[:]) - funcArgsPool.Put(p) -} - -func getFuncArgs() *[_MAX_FUNCTION_ARG]Value { - if p := funcArgsPool.Get(); p == nil { - return new([_MAX_FUNCTION_ARG]Value) - } else { - return p.(*[_MAX_FUNCTION_ARG]Value) - } +func returnArgs(p *[]Value) { + valueArgsPool.Put(p) } type aggregateFunc struct { @@ -291,7 +292,7 @@ type aggregateFunc struct { func (a *aggregateFunc) Step(ctx Context, arg ...Value) { a.ctx = ctx - a.arg = arg + a.arg = append(a.arg[:0], arg...) if _, more := a.next(); !more { a.stop() } diff --git a/vtab.go b/vtab.go index 884aaaa..f4282d9 100644 --- a/vtab.go +++ b/vtab.go @@ -162,6 +162,7 @@ type VTabDestroyer interface { } // A VTabUpdater allows a virtual table to be updated. +// Implementations must not retain arg. type VTabUpdater interface { VTab // https://sqlite.org/vtab.html#xupdate @@ -241,6 +242,7 @@ type VTabSavepointer interface { // to loop through the virtual table. // A VTabCursor may optionally implement // [io.Closer] to free resources. +// Implementations of Filter must not retain arg. // // https://sqlite.org/c3ref/vtab_cursor.html type VTabCursor interface { @@ -489,12 +491,12 @@ func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo } func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab ptr_t, nArg int32, pArg, pRowID ptr_t) res_t { - vtab := vtabGetHandle(ctx, mod, pVTab).(VTabUpdater) - db := ctx.Value(connKey{}).(*Conn) - args := make([]Value, nArg) - callbackArgs(db, args, pArg) - rowID, err := vtab.Update(args...) + args := callbackArgs(db, nArg, pArg) + defer returnArgs(args) + + vtab := vtabGetHandle(ctx, mod, pVTab).(VTabUpdater) + rowID, err := vtab.Update(*args...) if err == nil { util.Write64(mod, pRowID, rowID) } @@ -593,15 +595,17 @@ func cursorCloseCallback(ctx context.Context, mod api.Module, pCur ptr_t) res_t } func cursorFilterCallback(ctx context.Context, mod api.Module, pCur ptr_t, idxNum int32, idxStr ptr_t, nArg int32, pArg ptr_t) res_t { - cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) db := ctx.Value(connKey{}).(*Conn) - args := make([]Value, nArg) - callbackArgs(db, args, pArg) + args := callbackArgs(db, nArg, pArg) + defer returnArgs(args) + var idxName string if idxStr != 0 { idxName = util.ReadString(mod, idxStr, _MAX_LENGTH) } - err := cursor.Filter(int(idxNum), idxName, args...) + + cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) + err := cursor.Filter(int(idxNum), idxName, *args...) return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err) }