From 9bf14becaf871f13e898450d861b2a3d6fd5db2b Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Wed, 29 Nov 2023 10:38:03 +0000 Subject: [PATCH] Reentrant arenas. --- backup.go | 2 +- blob.go | 18 +++++++++--------- conn.go | 9 ++++----- context.go | 4 ++-- func.go | 6 +++--- sqlite.go | 18 ++++++++++++------ stmt.go | 2 +- vtab.go | 15 +++++++++------ 8 files changed, 41 insertions(+), 33 deletions(-) diff --git a/backup.go b/backup.go index 70c289c..e38cccc 100644 --- a/backup.go +++ b/backup.go @@ -62,7 +62,7 @@ func (src *Conn) BackupInit(srcDB, dstURI string) (*Backup, error) { } func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string) (*Backup, error) { - defer c.arena.reset() + defer c.arena.mark()() dstPtr := c.arena.string(dstName) srcPtr := c.arena.string(srcName) diff --git a/blob.go b/blob.go index 0fd65b0..e977cad 100644 --- a/blob.go +++ b/blob.go @@ -30,7 +30,7 @@ var _ io.ReadWriteSeeker = &Blob{} // https://sqlite.org/c3ref/blob_open.html func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) { c.checkInterrupt() - defer c.arena.reset() + defer c.arena.mark()() blobPtr := c.arena.new(ptrlen) dbPtr := c.arena.string(db) tablePtr := c.arena.string(table) @@ -92,8 +92,8 @@ func (b *Blob) Read(p []byte) (n int, err error) { want = avail } - ptr := b.c.new(uint64(want)) - defer b.c.free(ptr) + defer b.c.arena.mark()() + ptr := b.c.arena.new(uint64(want)) r := b.c.call(b.c.api.blobRead, uint64(b.handle), uint64(ptr), uint64(want), uint64(b.offset)) @@ -124,8 +124,8 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) { want = avail } - ptr := b.c.new(uint64(want)) - defer b.c.free(ptr) + defer b.c.arena.mark()() + ptr := b.c.arena.new(uint64(want)) for want > 0 { r := b.c.call(b.c.api.blobRead, uint64(b.handle), @@ -158,8 +158,8 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) { // // https://sqlite.org/c3ref/blob_write.html func (b *Blob) Write(p []byte) (n int, err error) { - ptr := b.c.newBytes(p) - defer b.c.free(ptr) + defer b.c.arena.mark()() + ptr := b.c.arena.bytes(p) r := b.c.call(b.c.api.blobWrite, uint64(b.handle), uint64(ptr), uint64(len(p)), uint64(b.offset)) @@ -187,8 +187,8 @@ func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) { want = 1 } - ptr := b.c.new(uint64(want)) - defer b.c.free(ptr) + defer b.c.arena.mark()() + ptr := b.c.arena.new(uint64(want)) for { mem := util.View(b.c.mod, ptr, uint64(want)) diff --git a/conn.go b/conn.go index 12e9e94..f690d68 100644 --- a/conn.go +++ b/conn.go @@ -72,7 +72,7 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) { } func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { - defer c.arena.reset() + defer c.arena.mark()() connPtr := c.arena.new(ptrlen) namePtr := c.arena.string(filename) @@ -96,7 +96,6 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { } } - c.arena.reset() pragmaPtr := c.arena.string(pragmas.String()) r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0) if err := c.sqlite.error(r, handle, pragmas.String()); err != nil { @@ -151,11 +150,11 @@ func (c *Conn) Close() error { // https://sqlite.org/c3ref/exec.html func (c *Conn) Exec(sql string) error { c.checkInterrupt() - defer c.arena.reset() + defer c.arena.mark()() sqlPtr := c.arena.string(sql) r := c.call(c.api.exec, uint64(c.handle), uint64(sqlPtr), 0, 0, 0) - return c.error(r) + return c.error(r, sql) } // Prepare calls [Conn.PrepareFlags] with no flags. @@ -177,7 +176,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str return nil, "", nil } - defer c.arena.reset() + defer c.arena.mark()() stmtPtr := c.arena.new(ptrlen) tailPtr := c.arena.new(ptrlen) sqlPtr := c.arena.string(sql) diff --git a/context.go b/context.go index 2264c67..3ee1d2b 100644 --- a/context.go +++ b/context.go @@ -208,10 +208,10 @@ func (ctx Context) ResultError(err error) { msg, code := errorCode(err, _OK) if msg != "" { - ptr := ctx.c.newString(msg) + defer ctx.c.arena.mark()() + ptr := ctx.c.arena.string(msg) ctx.c.call(ctx.c.api.resultError, uint64(ctx.handle), uint64(ptr), uint64(len(msg))) - ctx.c.free(ptr) } if code != _OK { ctx.c.call(ctx.c.api.resultErrorCode, diff --git a/func.go b/func.go index 9e4e60b..7322beb 100644 --- a/func.go +++ b/func.go @@ -21,7 +21,7 @@ func (c *Conn) AnyCollationNeeded() { // // https://sqlite.org/c3ref/create_collation.html func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { - defer c.arena.reset() + defer c.arena.mark()() namePtr := c.arena.string(name) funcPtr := util.AddHandle(c.ctx, fn) r := c.call(c.api.createCollation, @@ -33,7 +33,7 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { // // https://sqlite.org/c3ref/create_function.html func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error { - defer c.arena.reset() + defer c.arena.mark()() namePtr := c.arena.string(name) funcPtr := util.AddHandle(c.ctx, fn) r := c.call(c.api.createFunction, @@ -48,7 +48,7 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func( // // https://sqlite.org/c3ref/create_function.html func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error { - defer c.arena.reset() + defer c.arena.mark()() call := c.api.createAggregate namePtr := c.arena.string(name) funcPtr := util.AddHandle(c.ctx, fn) diff --git a/sqlite.go b/sqlite.go index 2c3dbe2..3f4c2c8 100644 --- a/sqlite.go +++ b/sqlite.go @@ -294,17 +294,23 @@ func (a *arena) free() { if a.sqlt == nil { return } - a.reset() + for _, ptr := range a.ptrs { + a.sqlt.free(ptr) + } a.sqlt.free(a.base) a.sqlt = nil } -func (a *arena) reset() { - for _, ptr := range a.ptrs { - a.sqlt.free(ptr) +func (a *arena) mark() (reset func()) { + ptrs := len(a.ptrs) + next := a.next + return func() { + for _, ptr := range a.ptrs[ptrs:] { + a.sqlt.free(ptr) + } + a.ptrs = a.ptrs[:ptrs] + a.next = next } - a.ptrs = nil - a.next = 0 } func (a *arena) new(size uint64) uint32 { diff --git a/stmt.go b/stmt.go index 087497c..38cfd72 100644 --- a/stmt.go +++ b/stmt.go @@ -104,7 +104,7 @@ func (s *Stmt) BindCount() int { // // https://sqlite.org/c3ref/bind_parameter_index.html func (s *Stmt) BindIndex(name string) int { - defer s.c.arena.reset() + defer s.c.arena.mark()() namePtr := s.c.arena.string(name) r := s.c.call(s.c.api.bindIndex, uint64(s.handle), uint64(namePtr)) diff --git a/vtab.go b/vtab.go index 7274221..576b14c 100644 --- a/vtab.go +++ b/vtab.go @@ -53,7 +53,7 @@ func CreateModule[T VTab](db *Conn, name string, create, connect VTabConstructor flags |= VTAB_SAVEPOINTER } - defer db.arena.reset() + defer db.arena.mark()() namePtr := db.arena.string(name) modulePtr := util.AddHandle(db.ctx, module[T]{create, connect}) r := db.call(db.api.createModule, uint64(db.handle), @@ -70,7 +70,7 @@ func implements[T any](typ reflect.Type) bool { // // https://sqlite.org/c3ref/declare_vtab.html func (c *Conn) DeclareVtab(sql string) error { - // The arena will be cleared by the prepare or exec method. + defer c.arena.mark()() sqlPtr := c.arena.string(sql) r := c.call(c.api.declareVTab, uint64(c.handle), uint64(sqlPtr)) return c.error(r) @@ -255,7 +255,7 @@ type IndexConstraintUsage struct { // // https://sqlite.org/c3ref/vtab_rhs_value.html func (idx *IndexInfo) RHSValue(column int) (*Value, error) { - // The arena will be cleared by the prepare or exec method. + defer idx.c.arena.mark()() valPtr := idx.c.arena.new(ptrlen) r := idx.c.call(idx.c.api.vtabRHSValue, uint64(idx.handle), uint64(column), uint64(valPtr)) @@ -318,7 +318,7 @@ func (idx *IndexInfo) save() { util.WriteUint32(mod, ptr+20, uint32(idx.IdxNum)) if idx.IdxStr != "" { util.WriteUint32(mod, ptr+24, idx.c.newString(idx.IdxStr)) - util.WriteUint32(mod, ptr+28, 1) + util.WriteUint32(mod, ptr+28, 1) // needToFreeIdxStr } if idx.OrderByConsumed { util.WriteUint32(mod, ptr+32, 1) @@ -567,11 +567,14 @@ func vtabError(ctx context.Context, mod api.Module, ptr, kind uint32, err error) if msg != "" && ptr != 0 { switch kind { case _VTAB_ERROR: - ptr = ptr + 8 + ptr = ptr + 8 // zErrMsg case _CURSOR_ERROR: - ptr = util.ReadUint32(mod, ptr) + 8 + ptr = util.ReadUint32(mod, ptr) + 8 // pVtab->zErrMsg } db := ctx.Value(connKey{}).(*Conn) + if ptr := util.ReadUint32(mod, ptr); ptr != 0 { + db.free(ptr) + } util.WriteUint32(mod, ptr, db.newString(msg)) } return code