This commit is contained in:
Nuno Cruces
2023-02-10 16:42:49 +00:00
parent eaf7cf57fd
commit a33b8d3459
10 changed files with 287 additions and 28 deletions

31
api.go
View File

@@ -1,28 +1,37 @@
package sqlite3 package sqlite3
import "github.com/tetratelabs/wazero/api" import (
"context"
func newConn(module api.Module) *Conn { "github.com/tetratelabs/wazero/api"
)
func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) {
getFun := func(name string) api.Function { getFun := func(name string) api.Function {
f := module.ExportedFunction(name) f := module.ExportedFunction(name)
if f == nil { if f == nil {
panic(noFuncErr + errorString(name)) err = noFuncErr + errorString(name)
return nil
} }
return f return f
} }
global := module.ExportedGlobal("malloc_destructor") getPtr := func(name string) uint32 {
if global == nil { global := module.ExportedGlobal(name)
panic(noGlobalErr + "malloc_destructor") if global == nil {
err = noGlobalErr + errorString(name)
return 0
}
return memory{module}.readUint32(uint32(global.Get()))
} }
destructor := memory{module}.readUint32(uint32(global.Get()))
return &Conn{ c := Conn{
ctx: ctx,
mem: memory{module}, mem: memory{module},
api: sqliteAPI{ api: sqliteAPI{
malloc: getFun("malloc"), malloc: getFun("malloc"),
free: getFun("free"), free: getFun("free"),
destructor: uint64(destructor), destructor: uint64(getPtr("malloc_destructor")),
errcode: getFun("sqlite3_errcode"), errcode: getFun("sqlite3_errcode"),
errstr: getFun("sqlite3_errstr"), errstr: getFun("sqlite3_errstr"),
errmsg: getFun("sqlite3_errmsg"), errmsg: getFun("sqlite3_errmsg"),
@@ -49,6 +58,10 @@ func newConn(module api.Module) *Conn {
columnType: getFun("sqlite3_column_type"), columnType: getFun("sqlite3_column_type"),
}, },
} }
if err != nil {
return nil, err
}
return &c, nil
} }
type sqliteAPI struct { type sqliteAPI struct {

22
conn.go
View File

@@ -2,6 +2,7 @@ package sqlite3
import ( import (
"context" "context"
"math"
) )
// Conn is a database connection handle. // Conn is a database connection handle.
@@ -34,8 +35,11 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
} }
}() }()
c := newConn(module) c, err := newConn(ctx, module)
c.ctx = context.Background() if err != nil {
return nil, err
}
namePtr := c.newString(filename) namePtr := c.newString(filename)
connPtr := c.new(ptrlen) connPtr := c.new(ptrlen)
defer c.free(namePtr) defer c.free(namePtr)
@@ -124,7 +128,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
i := c.mem.readUint32(tailPtr) i := c.mem.readUint32(tailPtr)
tail = sql[i-sqlPtr:] tail = sql[i-sqlPtr:]
if err := c.error(r[0]); err != nil { if err := c.error(r[0], sql); err != nil {
return nil, "", err return nil, "", err
} }
if stmt.handle == 0 { if stmt.handle == 0 {
@@ -133,7 +137,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
return return
} }
func (c *Conn) error(rc uint64) error { func (c *Conn) error(rc uint64, sql ...string) error {
if rc == _OK { if rc == _OK {
return nil return nil
} }
@@ -146,12 +150,20 @@ func (c *Conn) error(rc uint64) error {
var r []uint64 var r []uint64
// Do this first, sqlite3_errmsg is guaranteed to never change the value of the error code. // sqlite3_errmsg is guaranteed to never change the value of the error code.
r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle)) r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle))
if r != nil { if r != nil {
err.msg = c.mem.readString(uint32(r[0]), 512) err.msg = c.mem.readString(uint32(r[0]), 512)
} }
if sql != nil {
// sqlite3_error_offset is guaranteed to never change the value of the error code.
r, _ = c.api.erroff.Call(c.ctx, uint64(c.handle))
if r != nil && r[0] != math.MaxUint32 {
err.sql = sql[0][r[0]:]
}
}
r, _ = c.api.errstr.Call(c.ctx, rc) r, _ = c.api.errstr.Call(c.ctx, rc)
if r != nil { if r != nil {
err.str = c.mem.readString(uint32(r[0]), 512) err.str = c.mem.readString(uint32(r[0]), 512)

View File

@@ -2,10 +2,104 @@ package sqlite3
import ( import (
"bytes" "bytes"
"errors"
"math" "math"
"testing" "testing"
) )
func TestConn_Close(t *testing.T) {
var conn *Conn
conn.Close()
}
func TestConn_Close_BUSY(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare("BEGIN")
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
err = db.Close()
if err == nil {
t.Fatal("want error")
}
var serr *Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != BUSY {
t.Errorf("got %d, want sqlite3.BUSY", rc)
}
if got := err.Error(); got != `sqlite3: database is locked: unable to close due to unfinalized statements or unfinished backups` {
t.Error("got message: ", got)
}
}
func TestConn_Prepare_Empty(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare("")
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt != nil {
t.Error("want nil")
}
}
func TestConn_Prepare_Invalid(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
var serr *Error
_, _, err = db.Prepare("SELECT")
if err == nil {
t.Fatal("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message: ", got)
}
_, _, err = db.Prepare("SELECT * FRM sqlite_schema")
if err == nil {
t.Fatal("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.ERROR", err)
}
if rc := serr.Code(); rc != ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := serr.SQL(); got != `FRM sqlite_schema` {
t.Error("got SQL: ", got)
}
if got := serr.Error(); got != `sqlite3: SQL logic error: near "FRM": syntax error` {
t.Error("got message: ", got)
}
}
func TestConn_new(t *testing.T) { func TestConn_new(t *testing.T) {
db, err := Open(":memory:") db, err := Open(":memory:")
if err != nil { if err != nil {
@@ -15,7 +109,7 @@ func TestConn_new(t *testing.T) {
defer func() { _ = recover() }() defer func() { _ = recover() }()
db.new(math.MaxUint32) db.new(math.MaxUint32)
t.Error("should have panicked") t.Error("want panic")
} }
func TestConn_newBytes(t *testing.T) { func TestConn_newBytes(t *testing.T) {
@@ -27,7 +121,7 @@ func TestConn_newBytes(t *testing.T) {
ptr := db.newBytes(nil) ptr := db.newBytes(nil)
if ptr != 0 { if ptr != 0 {
t.Errorf("got %x, want nullptr", ptr) t.Errorf("got %#x, want nullptr", ptr)
} }
buf := []byte("sqlite3") buf := []byte("sqlite3")
@@ -95,13 +189,13 @@ func TestConn_getString(t *testing.T) {
func() { func() {
defer func() { _ = recover() }() defer func() { _ = recover() }()
db.mem.readString(ptr, uint32(len(want)/2)) db.mem.readString(ptr, uint32(len(want)/2))
t.Error("should have panicked") t.Error("want panic")
}() }()
func() { func() {
defer func() { _ = recover() }() defer func() { _ = recover() }()
db.mem.readString(0, math.MaxUint32) db.mem.readString(0, math.MaxUint32)
t.Error("should have panicked") t.Error("want panic")
}() }()
} }
@@ -121,8 +215,3 @@ func TestConn_free(t *testing.T) {
db.free(ptr) db.free(ptr)
} }
func TestConn_Close(t *testing.T) {
var conn *Conn
conn.Close()
}

View File

@@ -13,6 +13,7 @@ type Error struct {
code uint64 code uint64
str string str string
msg string msg string
sql string
} }
// Code returns the primary error code for this error. // Code returns the primary error code for this error.
@@ -49,6 +50,11 @@ func (e *Error) Error() string {
return b.String() return b.String()
} }
// SQL returns the SQL starting at the token that triggered a syntax error.
func (e *Error) SQL() string {
return e.sql
}
type errorString string type errorString string
func (e errorString) Error() string { return string(e) } func (e errorString) Error() string { return string(e) }

26
error_test.go Normal file
View File

@@ -0,0 +1,26 @@
package sqlite3
import (
"strings"
"testing"
)
func TestError(t *testing.T) {
err := Error{code: 0x8080}
if rc := err.Code(); rc != 0x80 {
t.Errorf("got %#x, want 0x80", rc)
}
if rc := err.ExtendedCode(); rc != 0x8080 {
t.Errorf("got %#x, want 0x8080", rc)
}
if s := err.Error(); s != "sqlite3: 32896" {
t.Errorf("got %q", s)
}
}
func Test_assertErr(t *testing.T) {
err := assertErr()
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:22)") {
t.Errorf("got %q", s)
}
}

83
mem_test.go Normal file
View File

@@ -0,0 +1,83 @@
package sqlite3
import (
"math"
"testing"
)
func Test_memory_view_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.view(0, 8)
t.Error("want panic")
}
func Test_memory_view_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.view(126, 8)
t.Error("want panic")
}
func Test_memory_readUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint32(0)
t.Error("want panic")
}
func Test_memory_readUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint32(126)
t.Error("want panic")
}
func Test_memory_readUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint64(0)
t.Error("want panic")
}
func Test_memory_readUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint64(126)
t.Error("want panic")
}
func Test_memory_writeUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint32(0, 1)
t.Error("want panic")
}
func Test_memory_writeUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint32(126, 1)
t.Error("want panic")
}
func Test_memory_writeUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint64(0, 1)
t.Error("want panic")
}
func Test_memory_writeUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint64(126, 1)
t.Error("want panic")
}
func Test_memory_readString_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readString(130, math.MaxUint32)
t.Error("want panic")
}

View File

@@ -0,0 +1,15 @@
package compile_empty
import (
"testing"
"github.com/ncruces/go-sqlite3"
)
func TestCompile_empty(t *testing.T) {
sqlite3.Binary = []byte("\x00asm\x01\x00\x00\x00")
_, err := sqlite3.Open(":memory:")
if err == nil {
t.Error("want error")
}
}

View File

@@ -0,0 +1,15 @@
package compile_empty
import (
"testing"
"github.com/ncruces/go-sqlite3"
)
func TestCompile_empty(t *testing.T) {
sqlite3.Path = "sqlite3.wasm"
_, err := sqlite3.Open(":memory:")
if err == nil {
t.Error("want error")
}
}

View File

@@ -15,12 +15,12 @@ func TestDir(t *testing.T) {
} }
var serr *sqlite3.Error var serr *sqlite3.Error
if !errors.As(err, &serr) { if !errors.As(err, &serr) {
t.Fatal("want sqlite3.Error") t.Fatalf("got %T, want sqlite3.Error", err)
} }
if serr.Code() != sqlite3.CANTOPEN { if rc := serr.Code(); rc != sqlite3.CANTOPEN {
t.Error("want sqlite3.CANTOPEN") t.Errorf("got %d, want sqlite3.CANTOPEN", rc)
} }
if got := err.Error(); got != "sqlite3: unable to open database file" { if got := err.Error(); got != `sqlite3: unable to open database file` {
t.Error("got message: ", got) t.Error("got message: ", got)
} }
} }

View File

@@ -17,7 +17,7 @@ func Test_vfsExit(t *testing.T) {
mem := newMemory(128) mem := newMemory(128)
defer func() { _ = recover() }() defer func() { _ = recover() }()
vfsExit(context.TODO(), mem.mod, 1) vfsExit(context.TODO(), mem.mod, 1)
t.Error("should have panicked") t.Error("want panic")
} }
func Test_vfsLocaltime(t *testing.T) { func Test_vfsLocaltime(t *testing.T) {