Simplify tails.

This commit is contained in:
Nuno Cruces
2023-11-30 12:26:15 +00:00
parent dbaf2d99cd
commit 4160b9a4bb
5 changed files with 12 additions and 106 deletions

View File

@@ -172,9 +172,6 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
if len(sql) > _MAX_LENGTH {
return nil, "", TOOBIG
}
if emptyStatement(sql) {
return nil, "", nil
}
defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen)
@@ -187,8 +184,9 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
stmt = &Stmt{c: c}
stmt.handle = util.ReadUint32(c.mod, stmtPtr)
i := util.ReadUint32(c.mod, tailPtr)
tail = sql[i-sqlPtr:]
if sql := sql[util.ReadUint32(c.mod, tailPtr)-sqlPtr:]; sql != "" {
tail = sql
}
if err := c.error(r, sql); err != nil {
return nil, "", err

View File

@@ -247,17 +247,8 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
return nil, err
}
if tail != "" {
// Check if the tail contains any SQL.
st, _, err := c.Conn.Prepare(tail)
if err != nil {
s.Close()
return nil, err
}
if st != nil {
s.Close()
st.Close()
return nil, util.TailErr
}
s.Close()
return nil, util.TailErr
}
return &stmt{s, c.Conn}, nil
}

View File

@@ -186,12 +186,6 @@ func Test_Prepare(t *testing.T) {
}
defer db.Close()
stmt, err := db.Prepare(`SELECT 1; -- HERE`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
var serr *sqlite3.Error
_, err = db.Prepare(`SELECT`)
if err == nil {
@@ -207,18 +201,14 @@ func Test_Prepare(t *testing.T) {
t.Error("got message:", got)
}
_, err = db.Prepare(`SELECT 1; `)
if err.Error() != string(util.TailErr) {
t.Error("want tailErr")
}
_, err = db.Prepare(`SELECT 1; SELECT`)
if err == nil {
t.Error("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message:", got)
if err.Error() != string(util.TailErr) {
t.Error("want tailErr")
}
_, err = db.Prepare(`SELECT 1; SELECT 2`)

15
stmt.go
View File

@@ -501,18 +501,3 @@ func (s *Stmt) ColumnValue(col int) Value {
handle: uint32(r),
}
}
// Return true if stmt is an empty SQL statement.
// This is used as an optimization.
// It's OK to always return false here.
func emptyStatement(stmt string) bool {
for _, b := range []byte(stmt) {
switch b {
case ' ', '\n', '\r', '\t', '\v', '\f':
case ';':
default:
return false
}
}
return true
}

View File

@@ -1,58 +0,0 @@
package sqlite3
import "testing"
func Test_emptyStatement(t *testing.T) {
t.Parallel()
tests := []struct {
name string
stmt string
want bool
}{
{"empty", "", true},
{"space", " ", true},
{"separator", ";\n ", true},
{"begin", "BEGIN", false},
{"select", "SELECT 1;", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := emptyStatement(tt.stmt); got != tt.want {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}
func Fuzz_emptyStatement(f *testing.F) {
f.Add("")
f.Add(" ")
f.Add(";\n ")
f.Add("; ;\v")
f.Add("BEGIN")
f.Add("SELECT 1;")
db, err := Open(":memory:")
if err != nil {
f.Fatal(err)
}
defer db.Close()
f.Fuzz(func(t *testing.T, sql string) {
// If empty, SQLite parses it as empty.
if emptyStatement(sql) {
stmt, tail, err := db.Prepare(sql)
if err != nil {
t.Errorf("%q, %v", sql, err)
}
if stmt != nil {
t.Errorf("%q, %v", sql, stmt)
}
if tail != "" {
t.Errorf("%q", sql)
}
stmt.Close()
}
})
}