Refactor.

This commit is contained in:
Nuno Cruces
2023-01-28 12:47:39 +00:00
parent f4125bcd93
commit 0ace464670
7 changed files with 90 additions and 72 deletions

3
api.go
View File

@@ -18,8 +18,7 @@ func newConn(module api.Module) *Conn {
destructor := memory{module}.readUint32(uint32(global.Get())) destructor := memory{module}.readUint32(uint32(global.Get()))
return &Conn{ return &Conn{
module: module, mem: memory{module},
memory: memory{module},
api: sqliteAPI{ api: sqliteAPI{
malloc: getFun("malloc"), malloc: getFun("malloc"),
free: getFun("free"), free: getFun("free"),

View File

@@ -3,10 +3,12 @@ package sqlite3
import ( import (
"context" "context"
"os" "os"
"strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
) )
// Configure SQLite. // Configure SQLite.
@@ -15,33 +17,43 @@ var (
Path string // Path to load the binary from. Path string // Path to load the binary from.
) )
var ( var sqlite3 sqlite3Runtime
once sync.Once
wasm wazero.Runtime
module wazero.CompiledModule
counter atomic.Uint64
)
func compile() { type sqlite3Runtime struct {
ctx := context.Background() once sync.Once
runtime wazero.Runtime
compiled wazero.CompiledModule
instances atomic.Uint64
ctx context.Context
err error
}
wasm = wazero.NewRuntime(ctx) func (s *sqlite3Runtime) instantiateModule(ctx context.Context) (api.Module, error) {
s.ctx = ctx
if err := vfsInstantiate(ctx, wasm); err != nil { s.once.Do(s.compileModule)
panic(err) if s.err != nil {
return nil, s.err
} }
if Binary == nil && Path != "" { cfg := wazero.NewModuleConfig().
if bin, err := os.ReadFile(Path); err != nil { WithName("sqlite3-" + strconv.FormatUint(s.instances.Add(1), 10))
panic(err) return s.runtime.InstantiateModule(ctx, s.compiled, cfg)
} else { }
Binary = bin
func (s *sqlite3Runtime) compileModule() {
s.runtime = wazero.NewRuntime(s.ctx)
s.err = vfsInstantiate(s.ctx, s.runtime)
if s.err != nil {
return
}
bin := Binary
if bin == nil && Path != "" {
bin, s.err = os.ReadFile(Path)
if s.err != nil {
return
} }
} }
if m, err := wasm.CompileModule(ctx, Binary); err != nil { s.compiled, s.err = s.runtime.CompileModule(s.ctx, bin)
panic(err)
} else {
module = m
}
} }

37
conn.go
View File

@@ -2,18 +2,13 @@ package sqlite3
import ( import (
"context" "context"
"strconv"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
) )
type Conn struct { type Conn struct {
ctx context.Context ctx context.Context
handle uint32
module api.Module
memory memory
api sqliteAPI api sqliteAPI
mem memory
handle uint32
} }
func Open(filename string) (conn *Conn, err error) { func Open(filename string) (conn *Conn, err error) {
@@ -21,12 +16,8 @@ func Open(filename string) (conn *Conn, err error) {
} }
func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) { func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
once.Do(compile)
ctx := context.Background() ctx := context.Background()
cfg := wazero.NewModuleConfig(). module, err := sqlite3.instantiateModule(ctx)
WithName("sqlite3-" + strconv.FormatUint(counter.Add(1), 10))
module, err := wasm.InstantiateModule(ctx, module, cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -48,7 +39,7 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
return nil, err return nil, err
} }
c.handle = c.memory.readUint32(connPtr) c.handle = c.mem.readUint32(connPtr)
if err := c.error(r[0]); err != nil { if err := c.error(r[0]); err != nil {
return nil, err return nil, err
} }
@@ -64,7 +55,7 @@ func (c *Conn) Close() error {
if err := c.error(r[0]); err != nil { if err := c.error(r[0]); err != nil {
return err return err
} }
return c.module.Close(c.ctx) return c.mem.mod.Close(c.ctx)
} }
func (c *Conn) Exec(sql string) error { func (c *Conn) Exec(sql string) error {
@@ -98,8 +89,8 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
} }
stmt = &Stmt{c: c} stmt = &Stmt{c: c}
stmt.handle = c.memory.readUint32(stmtPtr) stmt.handle = c.mem.readUint32(stmtPtr)
i := c.memory.readUint32(tailPtr) i := c.mem.readUint32(tailPtr)
tail = sql[i-sqlPtr:] tail = sql[i-sqlPtr:]
if err := c.error(r[0]); err != nil { if err := c.error(r[0]); err != nil {
@@ -130,12 +121,12 @@ func (c *Conn) error(rc uint64) error {
// Do this first, sqlite3_errmsg is guaranteed to never change the value of the error code. // Do this first, sqlite3_errmsg is guaranteed to never change the value of the error code.
r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle)) r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle))
if r != nil { if r != nil {
err.msg = c.getString(uint32(r[0]), 512) err.msg = c.mem.readString(uint32(r[0]), 512)
} }
r, _ = c.api.errstr.Call(c.ctx, rc) r, _ = c.api.errstr.Call(c.ctx, rc)
if r != nil { if r != nil {
err.str = c.getString(uint32(r[0]), 512) err.str = c.mem.readString(uint32(r[0]), 512)
} }
if err.msg == err.str { if err.msg == err.str {
@@ -161,7 +152,7 @@ func (c *Conn) new(len uint32) uint32 {
panic(err) panic(err)
} }
ptr := uint32(r[0]) ptr := uint32(r[0])
if ptr == 0 || ptr >= c.memory.size() { if ptr == 0 || ptr >= c.mem.size() {
panic(oomErr) panic(oomErr)
} }
return ptr return ptr
@@ -174,7 +165,7 @@ func (c *Conn) newBytes(b []byte) uint32 {
siz := uint32(len(b)) siz := uint32(len(b))
ptr := c.new(siz) ptr := c.new(siz)
buf, ok := c.memory.read(ptr, siz) buf, ok := c.mem.read(ptr, siz)
if !ok { if !ok {
c.api.free.Call(c.ctx, uint64(ptr)) c.api.free.Call(c.ctx, uint64(ptr))
panic(rangeErr) panic(rangeErr)
@@ -187,7 +178,7 @@ func (c *Conn) newBytes(b []byte) uint32 {
func (c *Conn) newString(s string) uint32 { func (c *Conn) newString(s string) uint32 {
siz := uint32(len(s) + 1) siz := uint32(len(s) + 1)
ptr := c.new(siz) ptr := c.new(siz)
buf, ok := c.memory.read(ptr, siz) buf, ok := c.mem.read(ptr, siz)
if !ok { if !ok {
c.api.free.Call(c.ctx, uint64(ptr)) c.api.free.Call(c.ctx, uint64(ptr))
panic(rangeErr) panic(rangeErr)
@@ -197,7 +188,3 @@ func (c *Conn) newString(s string) uint32 {
copy(buf, s) copy(buf, s)
return ptr return ptr
} }
func (c *Conn) getString(ptr, maxlen uint32) string {
return c.memory.readString(ptr, maxlen)
}

View File

@@ -37,7 +37,7 @@ func TestConn_newBytes(t *testing.T) {
} }
want := buf want := buf
if got := db.memory.mustRead(ptr, uint32(len(want))); !bytes.Equal(got, want) { if got := db.mem.mustRead(ptr, uint32(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want) t.Errorf("got %q, want %q", got, want)
} }
} }
@@ -61,7 +61,7 @@ func TestConn_newString(t *testing.T) {
} }
want := str + "\000" want := str + "\000"
if got := db.memory.mustRead(ptr, uint32(len(want))); string(got) != want { if got := db.mem.mustRead(ptr, uint32(len(want))); string(got) != want {
t.Errorf("got %q, want %q", got, want) t.Errorf("got %q, want %q", got, want)
} }
} }
@@ -85,22 +85,22 @@ func TestConn_getString(t *testing.T) {
} }
want := "sqlite3" want := "sqlite3"
if got := db.getString(ptr, math.MaxUint32); got != want { if got := db.mem.readString(ptr, math.MaxUint32); got != want {
t.Errorf("got %q, want %q", got, want) t.Errorf("got %q, want %q", got, want)
} }
if got := db.getString(ptr, 0); got != "" { if got := db.mem.readString(ptr, 0); got != "" {
t.Errorf("got %q, want empty", got) t.Errorf("got %q, want empty", got)
} }
func() { func() {
defer func() { _ = recover() }() defer func() { _ = recover() }()
db.getString(ptr, uint32(len(want)/2)) db.mem.readString(ptr, uint32(len(want)/2))
t.Error("should have panicked") t.Error("should have panicked")
}() }()
func() { func() {
defer func() { _ = recover() }() defer func() { _ = recover() }()
db.getString(0, math.MaxUint32) db.mem.readString(0, math.MaxUint32)
t.Error("should have panicked") t.Error("should have panicked")
}() }()
} }

10
mem.go
View File

@@ -99,3 +99,13 @@ func (m memory) readString(ptr, maxlen uint32) string {
return string(buf[:i]) return string(buf[:i])
} }
} }
func (m memory) writeString(ptr uint32, s string) {
siz := uint32(len(s) + 1)
buf, ok := m.read(ptr, siz)
if !ok {
panic(rangeErr)
}
buf[len(s)] = 0
copy(buf, s)
}

View File

@@ -163,7 +163,7 @@ func (s *Stmt) ColumnText(col int) string {
panic(err) panic(err)
} }
mem := s.c.memory.mustRead(ptr, uint32(r[0])) mem := s.c.mem.mustRead(ptr, uint32(r[0]))
return string(mem) return string(mem)
} }
@@ -190,6 +190,6 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
panic(err) panic(err)
} }
mem := s.c.memory.mustRead(ptr, uint32(r[0])) mem := s.c.mem.mustRead(ptr, uint32(r[0]))
return append(buf[0:0], mem...) return append(buf[0:0], mem...)
} }

View File

@@ -7,12 +7,20 @@ import (
"io/fs" "io/fs"
"math/rand" "math/rand"
"os" "os"
"path/filepath"
"testing" "testing"
"time" "time"
"github.com/ncruces/julianday" "github.com/ncruces/julianday"
) )
func Test_vfsExit(t *testing.T) {
mem := newMemory(128)
defer func() { _ = recover() }()
vfsExit(context.TODO(), mem.mod, 1)
t.Error("should have panicked")
}
func Test_vfsLocaltime(t *testing.T) { func Test_vfsLocaltime(t *testing.T) {
mem := newMemory(128) mem := newMemory(128)
@@ -96,43 +104,40 @@ func Test_vfsCurrentTime(t *testing.T) {
} }
func Test_vfsCurrentTime64(t *testing.T) { func Test_vfsCurrentTime64(t *testing.T) {
memory := make(mockMemory, 128) mem := newMemory(128)
module := &mockModule{&memory}
now := time.Now() now := time.Now()
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
rc := vfsCurrentTime64(context.TODO(), module, 0, 4) rc := vfsCurrentTime64(context.TODO(), mem.mod, 0, 4)
if rc != 0 { if rc != 0 {
t.Fatal("returned", rc) t.Fatal("returned", rc)
} }
day, nsec := julianday.Date(now) day, nsec := julianday.Date(now)
want := day*86_400_000 + nsec/1_000_000 want := day*86_400_000 + nsec/1_000_000
if got, _ := memory.ReadUint64Le(4); int64(got)-want > 100 { if got := mem.readUint64(4); float32(got) != float32(want) {
t.Errorf("got %v, want %v", got, want) t.Errorf("got %v, want %v", got, want)
} }
} }
func Test_vfsFullPathname(t *testing.T) { func Test_vfsFullPathname(t *testing.T) {
memory := make(mockMemory, 128+_MAX_PATHNAME) mem := newMemory(128)
module := &mockModule{&memory} mem.writeString(4, ".")
memory.Write(4, []byte{'.', 0}) rc := vfsFullPathname(context.TODO(), mem.mod, 0, 4, 0, 8)
rc := vfsFullPathname(context.TODO(), module, 0, 4, 0, 8)
if rc != uint32(CANTOPEN_FULLPATH) { if rc != uint32(CANTOPEN_FULLPATH) {
t.Errorf("returned %d, want %d", rc, CANTOPEN_FULLPATH) t.Errorf("returned %d, want %d", rc, CANTOPEN_FULLPATH)
} }
rc = vfsFullPathname(context.TODO(), module, 0, 4, _MAX_PATHNAME, 8) rc = vfsFullPathname(context.TODO(), mem.mod, 0, 4, _MAX_PATHNAME, 8)
if rc != _OK { if rc != _OK {
t.Fatal("returned", rc) t.Fatal("returned", rc)
} }
// want, _ := filepath.Abs(".") want, _ := filepath.Abs(".")
// if got := getString(&memory, 8, _MAX_PATHNAME); got != want { if got := mem.readString(8, _MAX_PATHNAME); got != want {
// t.Errorf("got %v, want %v", got, want) t.Errorf("got %v, want %v", got, want)
// } }
} }
func Test_vfsDelete(t *testing.T) { func Test_vfsDelete(t *testing.T) {
@@ -156,7 +161,12 @@ func Test_vfsDelete(t *testing.T) {
} }
if _, err := os.Stat(name); !errors.Is(err, fs.ErrNotExist) { if _, err := os.Stat(name); !errors.Is(err, fs.ErrNotExist) {
t.Error("did not delete the file") t.Fatal("did not delete the file")
}
rc = vfsDelete(context.TODO(), module, 0, 4, 1)
if rc != _OK {
t.Fatal("returned", rc)
} }
} }