From 1e76a322bcd2151c14378b2d9da74259b895130f Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Thu, 12 Jan 2023 05:57:09 +0000 Subject: [PATCH] Refactor. --- .gitignore | 4 +- build.sh | 2 +- cmd/main.go | 16 ++++ embed/init.go | 14 ++++ go.mod | 2 +- sqlite3.go | 205 ++++++++++++++++++++++++++++---------------------- 6 files changed, 150 insertions(+), 93 deletions(-) create mode 100644 cmd/main.go create mode 100644 embed/init.go diff --git a/.gitignore b/.gitignore index 853b7f1..7ee877a 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,6 @@ # Dependency directories (remove the comment below to include it) # vendor/ - tools -sqlite3.wasm \ No newline at end of file + +embed/sqlite3.wasm \ No newline at end of file diff --git a/build.sh b/build.sh index 43d6377..9768b9b 100755 --- a/build.sh +++ b/build.sh @@ -1,6 +1,6 @@ #!/bin/sh -zig cc --target=wasm32-wasi -O2 -o sqlite3.wasm sqlite3/*.c \ +zig cc --target=wasm32-wasi -O2 -o embed/sqlite3.wasm sqlite3/*.c \ -DSQLITE_OS_OTHER=1 -DSQLITE_BYTEORDER=1234 \ -DHAVE_ISNAN -DHAVE_MALLOC_USABLE_SIZE \ -DSQLITE_DQS=0 \ diff --git a/cmd/main.go b/cmd/main.go new file mode 100644 index 0000000..be3cefb --- /dev/null +++ b/cmd/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "log" + + "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/embed" +) + +func main() { + db, err := sqlite3.Open(":memory:", sqlite3.SQLITE_OPEN_READWRITE|sqlite3.SQLITE_OPEN_CREATE, "") + if err != nil { + log.Fatal(err) + } + defer db.Close() +} diff --git a/embed/init.go b/embed/init.go new file mode 100644 index 0000000..0dc5fc2 --- /dev/null +++ b/embed/init.go @@ -0,0 +1,14 @@ +package embed + +import ( + _ "embed" + + "github.com/ncruces/go-sqlite3" +) + +//go:embed sqlite3.wasm +var binary []byte + +func init() { + sqlite3.Binary = binary +} diff --git a/go.mod b/go.mod index 4cd1b29..9a7f7de 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module sqlite3 +module github.com/ncruces/go-sqlite3 go 1.19 diff --git a/sqlite3.go b/sqlite3.go index bb03347..c00aed9 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -1,159 +1,172 @@ -package main +package sqlite3 import ( "bytes" "context" "errors" "fmt" - "log" - - _ "embed" + "os" + "strconv" + "sync" + "sync/atomic" "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" ) -//go:embed sqlite3.wasm -var binary []byte +// Configure SQLite. +var ( + Binary []byte // Binary to load. + Path string // Path to load the binary from. +) -func main() { - var ctx = context.Background() +var ( + once sync.Once + wasm wazero.Runtime + module wazero.CompiledModule + counter atomic.Uint64 +) - wasm := wazero.NewRuntime(ctx) +func compile() { + ctx := context.Background() + + wasm = wazero.NewRuntime(ctx) wasi_snapshot_preview1.MustInstantiate(ctx, wasm) - compiled, err := wasm.CompileModule(ctx, binary) - if err != nil { + if Binary == nil && Path != "" { + if bin, err := os.ReadFile(Path); err != nil { + panic(err) + } else { + Binary = bin + } + } + + if m, err := wasm.CompileModule(ctx, Binary); err != nil { panic(err) + } else { + module = m } +} - cfg := wazero.NewModuleConfig() - module, err := wasm.InstantiateModule(ctx, compiled, cfg) +type Conn struct { + handle uint32 + module api.Module + memory api.Memory + api sqliteAPI +} + +func Open(name string, flags uint64, vfs string) (*Conn, error) { + once.Do(compile) + + ctx := context.TODO() + + cfg := wazero.NewModuleConfig(). + WithName("sqlite3-" + strconv.FormatUint(counter.Add(1), 10)) + module, err := wasm.InstantiateModule(ctx, module, cfg) if err != nil { - panic(err) + return nil, err } - var db sqlite = sqlite{ - memory: module.Memory(), - _malloc: module.ExportedFunction("malloc"), - _free: module.ExportedFunction("free"), - _errmsg: module.ExportedFunction("sqlite3_errmsg"), - _open: module.ExportedFunction("sqlite3_open_v2"), - _close: module.ExportedFunction("sqlite3_close"), - _prepare: module.ExportedFunction("sqlite3_prepare_v2"), - _exec: module.ExportedFunction("sqlite3_exec"), - _step: module.ExportedFunction("sqlite3_step"), - _columnText: module.ExportedFunction("sqlite3_column_text"), - _columnInt: module.ExportedFunction("sqlite3_column_int64"), - _columnFloat: module.ExportedFunction("sqlite3_column_double"), + c := Conn{ + module: module, + memory: module.Memory(), + api: sqliteAPI{ + malloc: module.ExportedFunction("malloc"), + free: module.ExportedFunction("free"), + errmsg: module.ExportedFunction("sqlite3_errmsg"), + open: module.ExportedFunction("sqlite3_open_v2"), + close: module.ExportedFunction("sqlite3_close"), + prepare: module.ExportedFunction("sqlite3_prepare_v2"), + exec: module.ExportedFunction("sqlite3_exec"), + step: module.ExportedFunction("sqlite3_step"), + columnText: module.ExportedFunction("sqlite3_column_text"), + columnInt: module.ExportedFunction("sqlite3_column_int64"), + columnFloat: module.ExportedFunction("sqlite3_column_double"), + }, } - log.Println(err, db.memory.Size()) + namePtr := c.newString(name) + defer c.free(namePtr) - err = db.Open(":memory:", SQLITE_OPEN_READWRITE|SQLITE_OPEN_CREATE, "") - defer db.Close() - - log.Println(err, db.memory.Size()) -} - -type sqlite struct { - handle uint32 - memory api.Memory - _malloc api.Function - _free api.Function - _errmsg api.Function - _open api.Function - _close api.Function - _prepare api.Function - _exec api.Function - _step api.Function - _columnInt api.Function - _columnText api.Function - _columnFloat api.Function -} - -func (s *sqlite) Errmsg() error { - r, err := s._errmsg.Call(context.TODO(), uint64(s.handle)) - if err != nil { - return err - } - return errors.New(s.getString(r[0])) -} - -func (s *sqlite) Open(name string, flags uint64, vfs string) error { - namePtr := s.newString(name) - defer s.free(namePtr) - - handlePtr := s.newPtr() - defer s.free(handlePtr) + handlePtr := c.newPtr() + defer c.free(handlePtr) var vfsPtr uint32 if vfs != "" { - vfsPtr = s.newString(vfs) - defer s.free(vfsPtr) + vfsPtr = c.newString(vfs) + defer c.free(vfsPtr) } - r, err := s._open.Call(context.TODO(), uint64(namePtr), uint64(handlePtr), flags, uint64(vfsPtr)) + r, err := c.api.open.Call(ctx, uint64(namePtr), uint64(handlePtr), flags, uint64(vfsPtr)) + if err != nil { + _ = c.Close() + return nil, err + } + + c.handle, _ = c.memory.ReadUint32Le(handlePtr) + + if r[0] != SQLITE_OK { + err := fmt.Errorf("sqlite error (%d): %s", r[0], c.Errmsg()) + _ = c.Close() + return nil, err + } + return &c, nil +} + +func (c *Conn) Errmsg() error { + r, err := c.api.errmsg.Call(context.TODO(), uint64(c.handle)) if err != nil { return err } - - s.handle, _ = s.memory.ReadUint32Le(handlePtr) - - if r[0] != SQLITE_OK { - err := fmt.Errorf("sqlite error (%d): %s", r[0], s.Errmsg()) - _ = s.Close() - return err - } - return nil + return errors.New(c.getString(r[0])) } -func (s *sqlite) Close() error { - r, err := s._close.Call(context.TODO(), uint64(s.handle)) +func (c *Conn) Close() error { + r, err := c.api.close.Call(context.TODO(), uint64(c.handle)) if err != nil { return err } if r[0] != SQLITE_OK { - return fmt.Errorf("sqlite error (%d): %s", r[0], s.Errmsg()) + return fmt.Errorf("sqlite error (%d): %s", r[0], c.Errmsg()) } return nil } -func (s *sqlite) free(ptr uint32) { - _, err := s._free.Call(context.TODO(), uint64(ptr)) +func (c *Conn) free(ptr uint32) { + _, err := c.api.free.Call(context.TODO(), uint64(ptr)) if err != nil { panic(err) } } -func (s *sqlite) newPtr() uint32 { - r, err := s._malloc.Call(context.TODO(), 4) +func (c *Conn) newPtr() uint32 { + r, err := c.api.malloc.Call(context.TODO(), 4) if err != nil { panic(err) } return uint32(r[0]) } -func (s *sqlite) newString(str string) uint32 { - r, err := s._malloc.Call(context.TODO(), uint64(len(str)+1)) +func (c *Conn) newString(str string) uint32 { + r, err := c.api.malloc.Call(context.TODO(), uint64(len(str)+1)) if err != nil { panic(err) } ptr := uint32(r[0]) - if ok := s.memory.Write(ptr, []byte(str)); !ok { + if ok := c.memory.Write(ptr, []byte(str)); !ok { panic("failed init string") } - if ok := s.memory.WriteByte(ptr+uint32(len(str)), 0); !ok { + if ok := c.memory.WriteByte(ptr+uint32(len(str)), 0); !ok { panic("failed init string") } return ptr } -func (s *sqlite) getString(ptr uint64) string { - buf, ok := s.memory.Read(uint32(ptr), 64) +func (c *Conn) getString(ptr uint64) string { + buf, ok := c.memory.Read(uint32(ptr), 64) if !ok { panic("failed read string") } @@ -172,3 +185,17 @@ const ( SQLITE_OPEN_READWRITE = 0x00000002 SQLITE_OPEN_CREATE = 0x00000004 ) + +type sqliteAPI struct { + malloc api.Function + free api.Function + errmsg api.Function + open api.Function + close api.Function + prepare api.Function + exec api.Function + step api.Function + columnInt api.Function + columnText api.Function + columnFloat api.Function +}