From 0ace4646709fcbc6f0749bb37fecee71aedaf366 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sat, 28 Jan 2023 12:47:39 +0000 Subject: [PATCH] Refactor. --- api.go | 3 +-- compile.go | 56 +++++++++++++++++++++++++++++++--------------------- conn.go | 37 +++++++++++----------------------- conn_test.go | 12 +++++------ mem.go | 10 ++++++++++ stmt.go | 4 ++-- vfs_test.go | 40 +++++++++++++++++++++++-------------- 7 files changed, 90 insertions(+), 72 deletions(-) diff --git a/api.go b/api.go index 3a265f9..e8997a9 100644 --- a/api.go +++ b/api.go @@ -18,8 +18,7 @@ func newConn(module api.Module) *Conn { destructor := memory{module}.readUint32(uint32(global.Get())) return &Conn{ - module: module, - memory: memory{module}, + mem: memory{module}, api: sqliteAPI{ malloc: getFun("malloc"), free: getFun("free"), diff --git a/compile.go b/compile.go index f91130f..cd2ddb1 100644 --- a/compile.go +++ b/compile.go @@ -3,10 +3,12 @@ package sqlite3 import ( "context" "os" + "strconv" "sync" "sync/atomic" "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" ) // Configure SQLite. @@ -15,33 +17,43 @@ var ( Path string // Path to load the binary from. ) -var ( - once sync.Once - wasm wazero.Runtime - module wazero.CompiledModule - counter atomic.Uint64 -) +var sqlite3 sqlite3Runtime -func compile() { - ctx := context.Background() +type sqlite3Runtime struct { + once sync.Once + runtime wazero.Runtime + compiled wazero.CompiledModule + instances atomic.Uint64 + ctx context.Context + err error +} - wasm = wazero.NewRuntime(ctx) - - if err := vfsInstantiate(ctx, wasm); err != nil { - panic(err) +func (s *sqlite3Runtime) instantiateModule(ctx context.Context) (api.Module, error) { + s.ctx = ctx + s.once.Do(s.compileModule) + if s.err != nil { + return nil, s.err } - if Binary == nil && Path != "" { - if bin, err := os.ReadFile(Path); err != nil { - panic(err) - } else { - Binary = bin + cfg := wazero.NewModuleConfig(). + WithName("sqlite3-" + strconv.FormatUint(s.instances.Add(1), 10)) + return s.runtime.InstantiateModule(ctx, s.compiled, cfg) +} + +func (s *sqlite3Runtime) compileModule() { + s.runtime = wazero.NewRuntime(s.ctx) + s.err = vfsInstantiate(s.ctx, s.runtime) + if s.err != nil { + return + } + + bin := Binary + if bin == nil && Path != "" { + bin, s.err = os.ReadFile(Path) + if s.err != nil { + return } } - if m, err := wasm.CompileModule(ctx, Binary); err != nil { - panic(err) - } else { - module = m - } + s.compiled, s.err = s.runtime.CompileModule(s.ctx, bin) } diff --git a/conn.go b/conn.go index f3daa9e..298b2f2 100644 --- a/conn.go +++ b/conn.go @@ -2,18 +2,13 @@ package sqlite3 import ( "context" - "strconv" - - "github.com/tetratelabs/wazero" - "github.com/tetratelabs/wazero/api" ) type Conn struct { ctx context.Context - handle uint32 - module api.Module - memory memory api sqliteAPI + mem memory + handle uint32 } func Open(filename string) (conn *Conn, err error) { @@ -21,12 +16,8 @@ func Open(filename string) (conn *Conn, err error) { } func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) { - once.Do(compile) - ctx := context.Background() - cfg := wazero.NewModuleConfig(). - WithName("sqlite3-" + strconv.FormatUint(counter.Add(1), 10)) - module, err := wasm.InstantiateModule(ctx, module, cfg) + module, err := sqlite3.instantiateModule(ctx) if err != nil { return nil, err } @@ -48,7 +39,7 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) { return nil, err } - c.handle = c.memory.readUint32(connPtr) + c.handle = c.mem.readUint32(connPtr) if err := c.error(r[0]); err != nil { return nil, err } @@ -64,7 +55,7 @@ func (c *Conn) Close() error { if err := c.error(r[0]); err != nil { return err } - return c.module.Close(c.ctx) + return c.mem.mod.Close(c.ctx) } func (c *Conn) Exec(sql string) error { @@ -98,8 +89,8 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str } stmt = &Stmt{c: c} - stmt.handle = c.memory.readUint32(stmtPtr) - i := c.memory.readUint32(tailPtr) + stmt.handle = c.mem.readUint32(stmtPtr) + i := c.mem.readUint32(tailPtr) tail = sql[i-sqlPtr:] if err := c.error(r[0]); err != nil { @@ -130,12 +121,12 @@ func (c *Conn) error(rc uint64) error { // Do this first, sqlite3_errmsg is guaranteed to never change the value of the error code. r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle)) if r != nil { - err.msg = c.getString(uint32(r[0]), 512) + err.msg = c.mem.readString(uint32(r[0]), 512) } r, _ = c.api.errstr.Call(c.ctx, rc) if r != nil { - err.str = c.getString(uint32(r[0]), 512) + err.str = c.mem.readString(uint32(r[0]), 512) } if err.msg == err.str { @@ -161,7 +152,7 @@ func (c *Conn) new(len uint32) uint32 { panic(err) } ptr := uint32(r[0]) - if ptr == 0 || ptr >= c.memory.size() { + if ptr == 0 || ptr >= c.mem.size() { panic(oomErr) } return ptr @@ -174,7 +165,7 @@ func (c *Conn) newBytes(b []byte) uint32 { siz := uint32(len(b)) ptr := c.new(siz) - buf, ok := c.memory.read(ptr, siz) + buf, ok := c.mem.read(ptr, siz) if !ok { c.api.free.Call(c.ctx, uint64(ptr)) panic(rangeErr) @@ -187,7 +178,7 @@ func (c *Conn) newBytes(b []byte) uint32 { func (c *Conn) newString(s string) uint32 { siz := uint32(len(s) + 1) ptr := c.new(siz) - buf, ok := c.memory.read(ptr, siz) + buf, ok := c.mem.read(ptr, siz) if !ok { c.api.free.Call(c.ctx, uint64(ptr)) panic(rangeErr) @@ -197,7 +188,3 @@ func (c *Conn) newString(s string) uint32 { copy(buf, s) return ptr } - -func (c *Conn) getString(ptr, maxlen uint32) string { - return c.memory.readString(ptr, maxlen) -} diff --git a/conn_test.go b/conn_test.go index ee8b105..76b71a7 100644 --- a/conn_test.go +++ b/conn_test.go @@ -37,7 +37,7 @@ func TestConn_newBytes(t *testing.T) { } want := buf - if got := db.memory.mustRead(ptr, uint32(len(want))); !bytes.Equal(got, want) { + if got := db.mem.mustRead(ptr, uint32(len(want))); !bytes.Equal(got, want) { t.Errorf("got %q, want %q", got, want) } } @@ -61,7 +61,7 @@ func TestConn_newString(t *testing.T) { } want := str + "\000" - if got := db.memory.mustRead(ptr, uint32(len(want))); string(got) != want { + if got := db.mem.mustRead(ptr, uint32(len(want))); string(got) != want { t.Errorf("got %q, want %q", got, want) } } @@ -85,22 +85,22 @@ func TestConn_getString(t *testing.T) { } want := "sqlite3" - if got := db.getString(ptr, math.MaxUint32); got != want { + if got := db.mem.readString(ptr, math.MaxUint32); got != want { t.Errorf("got %q, want %q", got, want) } - if got := db.getString(ptr, 0); got != "" { + if got := db.mem.readString(ptr, 0); got != "" { t.Errorf("got %q, want empty", got) } func() { defer func() { _ = recover() }() - db.getString(ptr, uint32(len(want)/2)) + db.mem.readString(ptr, uint32(len(want)/2)) t.Error("should have panicked") }() func() { defer func() { _ = recover() }() - db.getString(0, math.MaxUint32) + db.mem.readString(0, math.MaxUint32) t.Error("should have panicked") }() } diff --git a/mem.go b/mem.go index da8f675..43dad72 100644 --- a/mem.go +++ b/mem.go @@ -99,3 +99,13 @@ func (m memory) readString(ptr, maxlen uint32) string { return string(buf[:i]) } } + +func (m memory) writeString(ptr uint32, s string) { + siz := uint32(len(s) + 1) + buf, ok := m.read(ptr, siz) + if !ok { + panic(rangeErr) + } + buf[len(s)] = 0 + copy(buf, s) +} diff --git a/stmt.go b/stmt.go index 0b3d3c6..b6c8dcc 100644 --- a/stmt.go +++ b/stmt.go @@ -163,7 +163,7 @@ func (s *Stmt) ColumnText(col int) string { panic(err) } - mem := s.c.memory.mustRead(ptr, uint32(r[0])) + mem := s.c.mem.mustRead(ptr, uint32(r[0])) return string(mem) } @@ -190,6 +190,6 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte { panic(err) } - mem := s.c.memory.mustRead(ptr, uint32(r[0])) + mem := s.c.mem.mustRead(ptr, uint32(r[0])) return append(buf[0:0], mem...) } diff --git a/vfs_test.go b/vfs_test.go index 33d1488..fc5ea33 100644 --- a/vfs_test.go +++ b/vfs_test.go @@ -7,12 +7,20 @@ import ( "io/fs" "math/rand" "os" + "path/filepath" "testing" "time" "github.com/ncruces/julianday" ) +func Test_vfsExit(t *testing.T) { + mem := newMemory(128) + defer func() { _ = recover() }() + vfsExit(context.TODO(), mem.mod, 1) + t.Error("should have panicked") +} + func Test_vfsLocaltime(t *testing.T) { mem := newMemory(128) @@ -96,43 +104,40 @@ func Test_vfsCurrentTime(t *testing.T) { } func Test_vfsCurrentTime64(t *testing.T) { - memory := make(mockMemory, 128) - module := &mockModule{&memory} + mem := newMemory(128) now := time.Now() time.Sleep(time.Millisecond) - rc := vfsCurrentTime64(context.TODO(), module, 0, 4) + rc := vfsCurrentTime64(context.TODO(), mem.mod, 0, 4) if rc != 0 { t.Fatal("returned", rc) } day, nsec := julianday.Date(now) want := day*86_400_000 + nsec/1_000_000 - if got, _ := memory.ReadUint64Le(4); int64(got)-want > 100 { + if got := mem.readUint64(4); float32(got) != float32(want) { t.Errorf("got %v, want %v", got, want) } } func Test_vfsFullPathname(t *testing.T) { - memory := make(mockMemory, 128+_MAX_PATHNAME) - module := &mockModule{&memory} + mem := newMemory(128) + mem.writeString(4, ".") - memory.Write(4, []byte{'.', 0}) - - rc := vfsFullPathname(context.TODO(), module, 0, 4, 0, 8) + rc := vfsFullPathname(context.TODO(), mem.mod, 0, 4, 0, 8) if rc != uint32(CANTOPEN_FULLPATH) { t.Errorf("returned %d, want %d", rc, CANTOPEN_FULLPATH) } - rc = vfsFullPathname(context.TODO(), module, 0, 4, _MAX_PATHNAME, 8) + rc = vfsFullPathname(context.TODO(), mem.mod, 0, 4, _MAX_PATHNAME, 8) if rc != _OK { t.Fatal("returned", rc) } - // want, _ := filepath.Abs(".") - // if got := getString(&memory, 8, _MAX_PATHNAME); got != want { - // t.Errorf("got %v, want %v", got, want) - // } + want, _ := filepath.Abs(".") + if got := mem.readString(8, _MAX_PATHNAME); got != want { + t.Errorf("got %v, want %v", got, want) + } } func Test_vfsDelete(t *testing.T) { @@ -156,7 +161,12 @@ func Test_vfsDelete(t *testing.T) { } if _, err := os.Stat(name); !errors.Is(err, fs.ErrNotExist) { - t.Error("did not delete the file") + t.Fatal("did not delete the file") + } + + rc = vfsDelete(context.TODO(), module, 0, 4, 1) + if rc != _OK { + t.Fatal("returned", rc) } }