Files
sqlite3/conn.go

233 lines
4.4 KiB
Go
Raw Normal View History

2023-01-12 05:57:09 +00:00
package sqlite3
2023-01-11 14:58:20 +00:00
import (
"bytes"
"context"
2023-01-12 13:43:35 +00:00
"io/fs"
"path/filepath"
2023-01-12 05:57:09 +00:00
"strconv"
2023-01-11 14:58:20 +00:00
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
2023-01-12 13:43:35 +00:00
"github.com/tetratelabs/wazero/experimental/writefs"
2023-01-11 14:58:20 +00:00
)
2023-01-12 05:57:09 +00:00
type Conn struct {
handle uint32
module api.Module
memory api.Memory
api sqliteAPI
2023-01-11 14:58:20 +00:00
}
2023-01-16 12:54:24 +00:00
func Open(name string) (conn *Conn, err error) {
return OpenFlags(name, OPEN_READWRITE|OPEN_CREATE)
}
func OpenFlags(name string, flags OpenFlag) (conn *Conn, err error) {
2023-01-12 05:57:09 +00:00
once.Do(compile)
2023-01-12 13:43:35 +00:00
var fs fs.FS
if name != ":memory:" {
dir := filepath.Dir(name)
name = filepath.Base(name)
fs, err = writefs.NewDirFS(dir)
if err != nil {
return nil, err
}
}
2023-01-12 05:57:09 +00:00
ctx := context.TODO()
cfg := wazero.NewModuleConfig().
WithName("sqlite3-" + strconv.FormatUint(counter.Add(1), 10))
2023-01-12 13:43:35 +00:00
if fs != nil {
cfg = cfg.WithFS(fs)
}
2023-01-12 05:57:09 +00:00
module, err := wasm.InstantiateModule(ctx, module, cfg)
2023-01-11 14:58:20 +00:00
if err != nil {
2023-01-12 05:57:09 +00:00
return nil, err
2023-01-11 14:58:20 +00:00
}
2023-01-12 13:43:35 +00:00
defer func() {
if conn == nil {
2023-01-17 13:43:16 +00:00
module.Close(ctx)
2023-01-12 13:43:35 +00:00
}
}()
2023-01-12 05:57:09 +00:00
2023-01-17 13:43:16 +00:00
c := newConn(module)
2023-01-12 13:43:35 +00:00
namePtr := c.newString(name)
2023-01-17 13:43:16 +00:00
connPtr := c.new(ptrSize)
defer c.free(namePtr)
defer c.free(connPtr)
2023-01-12 13:43:35 +00:00
2023-01-16 12:54:24 +00:00
r, err := c.api.open.Call(ctx, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
2023-01-11 14:58:20 +00:00
if err != nil {
2023-01-12 05:57:09 +00:00
return nil, err
2023-01-11 14:58:20 +00:00
}
2023-01-15 04:35:37 +00:00
c.handle, _ = c.memory.ReadUint32Le(connPtr)
2023-01-11 14:58:20 +00:00
2023-01-17 13:43:16 +00:00
if err := c.error(r[0]); err != nil {
return nil, err
2023-01-12 05:57:09 +00:00
}
2023-01-17 13:43:16 +00:00
return c, nil
2023-01-12 05:57:09 +00:00
}
2023-01-12 13:43:35 +00:00
func (c *Conn) Close() error {
r, err := c.api.close.Call(context.TODO(), uint64(c.handle))
2023-01-12 05:57:09 +00:00
if err != nil {
2023-01-11 14:58:20 +00:00
return err
}
2023-01-12 13:43:35 +00:00
2023-01-17 13:43:16 +00:00
if err := c.error(r[0]); err != nil {
return err
2023-01-12 13:43:35 +00:00
}
return c.module.Close(context.TODO())
2023-01-11 14:58:20 +00:00
}
2023-01-12 13:43:35 +00:00
func (c *Conn) Exec(sql string) error {
sqlPtr := c.newString(sql)
2023-01-17 13:43:16 +00:00
defer c.free(sqlPtr)
2023-01-12 13:43:35 +00:00
r, err := c.api.exec.Call(context.TODO(), uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
2023-01-11 14:58:20 +00:00
if err != nil {
return err
}
2023-01-17 13:43:16 +00:00
return c.error(r[0])
2023-01-11 14:58:20 +00:00
}
2023-01-16 12:54:24 +00:00
func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
return c.PrepareFlags(sql, 0)
}
func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) {
2023-01-15 04:35:37 +00:00
sqlPtr := c.newString(sql)
2023-01-17 13:43:16 +00:00
stmtPtr := c.new(ptrSize)
tailPtr := c.new(ptrSize)
defer c.free(sqlPtr)
defer c.free(stmtPtr)
defer c.free(tailPtr)
2023-01-15 04:35:37 +00:00
r, err := c.api.prepare.Call(context.TODO(), uint64(c.handle),
2023-01-16 12:54:24 +00:00
uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),
2023-01-15 04:35:37 +00:00
uint64(stmtPtr), uint64(tailPtr))
if err != nil {
return nil, "", err
}
stmt = &Stmt{c: c}
stmt.handle, _ = c.memory.ReadUint32Le(stmtPtr)
i, _ := c.memory.ReadUint32Le(tailPtr)
tail = sql[i-sqlPtr:]
2023-01-17 13:43:16 +00:00
if err := c.error(r[0]); err != nil {
return nil, "", err
2023-01-15 04:35:37 +00:00
}
if stmt.handle == 0 {
return nil, "", nil
}
return
}
2023-01-17 13:43:16 +00:00
func (c *Conn) error(rc uint64) error {
if rc == _OK {
return nil
}
2023-01-16 12:54:24 +00:00
serr := Error{
Code: ErrorCode(rc & 0xFF),
ExtendedCode: ExtendedErrorCode(rc),
}
2023-01-12 13:43:35 +00:00
var r []uint64
// string
2023-01-16 12:54:24 +00:00
r, _ = c.api.errstr.Call(context.TODO(), rc)
2023-01-12 13:43:35 +00:00
if r != nil {
serr.str = c.getString(uint32(r[0]), 512)
}
// message
r, _ = c.api.errmsg.Call(context.TODO(), uint64(c.handle))
if r != nil {
serr.msg = c.getString(uint32(r[0]), 512)
}
2023-01-17 15:01:30 +00:00
switch serr.msg {
case "not an error", serr.str:
2023-01-12 13:43:35 +00:00
serr.msg = ""
}
return &serr
}
2023-01-12 05:57:09 +00:00
func (c *Conn) free(ptr uint32) {
2023-01-12 13:43:35 +00:00
if ptr == 0 {
return
}
2023-01-12 05:57:09 +00:00
_, err := c.api.free.Call(context.TODO(), uint64(ptr))
2023-01-11 14:58:20 +00:00
if err != nil {
panic(err)
}
}
2023-01-17 13:43:16 +00:00
func (c *Conn) new(len uint32) uint32 {
2023-01-12 11:06:17 +00:00
r, err := c.api.malloc.Call(context.TODO(), uint64(len))
2023-01-11 14:58:20 +00:00
if err != nil {
panic(err)
}
2023-01-12 11:06:17 +00:00
if r[0] == 0 {
panic("sqlite3: out of memory")
}
2023-01-11 14:58:20 +00:00
return uint32(r[0])
}
2023-01-17 13:43:16 +00:00
func (c *Conn) newBytes(s []byte) uint32 {
2023-01-17 18:31:46 +00:00
if s == nil {
return 0
}
2023-01-11 14:58:20 +00:00
2023-01-17 18:31:46 +00:00
siz := uint32(len(s))
ptr := c.new(siz)
mem, ok := c.memory.Read(ptr, siz)
2023-01-17 13:43:16 +00:00
if !ok {
c.api.free.Call(context.TODO(), uint64(ptr))
2023-01-17 18:31:46 +00:00
panic("sqlite3: out of range")
2023-01-17 13:43:16 +00:00
}
2023-01-17 18:31:46 +00:00
copy(mem, s)
2023-01-17 13:43:16 +00:00
return ptr
}
func (c *Conn) newString(s string) uint32 {
2023-01-17 18:31:46 +00:00
siz := uint32(len(s) + 1)
ptr := c.new(siz)
mem, ok := c.memory.Read(ptr, siz)
2023-01-12 11:06:17 +00:00
if !ok {
c.api.free.Call(context.TODO(), uint64(ptr))
2023-01-17 18:31:46 +00:00
panic("sqlite3: out of range")
2023-01-11 14:58:20 +00:00
}
2023-01-12 11:06:17 +00:00
2023-01-17 18:31:46 +00:00
mem[len(s)] = 0
copy(mem, s)
2023-01-11 14:58:20 +00:00
return ptr
}
2023-01-12 11:06:17 +00:00
func (c *Conn) getString(ptr, maxlen uint32) string {
2023-01-17 18:31:46 +00:00
mem, ok := c.memory.Read(ptr, maxlen)
2023-01-11 14:58:20 +00:00
if !ok {
2023-01-12 11:06:17 +00:00
if size := c.memory.Size(); ptr < size {
2023-01-17 18:31:46 +00:00
mem, ok = c.memory.Read(ptr, size-ptr)
2023-01-12 11:06:17 +00:00
}
if !ok {
2023-01-17 18:31:46 +00:00
panic("sqlite3: out of range")
2023-01-12 11:06:17 +00:00
}
2023-01-11 14:58:20 +00:00
}
2023-01-17 18:31:46 +00:00
if i := bytes.IndexByte(mem, 0); i < 0 {
2023-01-12 11:06:17 +00:00
panic("sqlite3: missing NUL terminator")
2023-01-11 14:58:20 +00:00
} else {
2023-01-17 18:31:46 +00:00
return string(mem[:i])
2023-01-11 14:58:20 +00:00
}
}
2023-01-15 04:35:37 +00:00
const ptrSize = 4