From 6c96a019e67d4dd43ec8ce0ffb02dcab7f8f7464 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Mon, 6 Mar 2023 23:41:54 +0000 Subject: [PATCH] Towards shared modules: refactor. --- conn.go | 5 +-- module.go | 19 ++++----- conn_test.go => module_test.go | 74 +++++++++++++++++----------------- 3 files changed, 49 insertions(+), 49 deletions(-) rename conn_test.go => module_test.go (67%) diff --git a/conn.go b/conn.go index e0a7861..5b22aff 100644 --- a/conn.go +++ b/conn.go @@ -42,14 +42,13 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) { } func newConn(filename string, flags OpenFlag) (conn *Conn, err error) { - ctx := context.Background() mod, err := instantiateModule() if err != nil { return nil, err } defer func() { if conn == nil { - mod.Close(ctx) + mod.close() } else { runtime.SetFinalizer(conn, finalizer[Conn](3)) } @@ -128,7 +127,7 @@ func (c *Conn) Close() error { c.handle = 0 runtime.SetFinalizer(c, nil) - return c.mem.mod.Close(c.ctx) + return c.module.close() } // Exec is a convenience function that allows an application to run diff --git a/module.go b/module.go index 29d6d4e..9d8f845 100644 --- a/module.go +++ b/module.go @@ -77,16 +77,18 @@ func compileModule() { } type module struct { - api.Module - ctx context.Context mem memory api sqliteAPI } func newModule(mod api.Module) (m *module, err error) { + m = &module{} + m.mem = memory{mod} + m.ctx = context.Background() + getFun := func(name string) api.Function { - f := m.ExportedFunction(name) + f := mod.ExportedFunction(name) if f == nil { err = noFuncErr + errorString(name) return nil @@ -95,7 +97,7 @@ func newModule(mod api.Module) (m *module, err error) { } getVal := func(name string) uint32 { - global := m.ExportedGlobal(name) + global := mod.ExportedGlobal(name) if global == nil { err = noGlobalErr + errorString(name) return 0 @@ -103,11 +105,6 @@ func newModule(mod api.Module) (m *module, err error) { return m.mem.readUint32(uint32(global.Get())) } - m = &module{ - Module: mod, - mem: memory{mod}, - ctx: context.Background(), - } m.api = sqliteAPI{ free: getFun("free"), malloc: getFun("malloc"), @@ -164,6 +161,10 @@ func newModule(mod api.Module) (m *module, err error) { return } +func (m *module) close() error { + return m.mem.mod.Close(m.ctx) +} + func (m *module) error(rc uint64, handle uint32, sql ...string) error { if rc == _OK { return nil diff --git a/conn_test.go b/module_test.go similarity index 67% rename from conn_test.go rename to module_test.go index 7aff16d..a0bcebd 100644 --- a/conn_test.go +++ b/module_test.go @@ -9,43 +9,43 @@ import ( func TestConn_error_OOM(t *testing.T) { t.Parallel() - db, err := Open(":memory:") + m, err := instantiateModule() if err != nil { t.Fatal(err) } - defer db.Close() + defer m.close() defer func() { _ = recover() }() - db.error(uint64(NOMEM)) + m.error(uint64(NOMEM), 0) t.Error("want panic") } func TestConn_call_nil(t *testing.T) { t.Parallel() - db, err := Open(":memory:") + m, err := instantiateModule() if err != nil { t.Fatal(err) } - defer db.Close() + defer m.close() defer func() { _ = recover() }() - db.call(db.api.free) + m.call(m.api.free) t.Error("want panic") } func TestConn_new(t *testing.T) { t.Parallel() - db, err := Open(":memory:") + m, err := instantiateModule() if err != nil { t.Fatal(err) } - defer db.Close() + defer m.close() testOOM := func(size uint64) { defer func() { _ = recover() }() - db.new(size) + m.new(size) t.Error("want panic") } @@ -56,13 +56,13 @@ func TestConn_new(t *testing.T) { func TestConn_newArena(t *testing.T) { t.Parallel() - db, err := Open(":memory:") + m, err := instantiateModule() if err != nil { t.Fatal(err) } - defer db.Close() + defer m.close() - arena := db.newArena(16) + arena := m.newArena(16) defer arena.free() const title = "Lorem ipsum" @@ -71,7 +71,7 @@ func TestConn_newArena(t *testing.T) { if ptr == 0 { t.Fatalf("got nullptr") } - if got := db.mem.readString(ptr, math.MaxUint32); got != title { + if got := m.mem.readString(ptr, math.MaxUint32); got != title { t.Errorf("got %q, want %q", got, title) } @@ -80,7 +80,7 @@ func TestConn_newArena(t *testing.T) { if ptr == 0 { t.Fatalf("got nullptr") } - if got := db.mem.readString(ptr, math.MaxUint32); got != body { + if got := m.mem.readString(ptr, math.MaxUint32); got != body { t.Errorf("got %q, want %q", got, body) } arena.free() @@ -89,25 +89,25 @@ func TestConn_newArena(t *testing.T) { func TestConn_newBytes(t *testing.T) { t.Parallel() - db, err := Open(":memory:") + m, err := instantiateModule() if err != nil { t.Fatal(err) } - defer db.Close() + defer m.close() - ptr := db.newBytes(nil) + ptr := m.newBytes(nil) if ptr != 0 { t.Errorf("got %#x, want nullptr", ptr) } buf := []byte("sqlite3") - ptr = db.newBytes(buf) + ptr = m.newBytes(buf) if ptr == 0 { t.Fatal("got nullptr, want a pointer") } want := buf - if got := db.mem.view(ptr, uint64(len(want))); !bytes.Equal(got, want) { + if got := m.mem.view(ptr, uint64(len(want))); !bytes.Equal(got, want) { t.Errorf("got %q, want %q", got, want) } } @@ -115,25 +115,25 @@ func TestConn_newBytes(t *testing.T) { func TestConn_newString(t *testing.T) { t.Parallel() - db, err := Open(":memory:") + m, err := instantiateModule() if err != nil { t.Fatal(err) } - defer db.Close() + defer m.close() - ptr := db.newString("") + ptr := m.newString("") if ptr == 0 { t.Error("got nullptr, want a pointer") } str := "sqlite3\000sqlite3" - ptr = db.newString(str) + ptr = m.newString(str) if ptr == 0 { t.Fatal("got nullptr, want a pointer") } want := str + "\000" - if got := db.mem.view(ptr, uint64(len(want))); string(got) != want { + if got := m.mem.view(ptr, uint64(len(want))); string(got) != want { t.Errorf("got %q, want %q", got, want) } } @@ -141,40 +141,40 @@ func TestConn_newString(t *testing.T) { func TestConn_getString(t *testing.T) { t.Parallel() - db, err := Open(":memory:") + m, err := instantiateModule() if err != nil { t.Fatal(err) } - defer db.Close() + defer m.close() - ptr := db.newString("") + ptr := m.newString("") if ptr == 0 { t.Error("got nullptr, want a pointer") } str := "sqlite3" + "\000 drop this" - ptr = db.newString(str) + ptr = m.newString(str) if ptr == 0 { t.Fatal("got nullptr, want a pointer") } want := "sqlite3" - if got := db.mem.readString(ptr, math.MaxUint32); got != want { + if got := m.mem.readString(ptr, math.MaxUint32); got != want { t.Errorf("got %q, want %q", got, want) } - if got := db.mem.readString(ptr, 0); got != "" { + if got := m.mem.readString(ptr, 0); got != "" { t.Errorf("got %q, want empty", got) } func() { defer func() { _ = recover() }() - db.mem.readString(ptr, uint32(len(want)/2)) + m.mem.readString(ptr, uint32(len(want)/2)) t.Error("want panic") }() func() { defer func() { _ = recover() }() - db.mem.readString(0, math.MaxUint32) + m.mem.readString(0, math.MaxUint32) t.Error("want panic") }() } @@ -182,18 +182,18 @@ func TestConn_getString(t *testing.T) { func TestConn_free(t *testing.T) { t.Parallel() - db, err := Open(":memory:") + m, err := instantiateModule() if err != nil { t.Fatal(err) } - defer db.Close() + defer m.close() - db.free(0) + m.free(0) - ptr := db.new(1) + ptr := m.new(1) if ptr == 0 { t.Error("got nullptr, want a pointer") } - db.free(ptr) + m.free(ptr) }