From e6cd0aaf8781941c28e7e9471f99fa6840d1fba4 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sat, 25 Feb 2023 01:29:46 +0000 Subject: [PATCH] MustPrepare. --- conn.go | 15 ++++++++++++++- error.go | 3 ++- tests/conn_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index 6bc1037..640a6fa 100644 --- a/conn.go +++ b/conn.go @@ -103,6 +103,19 @@ func (c *Conn) Exec(sql string) error { return c.error(r[0]) } +// MustPrepare calls [Conn.Prepare] and panics on error, +// or a non empty tail. +func (c *Conn) MustPrepare(sql string) *Stmt { + s, tail, err := c.PrepareFlags(sql, 0) + if err != nil { + panic(err) + } + if !emptyStatement(tail) { + panic(tailErr) + } + return s +} + // Prepare calls [Conn.PrepareFlags] with no flags. func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) { return c.PrepareFlags(sql, 0) @@ -205,7 +218,7 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { // Creating an uncompleted SQL statement prevents SQLite from ignoring // an interrupt that comes before any other statements are started. if c.pending == nil { - c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`) + c.pending = c.MustPrepare(`SELECT 1 UNION ALL SELECT 2`) c.pending.Step() } else { c.pending.Reset() diff --git a/error.go b/error.go index 8e45f0b..74bc287 100644 --- a/error.go +++ b/error.go @@ -65,14 +65,15 @@ type errorString string func (e errorString) Error() string { return string(e) } const ( - binaryErr = errorString("sqlite3: no SQLite binary embed/set/loaded") nilErr = errorString("sqlite3: invalid memory address or null pointer dereference") oomErr = errorString("sqlite3: out of memory") rangeErr = errorString("sqlite3: index out of range") noNulErr = errorString("sqlite3: missing NUL terminator") noGlobalErr = errorString("sqlite3: could not find global: ") noFuncErr = errorString("sqlite3: could not find function: ") + binaryErr = errorString("sqlite3: no SQLite binary embed/set/loaded") timeErr = errorString("sqlite3: invalid time value") + tailErr = errorString("sqlite3: non-empty tail") notImplErr = errorString("sqlite3: not implemented") ) diff --git a/tests/conn_test.go b/tests/conn_test.go index bcab3cc..536c92c 100644 --- a/tests/conn_test.go +++ b/tests/conn_test.go @@ -225,3 +225,49 @@ func TestConn_Prepare_invalid(t *testing.T) { t.Error("got message: ", got) } } + +func TestConn_MustPrepare_empty(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + defer func() { _ = recover() }() + stmt := db.MustPrepare(``) + defer stmt.Close() + + if stmt != nil { + t.Error("want nil") + } +} + +func TestConn_MustPrepare_tail(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + defer func() { _ = recover() }() + _ = db.MustPrepare(`SELECT 1; -- HERE`) + t.Error("want panic") +} + +func TestConn_MustPrepare_invalid(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + defer func() { _ = recover() }() + _ = db.MustPrepare(`SELECT`) + t.Error("want panic") +}