diff --git a/conn.go b/conn.go index d777e16..9e28c3c 100644 --- a/conn.go +++ b/conn.go @@ -12,6 +12,7 @@ type Conn struct { ctx context.Context api sqliteAPI mem memory + arena arena handle uint32 } @@ -39,11 +40,11 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) { if err != nil { return nil, err } + c.arena = c.newArena(1024) - namePtr := c.newString(filename) - connPtr := c.new(ptrlen) - defer c.free(namePtr) - defer c.free(connPtr) + defer c.arena.reset() + connPtr := c.arena.new(ptrlen) + namePtr := c.arena.string(filename) r, err := c.api.open.Call(c.ctx, uint64(namePtr), uint64(connPtr), uint64(flags), 0) if err != nil { @@ -87,8 +88,8 @@ func (c *Conn) Close() error { // // https://www.sqlite.org/c3ref/exec.html func (c *Conn) Exec(sql string) error { - sqlPtr := c.newString(sql) - defer c.free(sqlPtr) + defer c.arena.reset() + sqlPtr := c.arena.string(sql) r, err := c.api.exec.Call(c.ctx, uint64(c.handle), uint64(sqlPtr), 0, 0, 0) if err != nil { @@ -109,12 +110,10 @@ func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) { // // https://www.sqlite.org/c3ref/prepare.html func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) { - sqlPtr := c.newString(sql) - stmtPtr := c.new(ptrlen) - tailPtr := c.new(ptrlen) - defer c.free(sqlPtr) - defer c.free(stmtPtr) - defer c.free(tailPtr) + defer c.arena.reset() + sqlPtr := c.arena.string(sql) + stmtPtr := c.arena.new(ptrlen) + tailPtr := c.arena.new(ptrlen) r, err := c.api.prepare.Call(c.ctx, uint64(c.handle), uint64(sqlPtr), uint64(len(sql)+1), uint64(flags), @@ -186,13 +185,13 @@ func (c *Conn) free(ptr uint32) { } } -func (c *Conn) new(len uint32) uint32 { - r, err := c.api.malloc.Call(c.ctx, uint64(len)) +func (c *Conn) new(size uint32) uint32 { + r, err := c.api.malloc.Call(c.ctx, uint64(size)) if err != nil { panic(err) } ptr := uint32(r[0]) - if ptr == 0 && len != 0 { + if ptr == 0 && size != 0 { panic(oomErr) } return ptr @@ -202,19 +201,54 @@ func (c *Conn) newBytes(b []byte) uint32 { if b == nil { return 0 } - - siz := uint32(len(b)) - ptr := c.new(siz) - buf := c.mem.view(ptr, siz) - copy(buf, b) + ptr := c.new(uint32(len(b))) + c.mem.writeBytes(ptr, b) return ptr } func (c *Conn) newString(s string) uint32 { - siz := uint32(len(s) + 1) - ptr := c.new(siz) - buf := c.mem.view(ptr, siz) - buf[len(s)] = 0 - copy(buf, s) + ptr := c.new(uint32(len(s) + 1)) + c.mem.writeString(ptr, s) + return ptr +} + +func (c *Conn) newArena(size uint32) arena { + return arena{ + c: c, + size: size, + base: c.new(size), + } +} + +type arena struct { + c *Conn + base uint32 + next uint32 + size uint32 + ptrs []uint32 +} + +func (a *arena) reset() { + for _, ptr := range a.ptrs { + a.c.free(ptr) + } + a.ptrs = nil + a.next = 0 +} + +func (a *arena) new(size uint32) uint32 { + if a.next+size <= a.size { + ptr := a.base + a.next + a.next += size + return ptr + } + ptr := a.c.new(size) + a.ptrs = append(a.ptrs, ptr) + return ptr +} + +func (a *arena) string(s string) uint32 { + ptr := a.new(uint32(len(s) + 1)) + a.c.mem.writeString(ptr, s) return ptr } diff --git a/conn_test.go b/conn_test.go index 83ed24b..26fd106 100644 --- a/conn_test.go +++ b/conn_test.go @@ -112,6 +112,36 @@ func TestConn_new(t *testing.T) { t.Error("want panic") } +func TestConn_newArena(t *testing.T) { + db, err := Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + arena := db.newArena(16) + defer arena.reset() + + const title = "Lorem ipsum" + + ptr := arena.string(title) + if ptr == 0 { + t.Fatalf("got nullptr") + } + if got := db.mem.readString(ptr, math.MaxUint32); got != title { + t.Errorf("got %q, want %q", got, title) + } + + const body = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + ptr = arena.string(body) + if ptr == 0 { + t.Fatalf("got nullptr") + } + if got := db.mem.readString(ptr, math.MaxUint32); got != body { + t.Errorf("got %q, want %q", got, body) + } +} + func TestConn_newBytes(t *testing.T) { db, err := Open(":memory:") if err != nil { diff --git a/mem.go b/mem.go index 3a7f115..a29ae62 100644 --- a/mem.go +++ b/mem.go @@ -99,9 +99,13 @@ func (m memory) readString(ptr, maxlen uint32) string { } } +func (m memory) writeBytes(ptr uint32, b []byte) { + buf := m.view(ptr, uint32(len(b))) + copy(buf, b) +} + func (m memory) writeString(ptr uint32, s string) { - siz := uint32(len(s) + 1) - buf := m.view(ptr, siz) + buf := m.view(ptr, uint32(len(s)+1)) buf[len(s)] = 0 copy(buf, s) } diff --git a/vfs.go b/vfs.go index 78369a5..90d653b 100644 --- a/vfs.go +++ b/vfs.go @@ -113,11 +113,11 @@ func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull // Or using [os.Readlink] to resolve a symbolic link (as the Unix VFS did). // This might be buggy on Windows (the Windows VFS doesn't try). - siz := uint32(len(abs) + 1) - if siz > nFull { + size := uint32(len(abs) + 1) + if size > nFull { return uint32(CANTOPEN_FULLPATH) } - mem := memory{mod}.view(zFull, siz) + mem := memory{mod}.view(zFull, size) mem[len(abs)] = 0 copy(mem, abs)