diff --git a/func.go b/func.go index 6b69368..09c1b7a 100644 --- a/func.go +++ b/func.go @@ -2,6 +2,7 @@ package sqlite3 import ( "context" + "io" "sync" "github.com/tetratelabs/wazero/api" @@ -85,12 +86,18 @@ func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn var funcPtr ptr_t defer c.arena.mark()() namePtr := c.arena.string(name) - if fn != nil { - funcPtr = util.AddHandle(c.ctx, fn) - } call := "sqlite3_create_aggregate_function_go" - if _, ok := fn().(WindowFunction); ok { - call = "sqlite3_create_window_function_go" + if fn != nil { + agg := fn() + if c, ok := agg.(io.Closer); ok { + if err := c.Close(); err != nil { + return err + } + } + if _, ok := agg.(WindowFunction); ok { + call = "sqlite3_create_window_function_go" + } + funcPtr = util.AddHandle(c.ctx, fn) } rc := res_t(c.call(call, stk_t(c.handle), stk_t(namePtr), stk_t(nArg), @@ -172,7 +179,13 @@ func finalCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t) db := ctx.Value(connKey{}).(*Conn) fn, handle := callbackAggregate(db, pAgg, pApp) fn.Value(Context{db, pCtx}) - if err := util.DelHandle(ctx, handle); err != nil { + var err error + if handle != 0 { + err = util.DelHandle(ctx, handle) + } else if c, ok := fn.(io.Closer); ok { + err = c.Close() + } + if err != nil { Context{db, pCtx}.ResultError(err) return // notest }