diff --git a/build_deps.sh b/build_deps.sh index 17772c6..3d13d46 100755 --- a/build_deps.sh +++ b/build_deps.sh @@ -40,6 +40,9 @@ zig cc --target=wasm32-wasi -flto -g0 -O2 \ -Wl,--export=sqlite3_column_text \ -Wl,--export=sqlite3_column_int64 \ -Wl,--export=sqlite3_column_double \ + -Wl,--export=sqlite3_errstr \ -Wl,--export=sqlite3_errmsg \ + -Wl,--export=sqlite3_error_offset \ + -Wl,--export=sqlite3_extended_errcode \ -Wl,--export=malloc \ -Wl,--export=free diff --git a/cmd/main.go b/cmd/main.go index be3cefb..f419601 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -8,9 +8,23 @@ import ( ) func main() { - db, err := sqlite3.Open(":memory:", sqlite3.SQLITE_OPEN_READWRITE|sqlite3.SQLITE_OPEN_CREATE, "") + db, err := sqlite3.Open(":memory:", 0, "") + if err != nil { + log.Fatal(err) + } + + err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id int, name varchar(10))`) + if err != nil { + log.Fatal(err) + } + + err = db.Exec(`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`) + if err != nil { + log.Fatal(err) + } + + err = db.Close() if err != nil { log.Fatal(err) } - defer db.Close() } diff --git a/compile.go b/compile.go new file mode 100644 index 0000000..5f6aa1e --- /dev/null +++ b/compile.go @@ -0,0 +1,45 @@ +package sqlite3 + +import ( + "context" + "os" + "sync" + "sync/atomic" + + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" +) + +// Configure SQLite. +var ( + Binary []byte // Binary to load. + Path string // Path to load the binary from. +) + +var ( + once sync.Once + wasm wazero.Runtime + module wazero.CompiledModule + counter atomic.Uint64 +) + +func compile() { + ctx := context.Background() + + wasm = wazero.NewRuntime(ctx) + wasi_snapshot_preview1.MustInstantiate(ctx, wasm) + + if Binary == nil && Path != "" { + if bin, err := os.ReadFile(Path); err != nil { + panic(err) + } else { + Binary = bin + } + } + + if m, err := wasm.CompileModule(ctx, Binary); err != nil { + panic(err) + } else { + module = m + } +} diff --git a/conn.go b/conn.go index 0242b01..b703cdd 100644 --- a/conn.go +++ b/conn.go @@ -3,52 +3,15 @@ package sqlite3 import ( "bytes" "context" - "errors" - "fmt" - "os" + "io/fs" + "path/filepath" "strconv" - "sync" - "sync/atomic" "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" - "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" + "github.com/tetratelabs/wazero/experimental/writefs" ) -// Configure SQLite. -var ( - Binary []byte // Binary to load. - Path string // Path to load the binary from. -) - -var ( - once sync.Once - wasm wazero.Runtime - module wazero.CompiledModule - counter atomic.Uint64 -) - -func compile() { - ctx := context.Background() - - wasm = wazero.NewRuntime(ctx) - wasi_snapshot_preview1.MustInstantiate(ctx, wasm) - - if Binary == nil && Path != "" { - if bin, err := os.ReadFile(Path); err != nil { - panic(err) - } else { - Binary = bin - } - } - - if m, err := wasm.CompileModule(ctx, Binary); err != nil { - panic(err) - } else { - module = m - } -} - type Conn struct { handle uint32 module api.Module @@ -56,13 +19,26 @@ type Conn struct { api sqliteAPI } -func Open(name string, flags uint64, vfs string) (*Conn, error) { +func Open(name string, flags uint64, vfs string) (conn *Conn, err error) { once.Do(compile) + var fs fs.FS + if name != ":memory:" { + dir := filepath.Dir(name) + name = filepath.Base(name) + fs, err = writefs.NewDirFS(dir) + if err != nil { + return nil, err + } + } + ctx := context.TODO() cfg := wazero.NewModuleConfig(). WithName("sqlite3-" + strconv.FormatUint(counter.Add(1), 10)) + if fs != nil { + cfg = cfg.WithFS(fs) + } module, err := wasm.InstantiateModule(ctx, module, cfg) if err != nil { return nil, err @@ -74,7 +50,10 @@ func Open(name string, flags uint64, vfs string) (*Conn, error) { api: sqliteAPI{ malloc: module.ExportedFunction("malloc"), free: module.ExportedFunction("free"), + errstr: module.ExportedFunction("sqlite3_errstr"), errmsg: module.ExportedFunction("sqlite3_errmsg"), + erroff: module.ExportedFunction("sqlite3_error_offset"), + errext: module.ExportedFunction("sqlite3_extended_errcode"), open: module.ExportedFunction("sqlite3_open_v2"), close: module.ExportedFunction("sqlite3_close"), prepare: module.ExportedFunction("sqlite3_prepare_v2"), @@ -86,55 +65,102 @@ func Open(name string, flags uint64, vfs string) (*Conn, error) { }, } - namePtr := c.newString(name) - defer c.free(namePtr) + defer func() { + if conn == nil { + c.Close() + } + }() + namePtr := c.newString(name) handlePtr := c.newBytes(4) - defer c.free(handlePtr) + + if flags == 0 { + flags = OPEN_READWRITE | OPEN_CREATE + } var vfsPtr uint32 if vfs != "" { vfsPtr = c.newString(vfs) - defer c.free(vfsPtr) } r, err := c.api.open.Call(ctx, uint64(namePtr), uint64(handlePtr), flags, uint64(vfsPtr)) if err != nil { - _ = c.Close() return nil, err } c.handle, _ = c.memory.ReadUint32Le(handlePtr) + c.free(handlePtr) + c.free(namePtr) + c.free(vfsPtr) - if r[0] != SQLITE_OK { - err := fmt.Errorf("sqlite error (%d): %s", r[0], c.Errmsg()) - _ = c.Close() - return nil, err + if r[0] != OK { + return nil, c.error(r[0]) } return &c, nil } -func (c *Conn) Errmsg() error { - r, err := c.api.errmsg.Call(context.TODO(), uint64(c.handle)) - if err != nil { - return err - } - return errors.New(c.getString(uint32(r[0]), 64)) -} - func (c *Conn) Close() error { r, err := c.api.close.Call(context.TODO(), uint64(c.handle)) if err != nil { return err } - if r[0] != SQLITE_OK { - return fmt.Errorf("sqlite error (%d): %s", r[0], c.Errmsg()) + if r[0] != OK { + return c.error(r[0]) + } + return c.module.Close(context.TODO()) +} + +func (c *Conn) Exec(sql string) error { + sqlPtr := c.newString(sql) + + r, err := c.api.exec.Call(context.TODO(), uint64(c.handle), uint64(sqlPtr), 0, 0, 0) + if err != nil { + return err + } + + c.free(sqlPtr) + + if r[0] != OK { + return c.error(r[0]) } return nil } +func (c *Conn) error(rc uint64) *Error { + serr := Error{Code: int(rc)} + + var r []uint64 + + // string + r, _ = c.api.errstr.Call(context.TODO(), uint64(rc)) + if r != nil { + serr.str = c.getString(uint32(r[0]), 512) + } + + // message + r, _ = c.api.errmsg.Call(context.TODO(), uint64(c.handle)) + if r != nil { + serr.msg = c.getString(uint32(r[0]), 512) + } + + // extended code + r, _ = c.api.errext.Call(context.TODO(), uint64(c.handle)) + if r != nil { + serr.ExtendedCode = int(r[0]) + } + + if serr.str == serr.msg { + serr.msg = "" + } + + return &serr +} + func (c *Conn) free(ptr uint32) { + if ptr == 0 { + return + } _, err := c.api.free.Call(context.TODO(), uint64(ptr)) if err != nil { panic(err) @@ -183,19 +209,13 @@ func (c *Conn) getString(ptr, maxlen uint32) string { } } -const ( - SQLITE_OK = 0 - SQLITE_ROW = 100 - SQLITE_DONE = 101 - - SQLITE_OPEN_READWRITE = 0x00000002 - SQLITE_OPEN_CREATE = 0x00000004 -) - type sqliteAPI struct { malloc api.Function free api.Function + errstr api.Function errmsg api.Function + errext api.Function + erroff api.Function open api.Function close api.Function prepare api.Function diff --git a/const.go b/const.go new file mode 100644 index 0000000..f064bc6 --- /dev/null +++ b/const.go @@ -0,0 +1,12 @@ +package sqlite3 + +const ( + OK = 0 + ROW = 100 + DONE = 101 +) + +const ( + OPEN_READWRITE = 0x00000002 + OPEN_CREATE = 0x00000004 +) diff --git a/error.go b/error.go new file mode 100644 index 0000000..e492eb9 --- /dev/null +++ b/error.go @@ -0,0 +1,32 @@ +package sqlite3 + +import ( + "strconv" + "strings" +) + +type Error struct { + Code int + ExtendedCode int + str string + msg string +} + +func (e Error) Error() string { + var b strings.Builder + b.WriteString("sqlite3: ") + + if e.str != "" { + b.WriteString(e.str) + } else { + b.WriteString(strconv.Itoa(e.Code)) + } + + if e.msg != "" { + b.WriteByte(':') + b.WriteByte(' ') + b.WriteString(e.msg) + } + + return b.String() +}