mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Refactor.
This commit is contained in:
3
api.go
3
api.go
@@ -18,8 +18,7 @@ func newConn(module api.Module) *Conn {
|
|||||||
destructor := memory{module}.readUint32(uint32(global.Get()))
|
destructor := memory{module}.readUint32(uint32(global.Get()))
|
||||||
|
|
||||||
return &Conn{
|
return &Conn{
|
||||||
module: module,
|
mem: memory{module},
|
||||||
memory: memory{module},
|
|
||||||
api: sqliteAPI{
|
api: sqliteAPI{
|
||||||
malloc: getFun("malloc"),
|
malloc: getFun("malloc"),
|
||||||
free: getFun("free"),
|
free: getFun("free"),
|
||||||
|
|||||||
56
compile.go
56
compile.go
@@ -3,10 +3,12 @@ package sqlite3
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/tetratelabs/wazero"
|
"github.com/tetratelabs/wazero"
|
||||||
|
"github.com/tetratelabs/wazero/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Configure SQLite.
|
// Configure SQLite.
|
||||||
@@ -15,33 +17,43 @@ var (
|
|||||||
Path string // Path to load the binary from.
|
Path string // Path to load the binary from.
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var sqlite3 sqlite3Runtime
|
||||||
once sync.Once
|
|
||||||
wasm wazero.Runtime
|
|
||||||
module wazero.CompiledModule
|
|
||||||
counter atomic.Uint64
|
|
||||||
)
|
|
||||||
|
|
||||||
func compile() {
|
type sqlite3Runtime struct {
|
||||||
ctx := context.Background()
|
once sync.Once
|
||||||
|
runtime wazero.Runtime
|
||||||
|
compiled wazero.CompiledModule
|
||||||
|
instances atomic.Uint64
|
||||||
|
ctx context.Context
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
wasm = wazero.NewRuntime(ctx)
|
func (s *sqlite3Runtime) instantiateModule(ctx context.Context) (api.Module, error) {
|
||||||
|
s.ctx = ctx
|
||||||
if err := vfsInstantiate(ctx, wasm); err != nil {
|
s.once.Do(s.compileModule)
|
||||||
panic(err)
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
}
|
}
|
||||||
|
|
||||||
if Binary == nil && Path != "" {
|
cfg := wazero.NewModuleConfig().
|
||||||
if bin, err := os.ReadFile(Path); err != nil {
|
WithName("sqlite3-" + strconv.FormatUint(s.instances.Add(1), 10))
|
||||||
panic(err)
|
return s.runtime.InstantiateModule(ctx, s.compiled, cfg)
|
||||||
} else {
|
}
|
||||||
Binary = bin
|
|
||||||
|
func (s *sqlite3Runtime) compileModule() {
|
||||||
|
s.runtime = wazero.NewRuntime(s.ctx)
|
||||||
|
s.err = vfsInstantiate(s.ctx, s.runtime)
|
||||||
|
if s.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
bin := Binary
|
||||||
|
if bin == nil && Path != "" {
|
||||||
|
bin, s.err = os.ReadFile(Path)
|
||||||
|
if s.err != nil {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if m, err := wasm.CompileModule(ctx, Binary); err != nil {
|
s.compiled, s.err = s.runtime.CompileModule(s.ctx, bin)
|
||||||
panic(err)
|
|
||||||
} else {
|
|
||||||
module = m
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
37
conn.go
37
conn.go
@@ -2,18 +2,13 @@ package sqlite3
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/tetratelabs/wazero"
|
|
||||||
"github.com/tetratelabs/wazero/api"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
handle uint32
|
|
||||||
module api.Module
|
|
||||||
memory memory
|
|
||||||
api sqliteAPI
|
api sqliteAPI
|
||||||
|
mem memory
|
||||||
|
handle uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func Open(filename string) (conn *Conn, err error) {
|
func Open(filename string) (conn *Conn, err error) {
|
||||||
@@ -21,12 +16,8 @@ func Open(filename string) (conn *Conn, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
|
func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
|
||||||
once.Do(compile)
|
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
cfg := wazero.NewModuleConfig().
|
module, err := sqlite3.instantiateModule(ctx)
|
||||||
WithName("sqlite3-" + strconv.FormatUint(counter.Add(1), 10))
|
|
||||||
module, err := wasm.InstantiateModule(ctx, module, cfg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -48,7 +39,7 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.handle = c.memory.readUint32(connPtr)
|
c.handle = c.mem.readUint32(connPtr)
|
||||||
if err := c.error(r[0]); err != nil {
|
if err := c.error(r[0]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -64,7 +55,7 @@ func (c *Conn) Close() error {
|
|||||||
if err := c.error(r[0]); err != nil {
|
if err := c.error(r[0]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.module.Close(c.ctx)
|
return c.mem.mod.Close(c.ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Exec(sql string) error {
|
func (c *Conn) Exec(sql string) error {
|
||||||
@@ -98,8 +89,8 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
|
|||||||
}
|
}
|
||||||
|
|
||||||
stmt = &Stmt{c: c}
|
stmt = &Stmt{c: c}
|
||||||
stmt.handle = c.memory.readUint32(stmtPtr)
|
stmt.handle = c.mem.readUint32(stmtPtr)
|
||||||
i := c.memory.readUint32(tailPtr)
|
i := c.mem.readUint32(tailPtr)
|
||||||
tail = sql[i-sqlPtr:]
|
tail = sql[i-sqlPtr:]
|
||||||
|
|
||||||
if err := c.error(r[0]); err != nil {
|
if err := c.error(r[0]); err != nil {
|
||||||
@@ -130,12 +121,12 @@ func (c *Conn) error(rc uint64) error {
|
|||||||
// Do this first, sqlite3_errmsg is guaranteed to never change the value of the error code.
|
// Do this first, sqlite3_errmsg is guaranteed to never change the value of the error code.
|
||||||
r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle))
|
r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle))
|
||||||
if r != nil {
|
if r != nil {
|
||||||
err.msg = c.getString(uint32(r[0]), 512)
|
err.msg = c.mem.readString(uint32(r[0]), 512)
|
||||||
}
|
}
|
||||||
|
|
||||||
r, _ = c.api.errstr.Call(c.ctx, rc)
|
r, _ = c.api.errstr.Call(c.ctx, rc)
|
||||||
if r != nil {
|
if r != nil {
|
||||||
err.str = c.getString(uint32(r[0]), 512)
|
err.str = c.mem.readString(uint32(r[0]), 512)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err.msg == err.str {
|
if err.msg == err.str {
|
||||||
@@ -161,7 +152,7 @@ func (c *Conn) new(len uint32) uint32 {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
ptr := uint32(r[0])
|
ptr := uint32(r[0])
|
||||||
if ptr == 0 || ptr >= c.memory.size() {
|
if ptr == 0 || ptr >= c.mem.size() {
|
||||||
panic(oomErr)
|
panic(oomErr)
|
||||||
}
|
}
|
||||||
return ptr
|
return ptr
|
||||||
@@ -174,7 +165,7 @@ func (c *Conn) newBytes(b []byte) uint32 {
|
|||||||
|
|
||||||
siz := uint32(len(b))
|
siz := uint32(len(b))
|
||||||
ptr := c.new(siz)
|
ptr := c.new(siz)
|
||||||
buf, ok := c.memory.read(ptr, siz)
|
buf, ok := c.mem.read(ptr, siz)
|
||||||
if !ok {
|
if !ok {
|
||||||
c.api.free.Call(c.ctx, uint64(ptr))
|
c.api.free.Call(c.ctx, uint64(ptr))
|
||||||
panic(rangeErr)
|
panic(rangeErr)
|
||||||
@@ -187,7 +178,7 @@ func (c *Conn) newBytes(b []byte) uint32 {
|
|||||||
func (c *Conn) newString(s string) uint32 {
|
func (c *Conn) newString(s string) uint32 {
|
||||||
siz := uint32(len(s) + 1)
|
siz := uint32(len(s) + 1)
|
||||||
ptr := c.new(siz)
|
ptr := c.new(siz)
|
||||||
buf, ok := c.memory.read(ptr, siz)
|
buf, ok := c.mem.read(ptr, siz)
|
||||||
if !ok {
|
if !ok {
|
||||||
c.api.free.Call(c.ctx, uint64(ptr))
|
c.api.free.Call(c.ctx, uint64(ptr))
|
||||||
panic(rangeErr)
|
panic(rangeErr)
|
||||||
@@ -197,7 +188,3 @@ func (c *Conn) newString(s string) uint32 {
|
|||||||
copy(buf, s)
|
copy(buf, s)
|
||||||
return ptr
|
return ptr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) getString(ptr, maxlen uint32) string {
|
|
||||||
return c.memory.readString(ptr, maxlen)
|
|
||||||
}
|
|
||||||
|
|||||||
12
conn_test.go
12
conn_test.go
@@ -37,7 +37,7 @@ func TestConn_newBytes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
want := buf
|
want := buf
|
||||||
if got := db.memory.mustRead(ptr, uint32(len(want))); !bytes.Equal(got, want) {
|
if got := db.mem.mustRead(ptr, uint32(len(want))); !bytes.Equal(got, want) {
|
||||||
t.Errorf("got %q, want %q", got, want)
|
t.Errorf("got %q, want %q", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -61,7 +61,7 @@ func TestConn_newString(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
want := str + "\000"
|
want := str + "\000"
|
||||||
if got := db.memory.mustRead(ptr, uint32(len(want))); string(got) != want {
|
if got := db.mem.mustRead(ptr, uint32(len(want))); string(got) != want {
|
||||||
t.Errorf("got %q, want %q", got, want)
|
t.Errorf("got %q, want %q", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -85,22 +85,22 @@ func TestConn_getString(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
want := "sqlite3"
|
want := "sqlite3"
|
||||||
if got := db.getString(ptr, math.MaxUint32); got != want {
|
if got := db.mem.readString(ptr, math.MaxUint32); got != want {
|
||||||
t.Errorf("got %q, want %q", got, want)
|
t.Errorf("got %q, want %q", got, want)
|
||||||
}
|
}
|
||||||
if got := db.getString(ptr, 0); got != "" {
|
if got := db.mem.readString(ptr, 0); got != "" {
|
||||||
t.Errorf("got %q, want empty", got)
|
t.Errorf("got %q, want empty", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
defer func() { _ = recover() }()
|
defer func() { _ = recover() }()
|
||||||
db.getString(ptr, uint32(len(want)/2))
|
db.mem.readString(ptr, uint32(len(want)/2))
|
||||||
t.Error("should have panicked")
|
t.Error("should have panicked")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
defer func() { _ = recover() }()
|
defer func() { _ = recover() }()
|
||||||
db.getString(0, math.MaxUint32)
|
db.mem.readString(0, math.MaxUint32)
|
||||||
t.Error("should have panicked")
|
t.Error("should have panicked")
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|||||||
10
mem.go
10
mem.go
@@ -99,3 +99,13 @@ func (m memory) readString(ptr, maxlen uint32) string {
|
|||||||
return string(buf[:i])
|
return string(buf[:i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m memory) writeString(ptr uint32, s string) {
|
||||||
|
siz := uint32(len(s) + 1)
|
||||||
|
buf, ok := m.read(ptr, siz)
|
||||||
|
if !ok {
|
||||||
|
panic(rangeErr)
|
||||||
|
}
|
||||||
|
buf[len(s)] = 0
|
||||||
|
copy(buf, s)
|
||||||
|
}
|
||||||
|
|||||||
4
stmt.go
4
stmt.go
@@ -163,7 +163,7 @@ func (s *Stmt) ColumnText(col int) string {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
mem := s.c.memory.mustRead(ptr, uint32(r[0]))
|
mem := s.c.mem.mustRead(ptr, uint32(r[0]))
|
||||||
return string(mem)
|
return string(mem)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,6 +190,6 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
mem := s.c.memory.mustRead(ptr, uint32(r[0]))
|
mem := s.c.mem.mustRead(ptr, uint32(r[0]))
|
||||||
return append(buf[0:0], mem...)
|
return append(buf[0:0], mem...)
|
||||||
}
|
}
|
||||||
|
|||||||
40
vfs_test.go
40
vfs_test.go
@@ -7,12 +7,20 @@ import (
|
|||||||
"io/fs"
|
"io/fs"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ncruces/julianday"
|
"github.com/ncruces/julianday"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func Test_vfsExit(t *testing.T) {
|
||||||
|
mem := newMemory(128)
|
||||||
|
defer func() { _ = recover() }()
|
||||||
|
vfsExit(context.TODO(), mem.mod, 1)
|
||||||
|
t.Error("should have panicked")
|
||||||
|
}
|
||||||
|
|
||||||
func Test_vfsLocaltime(t *testing.T) {
|
func Test_vfsLocaltime(t *testing.T) {
|
||||||
mem := newMemory(128)
|
mem := newMemory(128)
|
||||||
|
|
||||||
@@ -96,43 +104,40 @@ func Test_vfsCurrentTime(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Test_vfsCurrentTime64(t *testing.T) {
|
func Test_vfsCurrentTime64(t *testing.T) {
|
||||||
memory := make(mockMemory, 128)
|
mem := newMemory(128)
|
||||||
module := &mockModule{&memory}
|
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
time.Sleep(time.Millisecond)
|
time.Sleep(time.Millisecond)
|
||||||
rc := vfsCurrentTime64(context.TODO(), module, 0, 4)
|
rc := vfsCurrentTime64(context.TODO(), mem.mod, 0, 4)
|
||||||
if rc != 0 {
|
if rc != 0 {
|
||||||
t.Fatal("returned", rc)
|
t.Fatal("returned", rc)
|
||||||
}
|
}
|
||||||
|
|
||||||
day, nsec := julianday.Date(now)
|
day, nsec := julianday.Date(now)
|
||||||
want := day*86_400_000 + nsec/1_000_000
|
want := day*86_400_000 + nsec/1_000_000
|
||||||
if got, _ := memory.ReadUint64Le(4); int64(got)-want > 100 {
|
if got := mem.readUint64(4); float32(got) != float32(want) {
|
||||||
t.Errorf("got %v, want %v", got, want)
|
t.Errorf("got %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_vfsFullPathname(t *testing.T) {
|
func Test_vfsFullPathname(t *testing.T) {
|
||||||
memory := make(mockMemory, 128+_MAX_PATHNAME)
|
mem := newMemory(128)
|
||||||
module := &mockModule{&memory}
|
mem.writeString(4, ".")
|
||||||
|
|
||||||
memory.Write(4, []byte{'.', 0})
|
rc := vfsFullPathname(context.TODO(), mem.mod, 0, 4, 0, 8)
|
||||||
|
|
||||||
rc := vfsFullPathname(context.TODO(), module, 0, 4, 0, 8)
|
|
||||||
if rc != uint32(CANTOPEN_FULLPATH) {
|
if rc != uint32(CANTOPEN_FULLPATH) {
|
||||||
t.Errorf("returned %d, want %d", rc, CANTOPEN_FULLPATH)
|
t.Errorf("returned %d, want %d", rc, CANTOPEN_FULLPATH)
|
||||||
}
|
}
|
||||||
|
|
||||||
rc = vfsFullPathname(context.TODO(), module, 0, 4, _MAX_PATHNAME, 8)
|
rc = vfsFullPathname(context.TODO(), mem.mod, 0, 4, _MAX_PATHNAME, 8)
|
||||||
if rc != _OK {
|
if rc != _OK {
|
||||||
t.Fatal("returned", rc)
|
t.Fatal("returned", rc)
|
||||||
}
|
}
|
||||||
|
|
||||||
// want, _ := filepath.Abs(".")
|
want, _ := filepath.Abs(".")
|
||||||
// if got := getString(&memory, 8, _MAX_PATHNAME); got != want {
|
if got := mem.readString(8, _MAX_PATHNAME); got != want {
|
||||||
// t.Errorf("got %v, want %v", got, want)
|
t.Errorf("got %v, want %v", got, want)
|
||||||
// }
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_vfsDelete(t *testing.T) {
|
func Test_vfsDelete(t *testing.T) {
|
||||||
@@ -156,7 +161,12 @@ func Test_vfsDelete(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, err := os.Stat(name); !errors.Is(err, fs.ErrNotExist) {
|
if _, err := os.Stat(name); !errors.Is(err, fs.ErrNotExist) {
|
||||||
t.Error("did not delete the file")
|
t.Fatal("did not delete the file")
|
||||||
|
}
|
||||||
|
|
||||||
|
rc = vfsDelete(context.TODO(), module, 0, 4, 1)
|
||||||
|
if rc != _OK {
|
||||||
|
t.Fatal("returned", rc)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user