This commit is contained in:
Nuno Cruces
2023-02-24 17:49:16 +00:00
parent a69ab1ebe3
commit c1472a48b0
3 changed files with 58 additions and 137 deletions

61
conn.go
View File

@@ -6,6 +6,8 @@ import (
"math"
"runtime"
"sync"
"github.com/tetratelabs/wazero/api"
)
// Conn is a database connection handle.
@@ -54,10 +56,7 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
connPtr := c.arena.new(ptrlen)
namePtr := c.arena.string(filename)
r, err := c.api.open.Call(c.ctx, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
if err != nil {
panic(err)
}
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
c.handle = c.mem.readUint32(connPtr)
if err := c.error(r[0]); err != nil {
@@ -82,11 +81,7 @@ func (c *Conn) Close() error {
c.SetInterrupt(context.Background())
r, err := c.api.close.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
r := c.call(c.api.close, uint64(c.handle))
if err := c.error(r[0]); err != nil {
return err
}
@@ -104,10 +99,7 @@ func (c *Conn) Exec(sql string) error {
defer c.arena.reset()
sqlPtr := c.arena.string(sql)
r, err := c.api.exec.Call(c.ctx, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
if err != nil {
panic(err)
}
r := c.call(c.api.exec, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
return c.error(r[0])
}
@@ -132,12 +124,9 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
tailPtr := c.arena.new(ptrlen)
sqlPtr := c.arena.string(sql)
r, err := c.api.prepare.Call(c.ctx, uint64(c.handle),
r := c.call(c.api.prepare, uint64(c.handle),
uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),
uint64(stmtPtr), uint64(tailPtr))
if err != nil {
panic(err)
}
stmt = &Stmt{c: c}
stmt.handle = c.mem.readUint32(stmtPtr)
@@ -157,10 +146,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
//
// https://www.sqlite.org/c3ref/get_autocommit.html
func (c *Conn) GetAutocommit() bool {
r, err := c.api.autocommit.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
r := c.call(c.api.autocommit, uint64(c.handle))
return r[0] != 0
}
@@ -169,10 +155,7 @@ func (c *Conn) GetAutocommit() bool {
//
// https://www.sqlite.org/c3ref/last_insert_rowid.html
func (c *Conn) LastInsertRowID() uint64 {
r, err := c.api.lastRowid.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
r := c.call(c.api.lastRowid, uint64(c.handle))
return r[0]
}
@@ -182,10 +165,7 @@ func (c *Conn) LastInsertRowID() uint64 {
//
// https://www.sqlite.org/c3ref/changes.html
func (c *Conn) Changes() uint64 {
r, err := c.api.changes.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
r := c.call(c.api.changes, uint64(c.handle))
return r[0]
}
@@ -270,10 +250,7 @@ func (c *Conn) sendInterrupt() {
defer c.mtx.Unlock()
// This is safe to call from a goroutine
// because it doesn't touch the C stack.
_, err := c.api.interrupt.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
c.call(c.api.interrupt, uint64(c.handle))
}
// Savepoint creates a named SQLite transaction using SAVEPOINT.
@@ -383,21 +360,23 @@ func (c *Conn) error(rc uint64, sql ...string) error {
return &err
}
func (c *Conn) call(fn api.Function, params ...uint64) []uint64 {
r, err := fn.Call(c.ctx, params...)
if err != nil {
panic(err)
}
return r
}
func (c *Conn) free(ptr uint32) {
if ptr == 0 {
return
}
_, err := c.api.free.Call(c.ctx, uint64(ptr))
if err != nil {
panic(err)
}
c.call(c.api.free, uint64(ptr))
}
func (c *Conn) new(size uint32) uint32 {
r, err := c.api.malloc.Call(c.ctx, uint64(size))
if err != nil {
panic(err)
}
r := c.call(c.api.malloc, uint64(size))
ptr := uint32(r[0])
if ptr == 0 && size != 0 {
panic(oomErr)

View File

@@ -6,6 +6,20 @@ import (
"testing"
)
func TestConn_call_nil(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
db.call(db.api.free)
t.Error("want panic")
}
func TestConn_new(t *testing.T) {
t.Parallel()

120
stmt.go
View File

@@ -24,10 +24,7 @@ func (s *Stmt) Close() error {
return nil
}
r, err := s.c.api.finalize.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r := s.c.call(s.c.api.finalize, uint64(s.handle))
s.handle = 0
return s.c.error(r[0])
@@ -37,10 +34,7 @@ func (s *Stmt) Close() error {
//
// https://www.sqlite.org/c3ref/reset.html
func (s *Stmt) Reset() error {
r, err := s.c.api.reset.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r := s.c.call(s.c.api.reset, uint64(s.handle))
s.err = nil
return s.c.error(r[0])
}
@@ -49,10 +43,7 @@ func (s *Stmt) Reset() error {
//
// https://www.sqlite.org/c3ref/clear_bindings.html
func (s *Stmt) ClearBindings() error {
r, err := s.c.api.clearBindings.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r := s.c.call(s.c.api.clearBindings, uint64(s.handle))
return s.c.error(r[0])
}
@@ -67,10 +58,7 @@ func (s *Stmt) ClearBindings() error {
// https://www.sqlite.org/c3ref/step.html
func (s *Stmt) Step() bool {
s.c.checkInterrupt()
r, err := s.c.api.step.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r := s.c.call(s.c.api.step, uint64(s.handle))
if r[0] == _ROW {
return true
}
@@ -102,11 +90,8 @@ func (s *Stmt) Exec() error {
//
// https://www.sqlite.org/c3ref/bind_parameter_count.html
func (s *Stmt) BindCount() int {
r, err := s.c.api.bindCount.Call(s.c.ctx,
r := s.c.call(s.c.api.bindCount,
uint64(s.handle))
if err != nil {
panic(err)
}
return int(r[0])
}
@@ -117,11 +102,8 @@ func (s *Stmt) BindCount() int {
func (s *Stmt) BindIndex(name string) int {
defer s.c.arena.reset()
namePtr := s.c.arena.string(name)
r, err := s.c.api.bindIndex.Call(s.c.ctx,
r := s.c.call(s.c.api.bindIndex,
uint64(s.handle), uint64(namePtr))
if err != nil {
panic(err)
}
return int(r[0])
}
@@ -130,11 +112,8 @@ func (s *Stmt) BindIndex(name string) int {
//
// https://www.sqlite.org/c3ref/bind_parameter_name.html
func (s *Stmt) BindName(param int) string {
r, err := s.c.api.bindName.Call(s.c.ctx,
r := s.c.call(s.c.api.bindName,
uint64(s.handle), uint64(param))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
@@ -169,11 +148,8 @@ func (s *Stmt) BindInt(param int, value int) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindInt64(param int, value int64) error {
r, err := s.c.api.bindInteger.Call(s.c.ctx,
r := s.c.call(s.c.api.bindInteger,
uint64(s.handle), uint64(param), uint64(value))
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
@@ -182,11 +158,8 @@ func (s *Stmt) BindInt64(param int, value int64) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindFloat(param int, value float64) error {
r, err := s.c.api.bindFloat.Call(s.c.ctx,
r := s.c.call(s.c.api.bindFloat,
uint64(s.handle), uint64(param), math.Float64bits(value))
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
@@ -196,13 +169,10 @@ func (s *Stmt) BindFloat(param int, value float64) error {
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindText(param int, value string) error {
ptr := s.c.newString(value)
r, err := s.c.api.bindText.Call(s.c.ctx,
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
s.c.api.destructor, _UTF8)
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
@@ -213,13 +183,10 @@ func (s *Stmt) BindText(param int, value string) error {
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindBlob(param int, value []byte) error {
ptr := s.c.newBytes(value)
r, err := s.c.api.bindBlob.Call(s.c.ctx,
r := s.c.call(s.c.api.bindBlob,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
s.c.api.destructor)
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
@@ -228,11 +195,8 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindZeroBlob(param int, n int64) error {
r, err := s.c.api.bindZeroBlob.Call(s.c.ctx,
r := s.c.call(s.c.api.bindZeroBlob,
uint64(s.handle), uint64(param), uint64(n))
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
@@ -241,11 +205,8 @@ func (s *Stmt) BindZeroBlob(param int, n int64) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindNull(param int) error {
r, err := s.c.api.bindNull.Call(s.c.ctx,
r := s.c.call(s.c.api.bindNull,
uint64(s.handle), uint64(param))
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
@@ -271,11 +232,8 @@ func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
//
// https://www.sqlite.org/c3ref/column_count.html
func (s *Stmt) ColumnCount() int {
r, err := s.c.api.columnCount.Call(s.c.ctx,
r := s.c.call(s.c.api.columnCount,
uint64(s.handle))
if err != nil {
panic(err)
}
return int(r[0])
}
@@ -284,11 +242,8 @@ func (s *Stmt) ColumnCount() int {
//
// https://www.sqlite.org/c3ref/column_name.html
func (s *Stmt) ColumnName(col int) string {
r, err := s.c.api.columnName.Call(s.c.ctx,
r := s.c.call(s.c.api.columnName,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
@@ -302,11 +257,8 @@ func (s *Stmt) ColumnName(col int) string {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnType(col int) Datatype {
r, err := s.c.api.columnType.Call(s.c.ctx,
r := s.c.call(s.c.api.columnType,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
return Datatype(r[0])
}
@@ -337,11 +289,8 @@ func (s *Stmt) ColumnInt(col int) int {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnInt64(col int) int64 {
r, err := s.c.api.columnInteger.Call(s.c.ctx,
r := s.c.call(s.c.api.columnInteger,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
return int64(r[0])
}
@@ -350,11 +299,8 @@ func (s *Stmt) ColumnInt64(col int) int64 {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnFloat(col int) float64 {
r, err := s.c.api.columnFloat.Call(s.c.ctx,
r := s.c.call(s.c.api.columnFloat,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
return math.Float64frombits(r[0])
}
@@ -388,27 +334,18 @@ func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnText(col int) string {
r, err := s.c.api.columnText.Call(s.c.ctx,
r := s.c.call(s.c.api.columnText,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
r, err = s.c.api.errcode.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r = s.c.call(s.c.api.errcode, uint64(s.handle))
s.err = s.c.error(r[0])
return ""
}
r, err = s.c.api.columnBytes.Call(s.c.ctx,
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
mem := s.c.mem.view(ptr, uint32(r[0]))
return string(mem)
@@ -420,27 +357,18 @@ func (s *Stmt) ColumnText(col int) string {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
r, err := s.c.api.columnBlob.Call(s.c.ctx,
r := s.c.call(s.c.api.columnBlob,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
r, err = s.c.api.errcode.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r = s.c.call(s.c.api.errcode, uint64(s.handle))
s.err = s.c.error(r[0])
return buf[0:0]
}
r, err = s.c.api.columnBytes.Call(s.c.ctx,
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
mem := s.c.mem.view(ptr, uint32(r[0]))
return append(buf[0:0], mem...)