From 4160b9a4bb4fdf8017530ec21ed128d39336818f Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Thu, 30 Nov 2023 12:26:15 +0000 Subject: [PATCH] Simplify tails. --- conn.go | 8 +++--- driver/driver.go | 13 ++-------- driver/driver_test.go | 24 ++++++------------ stmt.go | 15 ----------- stmt_test.go | 58 ------------------------------------------- 5 files changed, 12 insertions(+), 106 deletions(-) delete mode 100644 stmt_test.go diff --git a/conn.go b/conn.go index f690d68..2f50fa6 100644 --- a/conn.go +++ b/conn.go @@ -172,9 +172,6 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str if len(sql) > _MAX_LENGTH { return nil, "", TOOBIG } - if emptyStatement(sql) { - return nil, "", nil - } defer c.arena.mark()() stmtPtr := c.arena.new(ptrlen) @@ -187,8 +184,9 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str stmt = &Stmt{c: c} stmt.handle = util.ReadUint32(c.mod, stmtPtr) - i := util.ReadUint32(c.mod, tailPtr) - tail = sql[i-sqlPtr:] + if sql := sql[util.ReadUint32(c.mod, tailPtr)-sqlPtr:]; sql != "" { + tail = sql + } if err := c.error(r, sql); err != nil { return nil, "", err diff --git a/driver/driver.go b/driver/driver.go index daa6003..f867c67 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -247,17 +247,8 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e return nil, err } if tail != "" { - // Check if the tail contains any SQL. - st, _, err := c.Conn.Prepare(tail) - if err != nil { - s.Close() - return nil, err - } - if st != nil { - s.Close() - st.Close() - return nil, util.TailErr - } + s.Close() + return nil, util.TailErr } return &stmt{s, c.Conn}, nil } diff --git a/driver/driver_test.go b/driver/driver_test.go index 04541e4..cf32f74 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -186,12 +186,6 @@ func Test_Prepare(t *testing.T) { } defer db.Close() - stmt, err := db.Prepare(`SELECT 1; -- HERE`) - if err != nil { - t.Error(err) - } - defer stmt.Close() - var serr *sqlite3.Error _, err = db.Prepare(`SELECT`) if err == nil { @@ -207,18 +201,14 @@ func Test_Prepare(t *testing.T) { t.Error("got message:", got) } + _, err = db.Prepare(`SELECT 1; `) + if err.Error() != string(util.TailErr) { + t.Error("want tailErr") + } + _, err = db.Prepare(`SELECT 1; SELECT`) - if err == nil { - t.Error("want error") - } - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) - } - if rc := serr.Code(); rc != sqlite3.ERROR { - t.Errorf("got %d, want sqlite3.ERROR", rc) - } - if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` { - t.Error("got message:", got) + if err.Error() != string(util.TailErr) { + t.Error("want tailErr") } _, err = db.Prepare(`SELECT 1; SELECT 2`) diff --git a/stmt.go b/stmt.go index a0b06a3..743e0b8 100644 --- a/stmt.go +++ b/stmt.go @@ -501,18 +501,3 @@ func (s *Stmt) ColumnValue(col int) Value { handle: uint32(r), } } - -// Return true if stmt is an empty SQL statement. -// This is used as an optimization. -// It's OK to always return false here. -func emptyStatement(stmt string) bool { - for _, b := range []byte(stmt) { - switch b { - case ' ', '\n', '\r', '\t', '\v', '\f': - case ';': - default: - return false - } - } - return true -} diff --git a/stmt_test.go b/stmt_test.go deleted file mode 100644 index a670a02..0000000 --- a/stmt_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package sqlite3 - -import "testing" - -func Test_emptyStatement(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - stmt string - want bool - }{ - {"empty", "", true}, - {"space", " ", true}, - {"separator", ";\n ", true}, - {"begin", "BEGIN", false}, - {"select", "SELECT 1;", false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := emptyStatement(tt.stmt); got != tt.want { - t.Errorf("got %v, want %v", got, tt.want) - } - }) - } -} - -func Fuzz_emptyStatement(f *testing.F) { - f.Add("") - f.Add(" ") - f.Add(";\n ") - f.Add("; ;\v") - f.Add("BEGIN") - f.Add("SELECT 1;") - - db, err := Open(":memory:") - if err != nil { - f.Fatal(err) - } - defer db.Close() - - f.Fuzz(func(t *testing.T, sql string) { - // If empty, SQLite parses it as empty. - if emptyStatement(sql) { - stmt, tail, err := db.Prepare(sql) - if err != nil { - t.Errorf("%q, %v", sql, err) - } - if stmt != nil { - t.Errorf("%q, %v", sql, stmt) - } - if tail != "" { - t.Errorf("%q", sql) - } - stmt.Close() - } - }) -}