Bind values.

This commit is contained in:
Nuno Cruces
2023-01-17 13:43:16 +00:00
parent c480512001
commit 469696867b
6 changed files with 222 additions and 81 deletions

81
api.go Normal file
View File

@@ -0,0 +1,81 @@
package sqlite3
import "github.com/tetratelabs/wazero/api"
func newConn(module api.Module) *Conn {
getFun := func(name string) api.Function {
f := module.ExportedFunction(name)
if f == nil {
panic("sqlite3: could not find " + name + " function")
}
return f
}
global := module.ExportedGlobal("malloc_destructor")
if global == nil {
panic("sqlite3: could not find malloc_destructor global")
}
destructor := uint32(global.Get())
destructor, ok := module.Memory().ReadUint32Le(destructor)
if !ok {
panic("sqlite3: could not read malloc_destructor global")
}
return &Conn{
module: module,
memory: module.Memory(),
api: sqliteAPI{
malloc: getFun("malloc"),
free: getFun("free"),
destructor: uint64(destructor),
errstr: getFun("sqlite3_errstr"),
errmsg: getFun("sqlite3_errmsg"),
erroff: getFun("sqlite3_error_offset"),
open: getFun("sqlite3_open_v2"),
close: getFun("sqlite3_close"),
prepare: getFun("sqlite3_prepare_v3"),
finalize: getFun("sqlite3_finalize"),
exec: getFun("sqlite3_exec"),
step: getFun("sqlite3_step"),
bindInteger: getFun("sqlite3_bind_int64"),
bindFloat: getFun("sqlite3_bind_double"),
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
bindNull: getFun("sqlite3_bind_null"),
columnInteger: getFun("sqlite3_column_int64"),
columnFloat: getFun("sqlite3_column_double"),
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
columnType: getFun("sqlite3_column_type"),
},
}
}
type sqliteAPI struct {
malloc api.Function
free api.Function
destructor uint64
errstr api.Function
errmsg api.Function
erroff api.Function
open api.Function
close api.Function
prepare api.Function
finalize api.Function
exec api.Function
step api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
bindNull api.Function
columnInteger api.Function
columnFloat api.Function
columnText api.Function
columnBlob api.Function
columnBytes api.Function
columnType api.Function
}

View File

@@ -32,17 +32,27 @@ zig cc --target=wasm32-wasi -flto -g0 -O2 \
-DSQLITE_OMIT_SHARED_CACHE \
-DSQLITE_OMIT_AUTOINIT \
-DSQLITE_OMIT_UTF16 \
-Wl,--export=malloc \
-Wl,--export=free \
-Wl,--export=malloc_destructor \
-Wl,--export=sqlite3_errstr \
-Wl,--export=sqlite3_errmsg \
-Wl,--export=sqlite3_error_offset \
-Wl,--export=sqlite3_open_v2 \
-Wl,--export=sqlite3_close \
-Wl,--export=sqlite3_prepare_v3 \
-Wl,--export=sqlite3_finalize \
-Wl,--export=sqlite3_exec \
-Wl,--export=sqlite3_step \
-Wl,--export=sqlite3_column_text \
-Wl,--export=sqlite3_bind_int64 \
-Wl,--export=sqlite3_bind_double \
-Wl,--export=sqlite3_bind_text64 \
-Wl,--export=sqlite3_bind_blob64 \
-Wl,--export=sqlite3_bind_zeroblob64 \
-Wl,--export=sqlite3_bind_null \
-Wl,--export=sqlite3_column_int64 \
-Wl,--export=sqlite3_column_double \
-Wl,--export=sqlite3_errstr \
-Wl,--export=sqlite3_errmsg \
-Wl,--export=sqlite3_error_offset \
-Wl,--export=malloc \
-Wl,--export=free
-Wl,--export=sqlite3_column_text \
-Wl,--export=sqlite3_column_blob \
-Wl,--export=sqlite3_column_bytes \
-Wl,--export=sqlite3_column_type \

113
conn.go
View File

@@ -47,36 +47,17 @@ func OpenFlags(name string, flags OpenFlag) (conn *Conn, err error) {
if err != nil {
return nil, err
}
c := Conn{
module: module,
memory: module.Memory(),
api: sqliteAPI{
malloc: module.ExportedFunction("malloc"),
free: module.ExportedFunction("free"),
errstr: module.ExportedFunction("sqlite3_errstr"),
errmsg: module.ExportedFunction("sqlite3_errmsg"),
erroff: module.ExportedFunction("sqlite3_error_offset"),
open: module.ExportedFunction("sqlite3_open_v2"),
close: module.ExportedFunction("sqlite3_close"),
prepare: module.ExportedFunction("sqlite3_prepare_v3"),
finalize: module.ExportedFunction("sqlite3_finalize"),
exec: module.ExportedFunction("sqlite3_exec"),
step: module.ExportedFunction("sqlite3_step"),
columnText: module.ExportedFunction("sqlite3_column_text"),
columnInt: module.ExportedFunction("sqlite3_column_int64"),
columnFloat: module.ExportedFunction("sqlite3_column_double"),
},
}
defer func() {
if conn == nil {
c.Close()
module.Close(ctx)
}
}()
c := newConn(module)
namePtr := c.newString(name)
connPtr := c.newBytes(ptrSize)
connPtr := c.new(ptrSize)
defer c.free(namePtr)
defer c.free(connPtr)
r, err := c.api.open.Call(ctx, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
if err != nil {
@@ -84,13 +65,11 @@ func OpenFlags(name string, flags OpenFlag) (conn *Conn, err error) {
}
c.handle, _ = c.memory.ReadUint32Le(connPtr)
c.free(connPtr)
c.free(namePtr)
if r[0] != _OK {
return nil, c.error(r[0])
if err := c.error(r[0]); err != nil {
return nil, err
}
return &c, nil
return c, nil
}
func (c *Conn) Close() error {
@@ -99,26 +78,21 @@ func (c *Conn) Close() error {
return err
}
if r[0] != _OK {
return c.error(r[0])
if err := c.error(r[0]); err != nil {
return err
}
return c.module.Close(context.TODO())
}
func (c *Conn) Exec(sql string) error {
sqlPtr := c.newString(sql)
defer c.free(sqlPtr)
r, err := c.api.exec.Call(context.TODO(), uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
if err != nil {
return err
}
c.free(sqlPtr)
if r[0] != _OK {
return c.error(r[0])
}
return nil
return c.error(r[0])
}
func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
@@ -127,8 +101,11 @@ func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) {
sqlPtr := c.newString(sql)
stmtPtr := c.newBytes(ptrSize)
tailPtr := c.newBytes(ptrSize)
stmtPtr := c.new(ptrSize)
tailPtr := c.new(ptrSize)
defer c.free(sqlPtr)
defer c.free(stmtPtr)
defer c.free(tailPtr)
r, err := c.api.prepare.Call(context.TODO(), uint64(c.handle),
uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),
@@ -142,12 +119,8 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
i, _ := c.memory.ReadUint32Le(tailPtr)
tail = sql[i-sqlPtr:]
c.free(tailPtr)
c.free(stmtPtr)
c.free(sqlPtr)
if r[0] != _OK {
return nil, "", c.error(r[0])
if err := c.error(r[0]); err != nil {
return nil, "", err
}
if stmt.handle == 0 {
return nil, "", nil
@@ -155,7 +128,11 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
return
}
func (c *Conn) error(rc uint64) *Error {
func (c *Conn) error(rc uint64) error {
if rc == _OK {
return nil
}
serr := Error{
Code: ErrorCode(rc & 0xFF),
ExtendedCode: ExtendedErrorCode(rc),
@@ -192,7 +169,7 @@ func (c *Conn) free(ptr uint32) {
}
}
func (c *Conn) newBytes(len uint32) uint32 {
func (c *Conn) new(len uint32) uint32 {
r, err := c.api.malloc.Call(context.TODO(), uint64(len))
if err != nil {
panic(err)
@@ -203,17 +180,30 @@ func (c *Conn) newBytes(len uint32) uint32 {
return uint32(r[0])
}
func (c *Conn) newString(str string) uint32 {
ptr := c.newBytes(uint32(len(str) + 1))
func (c *Conn) newBytes(s []byte) uint32 {
ptr := c.new(uint32(len(s)))
buf, ok := c.memory.Read(ptr, uint32(len(str)+1))
buf, ok := c.memory.Read(ptr, uint32(len(s)))
if !ok {
c.api.free.Call(context.TODO(), uint64(ptr))
panic("sqlite3: failed to init bytes")
}
copy(buf, s)
return ptr
}
func (c *Conn) newString(s string) uint32 {
ptr := c.new(uint32(len(s) + 1))
buf, ok := c.memory.Read(ptr, uint32(len(s)+1))
if !ok {
c.api.free.Call(context.TODO(), uint64(ptr))
panic("sqlite3: failed to init string")
}
buf[len(str)] = 0
copy(buf, str)
buf[len(s)] = 0
copy(buf, s)
return ptr
}
@@ -235,20 +225,3 @@ func (c *Conn) getString(ptr, maxlen uint32) string {
}
const ptrSize = 4
type sqliteAPI struct {
malloc api.Function
free api.Function
errstr api.Function
errmsg api.Function
erroff api.Function
open api.Function
close api.Function
prepare api.Function
finalize api.Function
exec api.Function
step api.Function
columnInt api.Function
columnText api.Function
columnFloat api.Function
}

View File

@@ -4,6 +4,8 @@ const (
_OK = 0 /* Successful result */
_ROW = 100 /* sqlite3_step() has another row ready */
_DONE = 101 /* sqlite3_step() has finished executing */
_UTF8 = 1
)
type ErrorCode int
@@ -151,3 +153,13 @@ const (
PREPARE_NORMALIZE PrepareFlag = 0x02
PREPARE_NO_VTAB PrepareFlag = 0x04
)
type Datatype uint
const (
INTEGER Datatype = 1
FLOAT Datatype = 2
TEXT Datatype = 3
BLOB Datatype = 4
NULL Datatype = 5
)

View File

@@ -1,3 +1,4 @@
#include <stdlib.h>
#include "sqlite3.h"
int main() {
@@ -10,3 +11,5 @@ sqlite3_vfs *sqlite3_demovfs();
int sqlite3_os_init() {
return sqlite3_vfs_register(sqlite3_demovfs(), /*default=*/1);
}
sqlite3_destructor_type malloc_destructor = &free;

72
stmt.go
View File

@@ -1,6 +1,9 @@
package sqlite3
import "context"
import (
"context"
"math"
)
type Stmt struct {
c *Conn
@@ -14,8 +17,67 @@ func (s *Stmt) Close() error {
}
s.handle = 0
if r[0] != _OK {
return s.c.error(r[0])
}
return nil
return s.c.error(r[0])
}
func (s *Stmt) BindBool(param int, value bool) error {
if value {
return s.BindInt64(param, 1)
}
return s.BindInt64(param, 0)
}
func (s *Stmt) BindInt(param int, value int) error {
return s.BindInt64(param, int64(value))
}
func (s *Stmt) BindInt64(param int, value int64) error {
r, err := s.c.api.bindInteger.Call(context.TODO(),
uint64(s.handle), uint64(param), uint64(value))
if err != nil {
return err
}
return s.c.error(r[0])
}
func (s *Stmt) BindFloat(param int, value float64) error {
r, err := s.c.api.bindFloat.Call(context.TODO(),
uint64(s.handle), uint64(param), math.Float64bits(value))
if err != nil {
return err
}
return s.c.error(r[0])
}
func (s *Stmt) BindText(param int, value string) error {
ptr := s.c.newString(value)
r, err := s.c.api.bindText.Call(context.TODO(),
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
s.c.api.destructor, _UTF8)
if err != nil {
return err
}
return s.c.error(r[0])
}
func (s *Stmt) BindBlob(param int, value []byte) error {
ptr := s.c.newBytes(value)
r, err := s.c.api.bindBlob.Call(context.TODO(),
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
s.c.api.destructor)
if err != nil {
return err
}
return s.c.error(r[0])
}
func (s *Stmt) BindNull(param int) error {
r, err := s.c.api.bindNull.Call(context.TODO(),
uint64(s.handle), uint64(param))
if err != nil {
return err
}
return s.c.error(r[0])
}