Refactor.

This commit is contained in:
Nuno Cruces
2023-01-28 12:47:39 +00:00
parent f4125bcd93
commit 0ace464670
7 changed files with 90 additions and 72 deletions

3
api.go
View File

@@ -18,8 +18,7 @@ func newConn(module api.Module) *Conn {
destructor := memory{module}.readUint32(uint32(global.Get()))
return &Conn{
module: module,
memory: memory{module},
mem: memory{module},
api: sqliteAPI{
malloc: getFun("malloc"),
free: getFun("free"),

View File

@@ -3,10 +3,12 @@ package sqlite3
import (
"context"
"os"
"strconv"
"sync"
"sync/atomic"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
// Configure SQLite.
@@ -15,33 +17,43 @@ var (
Path string // Path to load the binary from.
)
var (
once sync.Once
wasm wazero.Runtime
module wazero.CompiledModule
counter atomic.Uint64
)
var sqlite3 sqlite3Runtime
func compile() {
ctx := context.Background()
type sqlite3Runtime struct {
once sync.Once
runtime wazero.Runtime
compiled wazero.CompiledModule
instances atomic.Uint64
ctx context.Context
err error
}
wasm = wazero.NewRuntime(ctx)
if err := vfsInstantiate(ctx, wasm); err != nil {
panic(err)
func (s *sqlite3Runtime) instantiateModule(ctx context.Context) (api.Module, error) {
s.ctx = ctx
s.once.Do(s.compileModule)
if s.err != nil {
return nil, s.err
}
if Binary == nil && Path != "" {
if bin, err := os.ReadFile(Path); err != nil {
panic(err)
} else {
Binary = bin
cfg := wazero.NewModuleConfig().
WithName("sqlite3-" + strconv.FormatUint(s.instances.Add(1), 10))
return s.runtime.InstantiateModule(ctx, s.compiled, cfg)
}
func (s *sqlite3Runtime) compileModule() {
s.runtime = wazero.NewRuntime(s.ctx)
s.err = vfsInstantiate(s.ctx, s.runtime)
if s.err != nil {
return
}
bin := Binary
if bin == nil && Path != "" {
bin, s.err = os.ReadFile(Path)
if s.err != nil {
return
}
}
if m, err := wasm.CompileModule(ctx, Binary); err != nil {
panic(err)
} else {
module = m
}
s.compiled, s.err = s.runtime.CompileModule(s.ctx, bin)
}

37
conn.go
View File

@@ -2,18 +2,13 @@ package sqlite3
import (
"context"
"strconv"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
type Conn struct {
ctx context.Context
handle uint32
module api.Module
memory memory
api sqliteAPI
mem memory
handle uint32
}
func Open(filename string) (conn *Conn, err error) {
@@ -21,12 +16,8 @@ func Open(filename string) (conn *Conn, err error) {
}
func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
once.Do(compile)
ctx := context.Background()
cfg := wazero.NewModuleConfig().
WithName("sqlite3-" + strconv.FormatUint(counter.Add(1), 10))
module, err := wasm.InstantiateModule(ctx, module, cfg)
module, err := sqlite3.instantiateModule(ctx)
if err != nil {
return nil, err
}
@@ -48,7 +39,7 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
return nil, err
}
c.handle = c.memory.readUint32(connPtr)
c.handle = c.mem.readUint32(connPtr)
if err := c.error(r[0]); err != nil {
return nil, err
}
@@ -64,7 +55,7 @@ func (c *Conn) Close() error {
if err := c.error(r[0]); err != nil {
return err
}
return c.module.Close(c.ctx)
return c.mem.mod.Close(c.ctx)
}
func (c *Conn) Exec(sql string) error {
@@ -98,8 +89,8 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
}
stmt = &Stmt{c: c}
stmt.handle = c.memory.readUint32(stmtPtr)
i := c.memory.readUint32(tailPtr)
stmt.handle = c.mem.readUint32(stmtPtr)
i := c.mem.readUint32(tailPtr)
tail = sql[i-sqlPtr:]
if err := c.error(r[0]); err != nil {
@@ -130,12 +121,12 @@ func (c *Conn) error(rc uint64) error {
// Do this first, sqlite3_errmsg is guaranteed to never change the value of the error code.
r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle))
if r != nil {
err.msg = c.getString(uint32(r[0]), 512)
err.msg = c.mem.readString(uint32(r[0]), 512)
}
r, _ = c.api.errstr.Call(c.ctx, rc)
if r != nil {
err.str = c.getString(uint32(r[0]), 512)
err.str = c.mem.readString(uint32(r[0]), 512)
}
if err.msg == err.str {
@@ -161,7 +152,7 @@ func (c *Conn) new(len uint32) uint32 {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 || ptr >= c.memory.size() {
if ptr == 0 || ptr >= c.mem.size() {
panic(oomErr)
}
return ptr
@@ -174,7 +165,7 @@ func (c *Conn) newBytes(b []byte) uint32 {
siz := uint32(len(b))
ptr := c.new(siz)
buf, ok := c.memory.read(ptr, siz)
buf, ok := c.mem.read(ptr, siz)
if !ok {
c.api.free.Call(c.ctx, uint64(ptr))
panic(rangeErr)
@@ -187,7 +178,7 @@ func (c *Conn) newBytes(b []byte) uint32 {
func (c *Conn) newString(s string) uint32 {
siz := uint32(len(s) + 1)
ptr := c.new(siz)
buf, ok := c.memory.read(ptr, siz)
buf, ok := c.mem.read(ptr, siz)
if !ok {
c.api.free.Call(c.ctx, uint64(ptr))
panic(rangeErr)
@@ -197,7 +188,3 @@ func (c *Conn) newString(s string) uint32 {
copy(buf, s)
return ptr
}
func (c *Conn) getString(ptr, maxlen uint32) string {
return c.memory.readString(ptr, maxlen)
}

View File

@@ -37,7 +37,7 @@ func TestConn_newBytes(t *testing.T) {
}
want := buf
if got := db.memory.mustRead(ptr, uint32(len(want))); !bytes.Equal(got, want) {
if got := db.mem.mustRead(ptr, uint32(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -61,7 +61,7 @@ func TestConn_newString(t *testing.T) {
}
want := str + "\000"
if got := db.memory.mustRead(ptr, uint32(len(want))); string(got) != want {
if got := db.mem.mustRead(ptr, uint32(len(want))); string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -85,22 +85,22 @@ func TestConn_getString(t *testing.T) {
}
want := "sqlite3"
if got := db.getString(ptr, math.MaxUint32); got != want {
if got := db.mem.readString(ptr, math.MaxUint32); got != want {
t.Errorf("got %q, want %q", got, want)
}
if got := db.getString(ptr, 0); got != "" {
if got := db.mem.readString(ptr, 0); got != "" {
t.Errorf("got %q, want empty", got)
}
func() {
defer func() { _ = recover() }()
db.getString(ptr, uint32(len(want)/2))
db.mem.readString(ptr, uint32(len(want)/2))
t.Error("should have panicked")
}()
func() {
defer func() { _ = recover() }()
db.getString(0, math.MaxUint32)
db.mem.readString(0, math.MaxUint32)
t.Error("should have panicked")
}()
}

10
mem.go
View File

@@ -99,3 +99,13 @@ func (m memory) readString(ptr, maxlen uint32) string {
return string(buf[:i])
}
}
func (m memory) writeString(ptr uint32, s string) {
siz := uint32(len(s) + 1)
buf, ok := m.read(ptr, siz)
if !ok {
panic(rangeErr)
}
buf[len(s)] = 0
copy(buf, s)
}

View File

@@ -163,7 +163,7 @@ func (s *Stmt) ColumnText(col int) string {
panic(err)
}
mem := s.c.memory.mustRead(ptr, uint32(r[0]))
mem := s.c.mem.mustRead(ptr, uint32(r[0]))
return string(mem)
}
@@ -190,6 +190,6 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
panic(err)
}
mem := s.c.memory.mustRead(ptr, uint32(r[0]))
mem := s.c.mem.mustRead(ptr, uint32(r[0]))
return append(buf[0:0], mem...)
}

View File

@@ -7,12 +7,20 @@ import (
"io/fs"
"math/rand"
"os"
"path/filepath"
"testing"
"time"
"github.com/ncruces/julianday"
)
func Test_vfsExit(t *testing.T) {
mem := newMemory(128)
defer func() { _ = recover() }()
vfsExit(context.TODO(), mem.mod, 1)
t.Error("should have panicked")
}
func Test_vfsLocaltime(t *testing.T) {
mem := newMemory(128)
@@ -96,43 +104,40 @@ func Test_vfsCurrentTime(t *testing.T) {
}
func Test_vfsCurrentTime64(t *testing.T) {
memory := make(mockMemory, 128)
module := &mockModule{&memory}
mem := newMemory(128)
now := time.Now()
time.Sleep(time.Millisecond)
rc := vfsCurrentTime64(context.TODO(), module, 0, 4)
rc := vfsCurrentTime64(context.TODO(), mem.mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
day, nsec := julianday.Date(now)
want := day*86_400_000 + nsec/1_000_000
if got, _ := memory.ReadUint64Le(4); int64(got)-want > 100 {
if got := mem.readUint64(4); float32(got) != float32(want) {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsFullPathname(t *testing.T) {
memory := make(mockMemory, 128+_MAX_PATHNAME)
module := &mockModule{&memory}
mem := newMemory(128)
mem.writeString(4, ".")
memory.Write(4, []byte{'.', 0})
rc := vfsFullPathname(context.TODO(), module, 0, 4, 0, 8)
rc := vfsFullPathname(context.TODO(), mem.mod, 0, 4, 0, 8)
if rc != uint32(CANTOPEN_FULLPATH) {
t.Errorf("returned %d, want %d", rc, CANTOPEN_FULLPATH)
}
rc = vfsFullPathname(context.TODO(), module, 0, 4, _MAX_PATHNAME, 8)
rc = vfsFullPathname(context.TODO(), mem.mod, 0, 4, _MAX_PATHNAME, 8)
if rc != _OK {
t.Fatal("returned", rc)
}
// want, _ := filepath.Abs(".")
// if got := getString(&memory, 8, _MAX_PATHNAME); got != want {
// t.Errorf("got %v, want %v", got, want)
// }
want, _ := filepath.Abs(".")
if got := mem.readString(8, _MAX_PATHNAME); got != want {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsDelete(t *testing.T) {
@@ -156,7 +161,12 @@ func Test_vfsDelete(t *testing.T) {
}
if _, err := os.Stat(name); !errors.Is(err, fs.ErrNotExist) {
t.Error("did not delete the file")
t.Fatal("did not delete the file")
}
rc = vfsDelete(context.TODO(), module, 0, 4, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
}