From a33b8d3459d5232898724349f3fe70975bddcdbb Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 10 Feb 2023 16:42:49 +0000 Subject: [PATCH] Tests. --- api.go | 31 +++++--- conn.go | 22 ++++-- conn_test.go | 107 +++++++++++++++++++++++--- error.go | 6 ++ error_test.go | 26 +++++++ mem_test.go | 83 ++++++++++++++++++++ tests/compile/empty/compile_test.go | 15 ++++ tests/compile/missing/compile_test.go | 15 ++++ tests/dir_test.go | 8 +- vfs_test.go | 2 +- 10 files changed, 287 insertions(+), 28 deletions(-) create mode 100644 error_test.go create mode 100644 mem_test.go create mode 100644 tests/compile/empty/compile_test.go create mode 100644 tests/compile/missing/compile_test.go diff --git a/api.go b/api.go index e8997a9..5a0b075 100644 --- a/api.go +++ b/api.go @@ -1,28 +1,37 @@ package sqlite3 -import "github.com/tetratelabs/wazero/api" +import ( + "context" -func newConn(module api.Module) *Conn { + "github.com/tetratelabs/wazero/api" +) + +func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) { getFun := func(name string) api.Function { f := module.ExportedFunction(name) if f == nil { - panic(noFuncErr + errorString(name)) + err = noFuncErr + errorString(name) + return nil } return f } - global := module.ExportedGlobal("malloc_destructor") - if global == nil { - panic(noGlobalErr + "malloc_destructor") + getPtr := func(name string) uint32 { + global := module.ExportedGlobal(name) + if global == nil { + err = noGlobalErr + errorString(name) + return 0 + } + return memory{module}.readUint32(uint32(global.Get())) } - destructor := memory{module}.readUint32(uint32(global.Get())) - return &Conn{ + c := Conn{ + ctx: ctx, mem: memory{module}, api: sqliteAPI{ malloc: getFun("malloc"), free: getFun("free"), - destructor: uint64(destructor), + destructor: uint64(getPtr("malloc_destructor")), errcode: getFun("sqlite3_errcode"), errstr: getFun("sqlite3_errstr"), errmsg: getFun("sqlite3_errmsg"), @@ -49,6 +58,10 @@ func newConn(module api.Module) *Conn { columnType: getFun("sqlite3_column_type"), }, } + if err != nil { + return nil, err + } + return &c, nil } type sqliteAPI struct { diff --git a/conn.go b/conn.go index d74b314..d777e16 100644 --- a/conn.go +++ b/conn.go @@ -2,6 +2,7 @@ package sqlite3 import ( "context" + "math" ) // Conn is a database connection handle. @@ -34,8 +35,11 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) { } }() - c := newConn(module) - c.ctx = context.Background() + c, err := newConn(ctx, module) + if err != nil { + return nil, err + } + namePtr := c.newString(filename) connPtr := c.new(ptrlen) defer c.free(namePtr) @@ -124,7 +128,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str i := c.mem.readUint32(tailPtr) tail = sql[i-sqlPtr:] - if err := c.error(r[0]); err != nil { + if err := c.error(r[0], sql); err != nil { return nil, "", err } if stmt.handle == 0 { @@ -133,7 +137,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str return } -func (c *Conn) error(rc uint64) error { +func (c *Conn) error(rc uint64, sql ...string) error { if rc == _OK { return nil } @@ -146,12 +150,20 @@ func (c *Conn) error(rc uint64) error { var r []uint64 - // Do this first, sqlite3_errmsg is guaranteed to never change the value of the error code. + // sqlite3_errmsg is guaranteed to never change the value of the error code. r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle)) if r != nil { err.msg = c.mem.readString(uint32(r[0]), 512) } + if sql != nil { + // sqlite3_error_offset is guaranteed to never change the value of the error code. + r, _ = c.api.erroff.Call(c.ctx, uint64(c.handle)) + if r != nil && r[0] != math.MaxUint32 { + err.sql = sql[0][r[0]:] + } + } + r, _ = c.api.errstr.Call(c.ctx, rc) if r != nil { err.str = c.mem.readString(uint32(r[0]), 512) diff --git a/conn_test.go b/conn_test.go index 07c2689..83ed24b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -2,10 +2,104 @@ package sqlite3 import ( "bytes" + "errors" "math" "testing" ) +func TestConn_Close(t *testing.T) { + var conn *Conn + conn.Close() +} + +func TestConn_Close_BUSY(t *testing.T) { + db, err := Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + stmt, _, err := db.Prepare("BEGIN") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + err = db.Close() + if err == nil { + t.Fatal("want error") + } + var serr *Error + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != BUSY { + t.Errorf("got %d, want sqlite3.BUSY", rc) + } + if got := err.Error(); got != `sqlite3: database is locked: unable to close due to unfinalized statements or unfinished backups` { + t.Error("got message: ", got) + } +} + +func TestConn_Prepare_Empty(t *testing.T) { + db, err := Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + stmt, _, err := db.Prepare("") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if stmt != nil { + t.Error("want nil") + } +} + +func TestConn_Prepare_Invalid(t *testing.T) { + db, err := Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + var serr *Error + + _, _, err = db.Prepare("SELECT") + if err == nil { + t.Fatal("want error") + } + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != ERROR { + t.Errorf("got %d, want sqlite3.ERROR", rc) + } + if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` { + t.Error("got message: ", got) + } + + _, _, err = db.Prepare("SELECT * FRM sqlite_schema") + if err == nil { + t.Fatal("want error") + } + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.ERROR", err) + } + if rc := serr.Code(); rc != ERROR { + t.Errorf("got %d, want sqlite3.ERROR", rc) + } + if got := serr.SQL(); got != `FRM sqlite_schema` { + t.Error("got SQL: ", got) + } + if got := serr.Error(); got != `sqlite3: SQL logic error: near "FRM": syntax error` { + t.Error("got message: ", got) + } +} + func TestConn_new(t *testing.T) { db, err := Open(":memory:") if err != nil { @@ -15,7 +109,7 @@ func TestConn_new(t *testing.T) { defer func() { _ = recover() }() db.new(math.MaxUint32) - t.Error("should have panicked") + t.Error("want panic") } func TestConn_newBytes(t *testing.T) { @@ -27,7 +121,7 @@ func TestConn_newBytes(t *testing.T) { ptr := db.newBytes(nil) if ptr != 0 { - t.Errorf("got %x, want nullptr", ptr) + t.Errorf("got %#x, want nullptr", ptr) } buf := []byte("sqlite3") @@ -95,13 +189,13 @@ func TestConn_getString(t *testing.T) { func() { defer func() { _ = recover() }() db.mem.readString(ptr, uint32(len(want)/2)) - t.Error("should have panicked") + t.Error("want panic") }() func() { defer func() { _ = recover() }() db.mem.readString(0, math.MaxUint32) - t.Error("should have panicked") + t.Error("want panic") }() } @@ -121,8 +215,3 @@ func TestConn_free(t *testing.T) { db.free(ptr) } - -func TestConn_Close(t *testing.T) { - var conn *Conn - conn.Close() -} diff --git a/error.go b/error.go index 47585f9..686c71e 100644 --- a/error.go +++ b/error.go @@ -13,6 +13,7 @@ type Error struct { code uint64 str string msg string + sql string } // Code returns the primary error code for this error. @@ -49,6 +50,11 @@ func (e *Error) Error() string { return b.String() } +// SQL returns the SQL starting at the token that triggered a syntax error. +func (e *Error) SQL() string { + return e.sql +} + type errorString string func (e errorString) Error() string { return string(e) } diff --git a/error_test.go b/error_test.go new file mode 100644 index 0000000..1e8d769 --- /dev/null +++ b/error_test.go @@ -0,0 +1,26 @@ +package sqlite3 + +import ( + "strings" + "testing" +) + +func TestError(t *testing.T) { + err := Error{code: 0x8080} + if rc := err.Code(); rc != 0x80 { + t.Errorf("got %#x, want 0x80", rc) + } + if rc := err.ExtendedCode(); rc != 0x8080 { + t.Errorf("got %#x, want 0x8080", rc) + } + if s := err.Error(); s != "sqlite3: 32896" { + t.Errorf("got %q", s) + } +} + +func Test_assertErr(t *testing.T) { + err := assertErr() + if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:22)") { + t.Errorf("got %q", s) + } +} diff --git a/mem_test.go b/mem_test.go new file mode 100644 index 0000000..3a9836c --- /dev/null +++ b/mem_test.go @@ -0,0 +1,83 @@ +package sqlite3 + +import ( + "math" + "testing" +) + +func Test_memory_view_nil(t *testing.T) { + defer func() { _ = recover() }() + mem := newMemory(128) + mem.view(0, 8) + t.Error("want panic") +} + +func Test_memory_view_range(t *testing.T) { + defer func() { _ = recover() }() + mem := newMemory(128) + mem.view(126, 8) + t.Error("want panic") +} + +func Test_memory_readUint32_nil(t *testing.T) { + defer func() { _ = recover() }() + mem := newMemory(128) + mem.readUint32(0) + t.Error("want panic") +} + +func Test_memory_readUint32_range(t *testing.T) { + defer func() { _ = recover() }() + mem := newMemory(128) + mem.readUint32(126) + t.Error("want panic") +} + +func Test_memory_readUint64_nil(t *testing.T) { + defer func() { _ = recover() }() + mem := newMemory(128) + mem.readUint64(0) + t.Error("want panic") +} + +func Test_memory_readUint64_range(t *testing.T) { + defer func() { _ = recover() }() + mem := newMemory(128) + mem.readUint64(126) + t.Error("want panic") +} + +func Test_memory_writeUint32_nil(t *testing.T) { + defer func() { _ = recover() }() + mem := newMemory(128) + mem.writeUint32(0, 1) + t.Error("want panic") +} + +func Test_memory_writeUint32_range(t *testing.T) { + defer func() { _ = recover() }() + mem := newMemory(128) + mem.writeUint32(126, 1) + t.Error("want panic") +} + +func Test_memory_writeUint64_nil(t *testing.T) { + defer func() { _ = recover() }() + mem := newMemory(128) + mem.writeUint64(0, 1) + t.Error("want panic") +} + +func Test_memory_writeUint64_range(t *testing.T) { + defer func() { _ = recover() }() + mem := newMemory(128) + mem.writeUint64(126, 1) + t.Error("want panic") +} + +func Test_memory_readString_range(t *testing.T) { + defer func() { _ = recover() }() + mem := newMemory(128) + mem.readString(130, math.MaxUint32) + t.Error("want panic") +} diff --git a/tests/compile/empty/compile_test.go b/tests/compile/empty/compile_test.go new file mode 100644 index 0000000..9648c6a --- /dev/null +++ b/tests/compile/empty/compile_test.go @@ -0,0 +1,15 @@ +package compile_empty + +import ( + "testing" + + "github.com/ncruces/go-sqlite3" +) + +func TestCompile_empty(t *testing.T) { + sqlite3.Binary = []byte("\x00asm\x01\x00\x00\x00") + _, err := sqlite3.Open(":memory:") + if err == nil { + t.Error("want error") + } +} diff --git a/tests/compile/missing/compile_test.go b/tests/compile/missing/compile_test.go new file mode 100644 index 0000000..be9cdb0 --- /dev/null +++ b/tests/compile/missing/compile_test.go @@ -0,0 +1,15 @@ +package compile_empty + +import ( + "testing" + + "github.com/ncruces/go-sqlite3" +) + +func TestCompile_empty(t *testing.T) { + sqlite3.Path = "sqlite3.wasm" + _, err := sqlite3.Open(":memory:") + if err == nil { + t.Error("want error") + } +} diff --git a/tests/dir_test.go b/tests/dir_test.go index 35e7b07..928961b 100644 --- a/tests/dir_test.go +++ b/tests/dir_test.go @@ -15,12 +15,12 @@ func TestDir(t *testing.T) { } var serr *sqlite3.Error if !errors.As(err, &serr) { - t.Fatal("want sqlite3.Error") + t.Fatalf("got %T, want sqlite3.Error", err) } - if serr.Code() != sqlite3.CANTOPEN { - t.Error("want sqlite3.CANTOPEN") + if rc := serr.Code(); rc != sqlite3.CANTOPEN { + t.Errorf("got %d, want sqlite3.CANTOPEN", rc) } - if got := err.Error(); got != "sqlite3: unable to open database file" { + if got := err.Error(); got != `sqlite3: unable to open database file` { t.Error("got message: ", got) } } diff --git a/vfs_test.go b/vfs_test.go index 61d4b94..42a81ca 100644 --- a/vfs_test.go +++ b/vfs_test.go @@ -17,7 +17,7 @@ func Test_vfsExit(t *testing.T) { mem := newMemory(128) defer func() { _ = recover() }() vfsExit(context.TODO(), mem.mod, 1) - t.Error("should have panicked") + t.Error("want panic") } func Test_vfsLocaltime(t *testing.T) {