Towards shared modules: refactor.

This commit is contained in:
Nuno Cruces
2023-03-06 23:41:54 +00:00
parent d291738b81
commit 6c96a019e6
3 changed files with 49 additions and 49 deletions

View File

@@ -42,14 +42,13 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
}
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
ctx := context.Background()
mod, err := instantiateModule()
if err != nil {
return nil, err
}
defer func() {
if conn == nil {
mod.Close(ctx)
mod.close()
} else {
runtime.SetFinalizer(conn, finalizer[Conn](3))
}
@@ -128,7 +127,7 @@ func (c *Conn) Close() error {
c.handle = 0
runtime.SetFinalizer(c, nil)
return c.mem.mod.Close(c.ctx)
return c.module.close()
}
// Exec is a convenience function that allows an application to run

View File

@@ -77,16 +77,18 @@ func compileModule() {
}
type module struct {
api.Module
ctx context.Context
mem memory
api sqliteAPI
}
func newModule(mod api.Module) (m *module, err error) {
m = &module{}
m.mem = memory{mod}
m.ctx = context.Background()
getFun := func(name string) api.Function {
f := m.ExportedFunction(name)
f := mod.ExportedFunction(name)
if f == nil {
err = noFuncErr + errorString(name)
return nil
@@ -95,7 +97,7 @@ func newModule(mod api.Module) (m *module, err error) {
}
getVal := func(name string) uint32 {
global := m.ExportedGlobal(name)
global := mod.ExportedGlobal(name)
if global == nil {
err = noGlobalErr + errorString(name)
return 0
@@ -103,11 +105,6 @@ func newModule(mod api.Module) (m *module, err error) {
return m.mem.readUint32(uint32(global.Get()))
}
m = &module{
Module: mod,
mem: memory{mod},
ctx: context.Background(),
}
m.api = sqliteAPI{
free: getFun("free"),
malloc: getFun("malloc"),
@@ -164,6 +161,10 @@ func newModule(mod api.Module) (m *module, err error) {
return
}
func (m *module) close() error {
return m.mem.mod.Close(m.ctx)
}
func (m *module) error(rc uint64, handle uint32, sql ...string) error {
if rc == _OK {
return nil

View File

@@ -9,43 +9,43 @@ import (
func TestConn_error_OOM(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
defer func() { _ = recover() }()
db.error(uint64(NOMEM))
m.error(uint64(NOMEM), 0)
t.Error("want panic")
}
func TestConn_call_nil(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
defer func() { _ = recover() }()
db.call(db.api.free)
m.call(m.api.free)
t.Error("want panic")
}
func TestConn_new(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
testOOM := func(size uint64) {
defer func() { _ = recover() }()
db.new(size)
m.new(size)
t.Error("want panic")
}
@@ -56,13 +56,13 @@ func TestConn_new(t *testing.T) {
func TestConn_newArena(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
arena := db.newArena(16)
arena := m.newArena(16)
defer arena.free()
const title = "Lorem ipsum"
@@ -71,7 +71,7 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := db.mem.readString(ptr, math.MaxUint32); got != title {
if got := m.mem.readString(ptr, math.MaxUint32); got != title {
t.Errorf("got %q, want %q", got, title)
}
@@ -80,7 +80,7 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := db.mem.readString(ptr, math.MaxUint32); got != body {
if got := m.mem.readString(ptr, math.MaxUint32); got != body {
t.Errorf("got %q, want %q", got, body)
}
arena.free()
@@ -89,25 +89,25 @@ func TestConn_newArena(t *testing.T) {
func TestConn_newBytes(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
ptr := db.newBytes(nil)
ptr := m.newBytes(nil)
if ptr != 0 {
t.Errorf("got %#x, want nullptr", ptr)
}
buf := []byte("sqlite3")
ptr = db.newBytes(buf)
ptr = m.newBytes(buf)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := buf
if got := db.mem.view(ptr, uint64(len(want))); !bytes.Equal(got, want) {
if got := m.mem.view(ptr, uint64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -115,25 +115,25 @@ func TestConn_newBytes(t *testing.T) {
func TestConn_newString(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
ptr := db.newString("")
ptr := m.newString("")
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
str := "sqlite3\000sqlite3"
ptr = db.newString(str)
ptr = m.newString(str)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := str + "\000"
if got := db.mem.view(ptr, uint64(len(want))); string(got) != want {
if got := m.mem.view(ptr, uint64(len(want))); string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -141,40 +141,40 @@ func TestConn_newString(t *testing.T) {
func TestConn_getString(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
ptr := db.newString("")
ptr := m.newString("")
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
str := "sqlite3" + "\000 drop this"
ptr = db.newString(str)
ptr = m.newString(str)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := "sqlite3"
if got := db.mem.readString(ptr, math.MaxUint32); got != want {
if got := m.mem.readString(ptr, math.MaxUint32); got != want {
t.Errorf("got %q, want %q", got, want)
}
if got := db.mem.readString(ptr, 0); got != "" {
if got := m.mem.readString(ptr, 0); got != "" {
t.Errorf("got %q, want empty", got)
}
func() {
defer func() { _ = recover() }()
db.mem.readString(ptr, uint32(len(want)/2))
m.mem.readString(ptr, uint32(len(want)/2))
t.Error("want panic")
}()
func() {
defer func() { _ = recover() }()
db.mem.readString(0, math.MaxUint32)
m.mem.readString(0, math.MaxUint32)
t.Error("want panic")
}()
}
@@ -182,18 +182,18 @@ func TestConn_getString(t *testing.T) {
func TestConn_free(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
db.free(0)
m.free(0)
ptr := db.new(1)
ptr := m.new(1)
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
db.free(ptr)
m.free(ptr)
}