diff --git a/conn.go b/conn.go index 8e12e0c..6bc1037 100644 --- a/conn.go +++ b/conn.go @@ -6,6 +6,8 @@ import ( "math" "runtime" "sync" + + "github.com/tetratelabs/wazero/api" ) // Conn is a database connection handle. @@ -54,10 +56,7 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) { connPtr := c.arena.new(ptrlen) namePtr := c.arena.string(filename) - r, err := c.api.open.Call(c.ctx, uint64(namePtr), uint64(connPtr), uint64(flags), 0) - if err != nil { - panic(err) - } + r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0) c.handle = c.mem.readUint32(connPtr) if err := c.error(r[0]); err != nil { @@ -82,11 +81,7 @@ func (c *Conn) Close() error { c.SetInterrupt(context.Background()) - r, err := c.api.close.Call(c.ctx, uint64(c.handle)) - if err != nil { - panic(err) - } - + r := c.call(c.api.close, uint64(c.handle)) if err := c.error(r[0]); err != nil { return err } @@ -104,10 +99,7 @@ func (c *Conn) Exec(sql string) error { defer c.arena.reset() sqlPtr := c.arena.string(sql) - r, err := c.api.exec.Call(c.ctx, uint64(c.handle), uint64(sqlPtr), 0, 0, 0) - if err != nil { - panic(err) - } + r := c.call(c.api.exec, uint64(c.handle), uint64(sqlPtr), 0, 0, 0) return c.error(r[0]) } @@ -132,12 +124,9 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str tailPtr := c.arena.new(ptrlen) sqlPtr := c.arena.string(sql) - r, err := c.api.prepare.Call(c.ctx, uint64(c.handle), + r := c.call(c.api.prepare, uint64(c.handle), uint64(sqlPtr), uint64(len(sql)+1), uint64(flags), uint64(stmtPtr), uint64(tailPtr)) - if err != nil { - panic(err) - } stmt = &Stmt{c: c} stmt.handle = c.mem.readUint32(stmtPtr) @@ -157,10 +146,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str // // https://www.sqlite.org/c3ref/get_autocommit.html func (c *Conn) GetAutocommit() bool { - r, err := c.api.autocommit.Call(c.ctx, uint64(c.handle)) - if err != nil { - panic(err) - } + r := c.call(c.api.autocommit, uint64(c.handle)) return r[0] != 0 } @@ -169,10 +155,7 @@ func (c *Conn) GetAutocommit() bool { // // https://www.sqlite.org/c3ref/last_insert_rowid.html func (c *Conn) LastInsertRowID() uint64 { - r, err := c.api.lastRowid.Call(c.ctx, uint64(c.handle)) - if err != nil { - panic(err) - } + r := c.call(c.api.lastRowid, uint64(c.handle)) return r[0] } @@ -182,10 +165,7 @@ func (c *Conn) LastInsertRowID() uint64 { // // https://www.sqlite.org/c3ref/changes.html func (c *Conn) Changes() uint64 { - r, err := c.api.changes.Call(c.ctx, uint64(c.handle)) - if err != nil { - panic(err) - } + r := c.call(c.api.changes, uint64(c.handle)) return r[0] } @@ -270,10 +250,7 @@ func (c *Conn) sendInterrupt() { defer c.mtx.Unlock() // This is safe to call from a goroutine // because it doesn't touch the C stack. - _, err := c.api.interrupt.Call(c.ctx, uint64(c.handle)) - if err != nil { - panic(err) - } + c.call(c.api.interrupt, uint64(c.handle)) } // Savepoint creates a named SQLite transaction using SAVEPOINT. @@ -383,21 +360,23 @@ func (c *Conn) error(rc uint64, sql ...string) error { return &err } +func (c *Conn) call(fn api.Function, params ...uint64) []uint64 { + r, err := fn.Call(c.ctx, params...) + if err != nil { + panic(err) + } + return r +} + func (c *Conn) free(ptr uint32) { if ptr == 0 { return } - _, err := c.api.free.Call(c.ctx, uint64(ptr)) - if err != nil { - panic(err) - } + c.call(c.api.free, uint64(ptr)) } func (c *Conn) new(size uint32) uint32 { - r, err := c.api.malloc.Call(c.ctx, uint64(size)) - if err != nil { - panic(err) - } + r := c.call(c.api.malloc, uint64(size)) ptr := uint32(r[0]) if ptr == 0 && size != 0 { panic(oomErr) diff --git a/conn_test.go b/conn_test.go index d00bfca..d12bbcd 100644 --- a/conn_test.go +++ b/conn_test.go @@ -6,6 +6,20 @@ import ( "testing" ) +func TestConn_call_nil(t *testing.T) { + t.Parallel() + + db, err := Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + defer func() { _ = recover() }() + db.call(db.api.free) + t.Error("want panic") +} + func TestConn_new(t *testing.T) { t.Parallel() diff --git a/stmt.go b/stmt.go index 1215482..f739078 100644 --- a/stmt.go +++ b/stmt.go @@ -24,10 +24,7 @@ func (s *Stmt) Close() error { return nil } - r, err := s.c.api.finalize.Call(s.c.ctx, uint64(s.handle)) - if err != nil { - panic(err) - } + r := s.c.call(s.c.api.finalize, uint64(s.handle)) s.handle = 0 return s.c.error(r[0]) @@ -37,10 +34,7 @@ func (s *Stmt) Close() error { // // https://www.sqlite.org/c3ref/reset.html func (s *Stmt) Reset() error { - r, err := s.c.api.reset.Call(s.c.ctx, uint64(s.handle)) - if err != nil { - panic(err) - } + r := s.c.call(s.c.api.reset, uint64(s.handle)) s.err = nil return s.c.error(r[0]) } @@ -49,10 +43,7 @@ func (s *Stmt) Reset() error { // // https://www.sqlite.org/c3ref/clear_bindings.html func (s *Stmt) ClearBindings() error { - r, err := s.c.api.clearBindings.Call(s.c.ctx, uint64(s.handle)) - if err != nil { - panic(err) - } + r := s.c.call(s.c.api.clearBindings, uint64(s.handle)) return s.c.error(r[0]) } @@ -67,10 +58,7 @@ func (s *Stmt) ClearBindings() error { // https://www.sqlite.org/c3ref/step.html func (s *Stmt) Step() bool { s.c.checkInterrupt() - r, err := s.c.api.step.Call(s.c.ctx, uint64(s.handle)) - if err != nil { - panic(err) - } + r := s.c.call(s.c.api.step, uint64(s.handle)) if r[0] == _ROW { return true } @@ -102,11 +90,8 @@ func (s *Stmt) Exec() error { // // https://www.sqlite.org/c3ref/bind_parameter_count.html func (s *Stmt) BindCount() int { - r, err := s.c.api.bindCount.Call(s.c.ctx, + r := s.c.call(s.c.api.bindCount, uint64(s.handle)) - if err != nil { - panic(err) - } return int(r[0]) } @@ -117,11 +102,8 @@ func (s *Stmt) BindCount() int { func (s *Stmt) BindIndex(name string) int { defer s.c.arena.reset() namePtr := s.c.arena.string(name) - r, err := s.c.api.bindIndex.Call(s.c.ctx, + r := s.c.call(s.c.api.bindIndex, uint64(s.handle), uint64(namePtr)) - if err != nil { - panic(err) - } return int(r[0]) } @@ -130,11 +112,8 @@ func (s *Stmt) BindIndex(name string) int { // // https://www.sqlite.org/c3ref/bind_parameter_name.html func (s *Stmt) BindName(param int) string { - r, err := s.c.api.bindName.Call(s.c.ctx, + r := s.c.call(s.c.api.bindName, uint64(s.handle), uint64(param)) - if err != nil { - panic(err) - } ptr := uint32(r[0]) if ptr == 0 { @@ -169,11 +148,8 @@ func (s *Stmt) BindInt(param int, value int) error { // // https://www.sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindInt64(param int, value int64) error { - r, err := s.c.api.bindInteger.Call(s.c.ctx, + r := s.c.call(s.c.api.bindInteger, uint64(s.handle), uint64(param), uint64(value)) - if err != nil { - panic(err) - } return s.c.error(r[0]) } @@ -182,11 +158,8 @@ func (s *Stmt) BindInt64(param int, value int64) error { // // https://www.sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindFloat(param int, value float64) error { - r, err := s.c.api.bindFloat.Call(s.c.ctx, + r := s.c.call(s.c.api.bindFloat, uint64(s.handle), uint64(param), math.Float64bits(value)) - if err != nil { - panic(err) - } return s.c.error(r[0]) } @@ -196,13 +169,10 @@ func (s *Stmt) BindFloat(param int, value float64) error { // https://www.sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindText(param int, value string) error { ptr := s.c.newString(value) - r, err := s.c.api.bindText.Call(s.c.ctx, + r := s.c.call(s.c.api.bindText, uint64(s.handle), uint64(param), uint64(ptr), uint64(len(value)), s.c.api.destructor, _UTF8) - if err != nil { - panic(err) - } return s.c.error(r[0]) } @@ -213,13 +183,10 @@ func (s *Stmt) BindText(param int, value string) error { // https://www.sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindBlob(param int, value []byte) error { ptr := s.c.newBytes(value) - r, err := s.c.api.bindBlob.Call(s.c.ctx, + r := s.c.call(s.c.api.bindBlob, uint64(s.handle), uint64(param), uint64(ptr), uint64(len(value)), s.c.api.destructor) - if err != nil { - panic(err) - } return s.c.error(r[0]) } @@ -228,11 +195,8 @@ func (s *Stmt) BindBlob(param int, value []byte) error { // // https://www.sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindZeroBlob(param int, n int64) error { - r, err := s.c.api.bindZeroBlob.Call(s.c.ctx, + r := s.c.call(s.c.api.bindZeroBlob, uint64(s.handle), uint64(param), uint64(n)) - if err != nil { - panic(err) - } return s.c.error(r[0]) } @@ -241,11 +205,8 @@ func (s *Stmt) BindZeroBlob(param int, n int64) error { // // https://www.sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindNull(param int) error { - r, err := s.c.api.bindNull.Call(s.c.ctx, + r := s.c.call(s.c.api.bindNull, uint64(s.handle), uint64(param)) - if err != nil { - panic(err) - } return s.c.error(r[0]) } @@ -271,11 +232,8 @@ func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error { // // https://www.sqlite.org/c3ref/column_count.html func (s *Stmt) ColumnCount() int { - r, err := s.c.api.columnCount.Call(s.c.ctx, + r := s.c.call(s.c.api.columnCount, uint64(s.handle)) - if err != nil { - panic(err) - } return int(r[0]) } @@ -284,11 +242,8 @@ func (s *Stmt) ColumnCount() int { // // https://www.sqlite.org/c3ref/column_name.html func (s *Stmt) ColumnName(col int) string { - r, err := s.c.api.columnName.Call(s.c.ctx, + r := s.c.call(s.c.api.columnName, uint64(s.handle), uint64(col)) - if err != nil { - panic(err) - } ptr := uint32(r[0]) if ptr == 0 { @@ -302,11 +257,8 @@ func (s *Stmt) ColumnName(col int) string { // // https://www.sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnType(col int) Datatype { - r, err := s.c.api.columnType.Call(s.c.ctx, + r := s.c.call(s.c.api.columnType, uint64(s.handle), uint64(col)) - if err != nil { - panic(err) - } return Datatype(r[0]) } @@ -337,11 +289,8 @@ func (s *Stmt) ColumnInt(col int) int { // // https://www.sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnInt64(col int) int64 { - r, err := s.c.api.columnInteger.Call(s.c.ctx, + r := s.c.call(s.c.api.columnInteger, uint64(s.handle), uint64(col)) - if err != nil { - panic(err) - } return int64(r[0]) } @@ -350,11 +299,8 @@ func (s *Stmt) ColumnInt64(col int) int64 { // // https://www.sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnFloat(col int) float64 { - r, err := s.c.api.columnFloat.Call(s.c.ctx, + r := s.c.call(s.c.api.columnFloat, uint64(s.handle), uint64(col)) - if err != nil { - panic(err) - } return math.Float64frombits(r[0]) } @@ -388,27 +334,18 @@ func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time { // // https://www.sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnText(col int) string { - r, err := s.c.api.columnText.Call(s.c.ctx, + r := s.c.call(s.c.api.columnText, uint64(s.handle), uint64(col)) - if err != nil { - panic(err) - } ptr := uint32(r[0]) if ptr == 0 { - r, err = s.c.api.errcode.Call(s.c.ctx, uint64(s.handle)) - if err != nil { - panic(err) - } + r = s.c.call(s.c.api.errcode, uint64(s.handle)) s.err = s.c.error(r[0]) return "" } - r, err = s.c.api.columnBytes.Call(s.c.ctx, + r = s.c.call(s.c.api.columnBytes, uint64(s.handle), uint64(col)) - if err != nil { - panic(err) - } mem := s.c.mem.view(ptr, uint32(r[0])) return string(mem) @@ -420,27 +357,18 @@ func (s *Stmt) ColumnText(col int) string { // // https://www.sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnBlob(col int, buf []byte) []byte { - r, err := s.c.api.columnBlob.Call(s.c.ctx, + r := s.c.call(s.c.api.columnBlob, uint64(s.handle), uint64(col)) - if err != nil { - panic(err) - } ptr := uint32(r[0]) if ptr == 0 { - r, err = s.c.api.errcode.Call(s.c.ctx, uint64(s.handle)) - if err != nil { - panic(err) - } + r = s.c.call(s.c.api.errcode, uint64(s.handle)) s.err = s.c.error(r[0]) return buf[0:0] } - r, err = s.c.api.columnBytes.Call(s.c.ctx, + r = s.c.call(s.c.api.columnBytes, uint64(s.handle), uint64(col)) - if err != nil { - panic(err) - } mem := s.c.mem.view(ptr, uint32(r[0])) return append(buf[0:0], mem...)