MustPrepare.

This commit is contained in:
Nuno Cruces
2023-02-25 01:29:46 +00:00
parent c1472a48b0
commit e6cd0aaf87
3 changed files with 62 additions and 2 deletions

15
conn.go
View File

@@ -103,6 +103,19 @@ func (c *Conn) Exec(sql string) error {
return c.error(r[0])
}
// MustPrepare calls [Conn.Prepare] and panics on error,
// or a non empty tail.
func (c *Conn) MustPrepare(sql string) *Stmt {
s, tail, err := c.PrepareFlags(sql, 0)
if err != nil {
panic(err)
}
if !emptyStatement(tail) {
panic(tailErr)
}
return s
}
// Prepare calls [Conn.PrepareFlags] with no flags.
func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
return c.PrepareFlags(sql, 0)
@@ -205,7 +218,7 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
// Creating an uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
if c.pending == nil {
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
c.pending = c.MustPrepare(`SELECT 1 UNION ALL SELECT 2`)
c.pending.Step()
} else {
c.pending.Reset()

View File

@@ -65,14 +65,15 @@ type errorString string
func (e errorString) Error() string { return string(e) }
const (
binaryErr = errorString("sqlite3: no SQLite binary embed/set/loaded")
nilErr = errorString("sqlite3: invalid memory address or null pointer dereference")
oomErr = errorString("sqlite3: out of memory")
rangeErr = errorString("sqlite3: index out of range")
noNulErr = errorString("sqlite3: missing NUL terminator")
noGlobalErr = errorString("sqlite3: could not find global: ")
noFuncErr = errorString("sqlite3: could not find function: ")
binaryErr = errorString("sqlite3: no SQLite binary embed/set/loaded")
timeErr = errorString("sqlite3: invalid time value")
tailErr = errorString("sqlite3: non-empty tail")
notImplErr = errorString("sqlite3: not implemented")
)

View File

@@ -225,3 +225,49 @@ func TestConn_Prepare_invalid(t *testing.T) {
t.Error("got message: ", got)
}
}
func TestConn_MustPrepare_empty(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
stmt := db.MustPrepare(``)
defer stmt.Close()
if stmt != nil {
t.Error("want nil")
}
}
func TestConn_MustPrepare_tail(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(`SELECT 1; -- HERE`)
t.Error("want panic")
}
func TestConn_MustPrepare_invalid(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(`SELECT`)
t.Error("want panic")
}