diff --git a/api.go b/api.go new file mode 100644 index 0000000..fd97206 --- /dev/null +++ b/api.go @@ -0,0 +1,81 @@ +package sqlite3 + +import "github.com/tetratelabs/wazero/api" + +func newConn(module api.Module) *Conn { + getFun := func(name string) api.Function { + f := module.ExportedFunction(name) + if f == nil { + panic("sqlite3: could not find " + name + " function") + } + return f + } + + global := module.ExportedGlobal("malloc_destructor") + if global == nil { + panic("sqlite3: could not find malloc_destructor global") + } + destructor := uint32(global.Get()) + destructor, ok := module.Memory().ReadUint32Le(destructor) + if !ok { + panic("sqlite3: could not read malloc_destructor global") + } + + return &Conn{ + module: module, + memory: module.Memory(), + api: sqliteAPI{ + malloc: getFun("malloc"), + free: getFun("free"), + destructor: uint64(destructor), + errstr: getFun("sqlite3_errstr"), + errmsg: getFun("sqlite3_errmsg"), + erroff: getFun("sqlite3_error_offset"), + open: getFun("sqlite3_open_v2"), + close: getFun("sqlite3_close"), + prepare: getFun("sqlite3_prepare_v3"), + finalize: getFun("sqlite3_finalize"), + exec: getFun("sqlite3_exec"), + step: getFun("sqlite3_step"), + bindInteger: getFun("sqlite3_bind_int64"), + bindFloat: getFun("sqlite3_bind_double"), + bindText: getFun("sqlite3_bind_text64"), + bindBlob: getFun("sqlite3_bind_blob64"), + bindZeroBlob: getFun("sqlite3_bind_zeroblob64"), + bindNull: getFun("sqlite3_bind_null"), + columnInteger: getFun("sqlite3_column_int64"), + columnFloat: getFun("sqlite3_column_double"), + columnText: getFun("sqlite3_column_text"), + columnBlob: getFun("sqlite3_column_blob"), + columnBytes: getFun("sqlite3_column_bytes"), + columnType: getFun("sqlite3_column_type"), + }, + } +} + +type sqliteAPI struct { + malloc api.Function + free api.Function + destructor uint64 + errstr api.Function + errmsg api.Function + erroff api.Function + open api.Function + close api.Function + prepare api.Function + finalize api.Function + exec api.Function + step api.Function + bindInteger api.Function + bindFloat api.Function + bindText api.Function + bindBlob api.Function + bindZeroBlob api.Function + bindNull api.Function + columnInteger api.Function + columnFloat api.Function + columnText api.Function + columnBlob api.Function + columnBytes api.Function + columnType api.Function +} diff --git a/build_deps.sh b/build_deps.sh index e4a3510..976698b 100755 --- a/build_deps.sh +++ b/build_deps.sh @@ -32,17 +32,27 @@ zig cc --target=wasm32-wasi -flto -g0 -O2 \ -DSQLITE_OMIT_SHARED_CACHE \ -DSQLITE_OMIT_AUTOINIT \ -DSQLITE_OMIT_UTF16 \ + -Wl,--export=malloc \ + -Wl,--export=free \ + -Wl,--export=malloc_destructor \ + -Wl,--export=sqlite3_errstr \ + -Wl,--export=sqlite3_errmsg \ + -Wl,--export=sqlite3_error_offset \ -Wl,--export=sqlite3_open_v2 \ -Wl,--export=sqlite3_close \ -Wl,--export=sqlite3_prepare_v3 \ -Wl,--export=sqlite3_finalize \ -Wl,--export=sqlite3_exec \ -Wl,--export=sqlite3_step \ - -Wl,--export=sqlite3_column_text \ + -Wl,--export=sqlite3_bind_int64 \ + -Wl,--export=sqlite3_bind_double \ + -Wl,--export=sqlite3_bind_text64 \ + -Wl,--export=sqlite3_bind_blob64 \ + -Wl,--export=sqlite3_bind_zeroblob64 \ + -Wl,--export=sqlite3_bind_null \ -Wl,--export=sqlite3_column_int64 \ -Wl,--export=sqlite3_column_double \ - -Wl,--export=sqlite3_errstr \ - -Wl,--export=sqlite3_errmsg \ - -Wl,--export=sqlite3_error_offset \ - -Wl,--export=malloc \ - -Wl,--export=free + -Wl,--export=sqlite3_column_text \ + -Wl,--export=sqlite3_column_blob \ + -Wl,--export=sqlite3_column_bytes \ + -Wl,--export=sqlite3_column_type \ diff --git a/conn.go b/conn.go index 08c9a0d..d079e4e 100644 --- a/conn.go +++ b/conn.go @@ -47,36 +47,17 @@ func OpenFlags(name string, flags OpenFlag) (conn *Conn, err error) { if err != nil { return nil, err } - - c := Conn{ - module: module, - memory: module.Memory(), - api: sqliteAPI{ - malloc: module.ExportedFunction("malloc"), - free: module.ExportedFunction("free"), - errstr: module.ExportedFunction("sqlite3_errstr"), - errmsg: module.ExportedFunction("sqlite3_errmsg"), - erroff: module.ExportedFunction("sqlite3_error_offset"), - open: module.ExportedFunction("sqlite3_open_v2"), - close: module.ExportedFunction("sqlite3_close"), - prepare: module.ExportedFunction("sqlite3_prepare_v3"), - finalize: module.ExportedFunction("sqlite3_finalize"), - exec: module.ExportedFunction("sqlite3_exec"), - step: module.ExportedFunction("sqlite3_step"), - columnText: module.ExportedFunction("sqlite3_column_text"), - columnInt: module.ExportedFunction("sqlite3_column_int64"), - columnFloat: module.ExportedFunction("sqlite3_column_double"), - }, - } - defer func() { if conn == nil { - c.Close() + module.Close(ctx) } }() + c := newConn(module) namePtr := c.newString(name) - connPtr := c.newBytes(ptrSize) + connPtr := c.new(ptrSize) + defer c.free(namePtr) + defer c.free(connPtr) r, err := c.api.open.Call(ctx, uint64(namePtr), uint64(connPtr), uint64(flags), 0) if err != nil { @@ -84,13 +65,11 @@ func OpenFlags(name string, flags OpenFlag) (conn *Conn, err error) { } c.handle, _ = c.memory.ReadUint32Le(connPtr) - c.free(connPtr) - c.free(namePtr) - if r[0] != _OK { - return nil, c.error(r[0]) + if err := c.error(r[0]); err != nil { + return nil, err } - return &c, nil + return c, nil } func (c *Conn) Close() error { @@ -99,26 +78,21 @@ func (c *Conn) Close() error { return err } - if r[0] != _OK { - return c.error(r[0]) + if err := c.error(r[0]); err != nil { + return err } return c.module.Close(context.TODO()) } func (c *Conn) Exec(sql string) error { sqlPtr := c.newString(sql) + defer c.free(sqlPtr) r, err := c.api.exec.Call(context.TODO(), uint64(c.handle), uint64(sqlPtr), 0, 0, 0) if err != nil { return err } - - c.free(sqlPtr) - - if r[0] != _OK { - return c.error(r[0]) - } - return nil + return c.error(r[0]) } func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) { @@ -127,8 +101,11 @@ func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) { func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) { sqlPtr := c.newString(sql) - stmtPtr := c.newBytes(ptrSize) - tailPtr := c.newBytes(ptrSize) + stmtPtr := c.new(ptrSize) + tailPtr := c.new(ptrSize) + defer c.free(sqlPtr) + defer c.free(stmtPtr) + defer c.free(tailPtr) r, err := c.api.prepare.Call(context.TODO(), uint64(c.handle), uint64(sqlPtr), uint64(len(sql)+1), uint64(flags), @@ -142,12 +119,8 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str i, _ := c.memory.ReadUint32Le(tailPtr) tail = sql[i-sqlPtr:] - c.free(tailPtr) - c.free(stmtPtr) - c.free(sqlPtr) - - if r[0] != _OK { - return nil, "", c.error(r[0]) + if err := c.error(r[0]); err != nil { + return nil, "", err } if stmt.handle == 0 { return nil, "", nil @@ -155,7 +128,11 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str return } -func (c *Conn) error(rc uint64) *Error { +func (c *Conn) error(rc uint64) error { + if rc == _OK { + return nil + } + serr := Error{ Code: ErrorCode(rc & 0xFF), ExtendedCode: ExtendedErrorCode(rc), @@ -192,7 +169,7 @@ func (c *Conn) free(ptr uint32) { } } -func (c *Conn) newBytes(len uint32) uint32 { +func (c *Conn) new(len uint32) uint32 { r, err := c.api.malloc.Call(context.TODO(), uint64(len)) if err != nil { panic(err) @@ -203,17 +180,30 @@ func (c *Conn) newBytes(len uint32) uint32 { return uint32(r[0]) } -func (c *Conn) newString(str string) uint32 { - ptr := c.newBytes(uint32(len(str) + 1)) +func (c *Conn) newBytes(s []byte) uint32 { + ptr := c.new(uint32(len(s))) - buf, ok := c.memory.Read(ptr, uint32(len(str)+1)) + buf, ok := c.memory.Read(ptr, uint32(len(s))) + if !ok { + c.api.free.Call(context.TODO(), uint64(ptr)) + panic("sqlite3: failed to init bytes") + } + + copy(buf, s) + return ptr +} + +func (c *Conn) newString(s string) uint32 { + ptr := c.new(uint32(len(s) + 1)) + + buf, ok := c.memory.Read(ptr, uint32(len(s)+1)) if !ok { c.api.free.Call(context.TODO(), uint64(ptr)) panic("sqlite3: failed to init string") } - buf[len(str)] = 0 - copy(buf, str) + buf[len(s)] = 0 + copy(buf, s) return ptr } @@ -235,20 +225,3 @@ func (c *Conn) getString(ptr, maxlen uint32) string { } const ptrSize = 4 - -type sqliteAPI struct { - malloc api.Function - free api.Function - errstr api.Function - errmsg api.Function - erroff api.Function - open api.Function - close api.Function - prepare api.Function - finalize api.Function - exec api.Function - step api.Function - columnInt api.Function - columnText api.Function - columnFloat api.Function -} diff --git a/const.go b/const.go index 09e5727..09cd909 100644 --- a/const.go +++ b/const.go @@ -4,6 +4,8 @@ const ( _OK = 0 /* Successful result */ _ROW = 100 /* sqlite3_step() has another row ready */ _DONE = 101 /* sqlite3_step() has finished executing */ + + _UTF8 = 1 ) type ErrorCode int @@ -151,3 +153,13 @@ const ( PREPARE_NORMALIZE PrepareFlag = 0x02 PREPARE_NO_VTAB PrepareFlag = 0x04 ) + +type Datatype uint + +const ( + INTEGER Datatype = 1 + FLOAT Datatype = 2 + TEXT Datatype = 3 + BLOB Datatype = 4 + NULL Datatype = 5 +) diff --git a/sqlite3/main.c b/sqlite3/main.c index 499f532..03f5424 100644 --- a/sqlite3/main.c +++ b/sqlite3/main.c @@ -1,3 +1,4 @@ +#include #include "sqlite3.h" int main() { @@ -10,3 +11,5 @@ sqlite3_vfs *sqlite3_demovfs(); int sqlite3_os_init() { return sqlite3_vfs_register(sqlite3_demovfs(), /*default=*/1); } + +sqlite3_destructor_type malloc_destructor = &free; \ No newline at end of file diff --git a/stmt.go b/stmt.go index f311802..f6a6b52 100644 --- a/stmt.go +++ b/stmt.go @@ -1,6 +1,9 @@ package sqlite3 -import "context" +import ( + "context" + "math" +) type Stmt struct { c *Conn @@ -14,8 +17,67 @@ func (s *Stmt) Close() error { } s.handle = 0 - if r[0] != _OK { - return s.c.error(r[0]) - } - return nil + return s.c.error(r[0]) +} + +func (s *Stmt) BindBool(param int, value bool) error { + if value { + return s.BindInt64(param, 1) + } + return s.BindInt64(param, 0) +} + +func (s *Stmt) BindInt(param int, value int) error { + return s.BindInt64(param, int64(value)) +} + +func (s *Stmt) BindInt64(param int, value int64) error { + r, err := s.c.api.bindInteger.Call(context.TODO(), + uint64(s.handle), uint64(param), uint64(value)) + if err != nil { + return err + } + return s.c.error(r[0]) +} + +func (s *Stmt) BindFloat(param int, value float64) error { + r, err := s.c.api.bindFloat.Call(context.TODO(), + uint64(s.handle), uint64(param), math.Float64bits(value)) + if err != nil { + return err + } + return s.c.error(r[0]) +} + +func (s *Stmt) BindText(param int, value string) error { + ptr := s.c.newString(value) + r, err := s.c.api.bindText.Call(context.TODO(), + uint64(s.handle), uint64(param), + uint64(ptr), uint64(len(value)), + s.c.api.destructor, _UTF8) + if err != nil { + return err + } + return s.c.error(r[0]) +} + +func (s *Stmt) BindBlob(param int, value []byte) error { + ptr := s.c.newBytes(value) + r, err := s.c.api.bindBlob.Call(context.TODO(), + uint64(s.handle), uint64(param), + uint64(ptr), uint64(len(value)), + s.c.api.destructor) + if err != nil { + return err + } + return s.c.error(r[0]) +} + +func (s *Stmt) BindNull(param int) error { + r, err := s.c.api.bindNull.Call(context.TODO(), + uint64(s.handle), uint64(param)) + if err != nil { + return err + } + return s.c.error(r[0]) }