Files
sqlite3/conn.go
2023-01-15 04:35:37 +00:00

262 lines
5.4 KiB
Go

package sqlite3
import (
"bytes"
"context"
"io/fs"
"path/filepath"
"strconv"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/experimental/writefs"
)
type Conn struct {
handle uint32
module api.Module
memory api.Memory
api sqliteAPI
}
func Open(name string, flags uint64, vfs string) (conn *Conn, err error) {
once.Do(compile)
var fs fs.FS
if name != ":memory:" {
dir := filepath.Dir(name)
name = filepath.Base(name)
fs, err = writefs.NewDirFS(dir)
if err != nil {
return nil, err
}
}
ctx := context.TODO()
cfg := wazero.NewModuleConfig().
WithName("sqlite3-" + strconv.FormatUint(counter.Add(1), 10))
if fs != nil {
cfg = cfg.WithFS(fs)
}
module, err := wasm.InstantiateModule(ctx, module, cfg)
if err != nil {
return nil, err
}
c := Conn{
module: module,
memory: module.Memory(),
api: sqliteAPI{
malloc: module.ExportedFunction("malloc"),
free: module.ExportedFunction("free"),
errstr: module.ExportedFunction("sqlite3_errstr"),
errmsg: module.ExportedFunction("sqlite3_errmsg"),
erroff: module.ExportedFunction("sqlite3_error_offset"),
errext: module.ExportedFunction("sqlite3_extended_errcode"),
open: module.ExportedFunction("sqlite3_open_v2"),
close: module.ExportedFunction("sqlite3_close"),
prepare: module.ExportedFunction("sqlite3_prepare_v3"),
finalize: module.ExportedFunction("sqlite3_finalize"),
exec: module.ExportedFunction("sqlite3_exec"),
step: module.ExportedFunction("sqlite3_step"),
columnText: module.ExportedFunction("sqlite3_column_text"),
columnInt: module.ExportedFunction("sqlite3_column_int64"),
columnFloat: module.ExportedFunction("sqlite3_column_double"),
},
}
defer func() {
if conn == nil {
c.Close()
}
}()
namePtr := c.newString(name)
connPtr := c.newBytes(ptrSize)
if flags == 0 {
flags = OPEN_READWRITE | OPEN_CREATE
}
var vfsPtr uint32
if vfs != "" {
vfsPtr = c.newString(vfs)
}
r, err := c.api.open.Call(ctx, uint64(namePtr), uint64(connPtr), flags, uint64(vfsPtr))
if err != nil {
return nil, err
}
c.handle, _ = c.memory.ReadUint32Le(connPtr)
c.free(connPtr)
c.free(namePtr)
c.free(vfsPtr)
if r[0] != OK {
return nil, c.error(r[0])
}
return &c, nil
}
func (c *Conn) Close() error {
r, err := c.api.close.Call(context.TODO(), uint64(c.handle))
if err != nil {
return err
}
if r[0] != OK {
return c.error(r[0])
}
return c.module.Close(context.TODO())
}
func (c *Conn) Exec(sql string) error {
sqlPtr := c.newString(sql)
r, err := c.api.exec.Call(context.TODO(), uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
if err != nil {
return err
}
c.free(sqlPtr)
if r[0] != OK {
return c.error(r[0])
}
return nil
}
func (c *Conn) Prepare(sql string, flags uint64, args ...any) (stmt *Stmt, tail string, err error) {
sqlPtr := c.newString(sql)
stmtPtr := c.newBytes(ptrSize)
tailPtr := c.newBytes(ptrSize)
r, err := c.api.prepare.Call(context.TODO(), uint64(c.handle),
uint64(sqlPtr), uint64(len(sql)+1), flags,
uint64(stmtPtr), uint64(tailPtr))
if err != nil {
return nil, "", err
}
stmt = &Stmt{c: c}
stmt.handle, _ = c.memory.ReadUint32Le(stmtPtr)
i, _ := c.memory.ReadUint32Le(tailPtr)
tail = sql[i-sqlPtr:]
c.free(tailPtr)
c.free(stmtPtr)
c.free(sqlPtr)
if r[0] != OK {
return nil, "", c.error(r[0])
}
if stmt.handle == 0 {
return nil, "", nil
}
return
}
func (c *Conn) error(rc uint64) *Error {
serr := Error{Code: int(rc)}
var r []uint64
// string
r, _ = c.api.errstr.Call(context.TODO(), uint64(rc))
if r != nil {
serr.str = c.getString(uint32(r[0]), 512)
}
// message
r, _ = c.api.errmsg.Call(context.TODO(), uint64(c.handle))
if r != nil {
serr.msg = c.getString(uint32(r[0]), 512)
}
// extended code
r, _ = c.api.errext.Call(context.TODO(), uint64(c.handle))
if r != nil {
serr.ExtendedCode = int(r[0])
}
if serr.str == serr.msg {
serr.msg = ""
}
return &serr
}
func (c *Conn) free(ptr uint32) {
if ptr == 0 {
return
}
_, err := c.api.free.Call(context.TODO(), uint64(ptr))
if err != nil {
panic(err)
}
}
func (c *Conn) newBytes(len uint32) uint32 {
r, err := c.api.malloc.Call(context.TODO(), uint64(len))
if err != nil {
panic(err)
}
if r[0] == 0 {
panic("sqlite3: out of memory")
}
return uint32(r[0])
}
func (c *Conn) newString(str string) uint32 {
ptr := c.newBytes(uint32(len(str) + 1))
buf, ok := c.memory.Read(ptr, uint32(len(str)+1))
if !ok {
c.api.free.Call(context.TODO(), uint64(ptr))
panic("sqlite3: failed to init string")
}
buf[len(str)] = 0
copy(buf, str)
return ptr
}
func (c *Conn) getString(ptr, maxlen uint32) string {
buf, ok := c.memory.Read(ptr, maxlen)
if !ok {
if size := c.memory.Size(); ptr < size {
buf, ok = c.memory.Read(ptr, size-ptr)
}
if !ok {
panic("sqlite3: invalid pointer")
}
}
if i := bytes.IndexByte(buf, 0); i < 0 {
panic("sqlite3: missing NUL terminator")
} else {
return string(buf[:i])
}
}
const ptrSize = 4
type sqliteAPI struct {
malloc api.Function
free api.Function
errstr api.Function
errmsg api.Function
errext api.Function
erroff api.Function
open api.Function
close api.Function
prepare api.Function
finalize api.Function
exec api.Function
step api.Function
columnInt api.Function
columnText api.Function
columnFloat api.Function
}