diff --git a/conn.go b/conn.go index 18fcb61..b90ed51 100644 --- a/conn.go +++ b/conn.go @@ -68,6 +68,8 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) { // open blob handles, and/or unfinished backup objects, // Close will leave the database connection open and return [BUSY]. // +// It is safe to close a nil, zero or closed connection. +// // https://www.sqlite.org/c3ref/close.html func (c *Conn) Close() error { if c == nil || c.handle == 0 { @@ -179,6 +181,10 @@ func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) { // // https://www.sqlite.org/c3ref/prepare.html func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) { + if emptyStatement(sql) { + return nil, "", nil + } + defer c.arena.reset() stmtPtr := c.arena.new(ptrlen) tailPtr := c.arena.new(ptrlen) diff --git a/driver/driver.go b/driver/driver.go index 4a17839..8f04b8a 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -99,12 +99,14 @@ func (c conn) Prepare(query string) (driver.Stmt, error) { } if tail != "" { // Check if the tail contains any SQL. - s, _, err := c.conn.Prepare(tail) + st, _, err := c.conn.Prepare(tail) if err != nil { + s.Close() return nil, err } - if s != nil { + if st != nil { s.Close() + st.Close() return nil, tailErr } } diff --git a/driver/time.go b/driver/time.go index c8f592e..0a9a5b9 100644 --- a/driver/time.go +++ b/driver/time.go @@ -9,8 +9,22 @@ import ( // if it roundtrips back to the same string. // This way times can be persisted to, and recovered from, the database, // but if a string is needed, [database.sql] will recover the same string. -// TODO: optimize and fuzz test. func maybeDate(text string) driver.Value { + // Weed out (some) values that can't possibly be + // [time.RFC3339Nano] timestamps. + if len(text) < len("2006-01-02T15:04:05") { + return text + } + if text[4] != '-' || text[10] != 'T' || text[16] != ':' { + return text + } + for _, c := range []byte(text[:4]) { + if c < '0' || '9' < c { + return text + } + } + + // Slow path. date, err := time.Parse(time.RFC3339Nano, text) if err == nil && date.Format(time.RFC3339Nano) == text { return date diff --git a/driver/time_test.go b/driver/time_test.go new file mode 100644 index 0000000..7fd3711 --- /dev/null +++ b/driver/time_test.go @@ -0,0 +1,42 @@ +package driver + +import ( + "testing" + "time" +) + +func Fuzz_maybeDate(f *testing.F) { + f.Add("") + f.Add(" ") + f.Add("SQLite") + f.Add(time.RFC3339) + f.Add(time.RFC3339Nano) + f.Add(time.Layout) + f.Add(time.DateTime) + f.Add(time.DateOnly) + f.Add(time.TimeOnly) + + f.Fuzz(func(t *testing.T, str string) { + value := maybeDate(str) + + switch v := value.(type) { + case time.Time: + // Make sure times round-trip to the same string: + // https://pkg.go.dev/database/sql#Rows.Scan + if v.Format(time.RFC3339Nano) != str { + t.Fatalf("did not round-trip: %q", str) + } + case string: + if v != str { + t.Fatalf("did not round-trip: %q", str) + } + + date, err := time.Parse(time.RFC3339Nano, str) + if err == nil && date.Format(time.RFC3339Nano) == str { + t.Fatalf("would round-trip: %q", str) + } + default: + t.Fatalf("invalid type %T: %q", v, str) + } + }) +} diff --git a/mock_test.go b/mock_test.go new file mode 100644 index 0000000..fcb5ccc --- /dev/null +++ b/mock_test.go @@ -0,0 +1,161 @@ +package sqlite3 + +import ( + "context" + "encoding/binary" + "math" + + "github.com/tetratelabs/wazero/api" +) + +func init() { + Path = "./embed/sqlite3.wasm" +} + +func newMemory(size uint32) memory { + mem := make(mockMemory, size) + return memory{mockModule{&mem}} +} + +type mockModule struct { + memory api.Memory +} + +func (m mockModule) Memory() api.Memory { return m.memory } +func (m mockModule) String() string { return "mockModule" } +func (m mockModule) Name() string { return "mockModule" } + +func (m mockModule) ExportedGlobal(name string) api.Global { return nil } +func (m mockModule) ExportedMemory(name string) api.Memory { return nil } +func (m mockModule) ExportedFunction(name string) api.Function { return nil } +func (m mockModule) ExportedMemoryDefinitions() map[string]api.MemoryDefinition { return nil } +func (m mockModule) ExportedFunctionDefinitions() map[string]api.FunctionDefinition { return nil } +func (m mockModule) CloseWithExitCode(ctx context.Context, exitCode uint32) error { return nil } +func (m mockModule) Close(context.Context) error { return nil } + +type mockMemory []byte + +func (m mockMemory) Definition() api.MemoryDefinition { return nil } + +func (m mockMemory) Size() uint32 { return uint32(len(m)) } + +func (m mockMemory) ReadByte(offset uint32) (byte, bool) { + if offset >= m.Size() { + return 0, false + } + return m[offset], true +} + +func (m mockMemory) ReadUint16Le(offset uint32) (uint16, bool) { + if !m.hasSize(offset, 2) { + return 0, false + } + return binary.LittleEndian.Uint16(m[offset : offset+2]), true +} + +func (m mockMemory) ReadUint32Le(offset uint32) (uint32, bool) { + if !m.hasSize(offset, 4) { + return 0, false + } + return binary.LittleEndian.Uint32(m[offset : offset+4]), true +} + +func (m mockMemory) ReadFloat32Le(offset uint32) (float32, bool) { + v, ok := m.ReadUint32Le(offset) + if !ok { + return 0, false + } + return math.Float32frombits(v), true +} + +func (m mockMemory) ReadUint64Le(offset uint32) (uint64, bool) { + if !m.hasSize(offset, 8) { + return 0, false + } + return binary.LittleEndian.Uint64(m[offset : offset+8]), true +} + +func (m mockMemory) ReadFloat64Le(offset uint32) (float64, bool) { + v, ok := m.ReadUint64Le(offset) + if !ok { + return 0, false + } + return math.Float64frombits(v), true +} + +func (m mockMemory) Read(offset, byteCount uint32) ([]byte, bool) { + if !m.hasSize(offset, byteCount) { + return nil, false + } + return m[offset : offset+byteCount : offset+byteCount], true +} + +func (m mockMemory) WriteByte(offset uint32, v byte) bool { + if offset >= m.Size() { + return false + } + m[offset] = v + return true +} + +func (m mockMemory) WriteUint16Le(offset uint32, v uint16) bool { + if !m.hasSize(offset, 2) { + return false + } + binary.LittleEndian.PutUint16(m[offset:], v) + return true +} + +func (m mockMemory) WriteUint32Le(offset, v uint32) bool { + if !m.hasSize(offset, 4) { + return false + } + binary.LittleEndian.PutUint32(m[offset:], v) + return true +} + +func (m mockMemory) WriteFloat32Le(offset uint32, v float32) bool { + return m.WriteUint32Le(offset, math.Float32bits(v)) +} + +func (m mockMemory) WriteUint64Le(offset uint32, v uint64) bool { + if !m.hasSize(offset, 8) { + return false + } + binary.LittleEndian.PutUint64(m[offset:], v) + return true +} + +func (m mockMemory) WriteFloat64Le(offset uint32, v float64) bool { + return m.WriteUint64Le(offset, math.Float64bits(v)) +} + +func (m mockMemory) Write(offset uint32, val []byte) bool { + if !m.hasSize(offset, uint32(len(val))) { + return false + } + copy(m[offset:], val) + return true +} + +func (m mockMemory) WriteString(offset uint32, val string) bool { + if !m.hasSize(offset, uint32(len(val))) { + return false + } + copy(m[offset:], val) + return true +} + +func (m *mockMemory) Grow(delta uint32) (result uint32, ok bool) { + prev := (len(*m) + 65535) / 65536 + *m = append(*m, make([]byte, 65536*delta)...) + return uint32(prev), true +} + +func (m mockMemory) PageSize() (result uint32) { + return uint32(len(m) / 65536) +} + +func (m mockMemory) hasSize(offset uint32, byteCount uint32) bool { + return uint64(offset)+uint64(byteCount) <= uint64(len(m)) +} diff --git a/stmt.go b/stmt.go index 7f1d93e..419c295 100644 --- a/stmt.go +++ b/stmt.go @@ -15,6 +15,8 @@ type Stmt struct { // Close destroys the prepared statement object. // +// It is safe to close a nil, zero or closed prepared statement. +// // https://www.sqlite.org/c3ref/finalize.html func (s *Stmt) Close() error { if s == nil || s.handle == 0 { diff --git a/util.go b/util.go new file mode 100644 index 0000000..7a5bf13 --- /dev/null +++ b/util.go @@ -0,0 +1,16 @@ +package sqlite3 + +// Return true if stmt is an empty SQL statement. +// This is used as an optimization. +// It's OK to always return false here. +func emptyStatement(stmt string) bool { + for _, b := range []byte(stmt) { + switch b { + case ' ', '\n', '\r', '\t', '\v', '\f': + case ';': + default: + return false + } + } + return true +} diff --git a/util_test.go b/util_test.go index fcb5ccc..48e6b95 100644 --- a/util_test.go +++ b/util_test.go @@ -1,161 +1,54 @@ package sqlite3 import ( - "context" - "encoding/binary" - "math" - - "github.com/tetratelabs/wazero/api" + "testing" ) -func init() { - Path = "./embed/sqlite3.wasm" -} - -func newMemory(size uint32) memory { - mem := make(mockMemory, size) - return memory{mockModule{&mem}} -} - -type mockModule struct { - memory api.Memory -} - -func (m mockModule) Memory() api.Memory { return m.memory } -func (m mockModule) String() string { return "mockModule" } -func (m mockModule) Name() string { return "mockModule" } - -func (m mockModule) ExportedGlobal(name string) api.Global { return nil } -func (m mockModule) ExportedMemory(name string) api.Memory { return nil } -func (m mockModule) ExportedFunction(name string) api.Function { return nil } -func (m mockModule) ExportedMemoryDefinitions() map[string]api.MemoryDefinition { return nil } -func (m mockModule) ExportedFunctionDefinitions() map[string]api.FunctionDefinition { return nil } -func (m mockModule) CloseWithExitCode(ctx context.Context, exitCode uint32) error { return nil } -func (m mockModule) Close(context.Context) error { return nil } - -type mockMemory []byte - -func (m mockMemory) Definition() api.MemoryDefinition { return nil } - -func (m mockMemory) Size() uint32 { return uint32(len(m)) } - -func (m mockMemory) ReadByte(offset uint32) (byte, bool) { - if offset >= m.Size() { - return 0, false +func Test_emptyStatement(t *testing.T) { + tests := []struct { + name string + stmt string + want bool + }{ + {"empty", "", true}, + {"space", " ", true}, + {"separator", ";\n ", true}, + {"begin", "BEGIN", false}, + {"select", "SELECT 1;", false}, } - return m[offset], true -} - -func (m mockMemory) ReadUint16Le(offset uint32) (uint16, bool) { - if !m.hasSize(offset, 2) { - return 0, false + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := emptyStatement(tt.stmt); got != tt.want { + t.Errorf("emptyStatement(%q) = %v, want %v", tt.stmt, got, tt.want) + } + }) } - return binary.LittleEndian.Uint16(m[offset : offset+2]), true } -func (m mockMemory) ReadUint32Le(offset uint32) (uint32, bool) { - if !m.hasSize(offset, 4) { - return 0, false +func Fuzz_emptyStatement(f *testing.F) { + f.Add("") + f.Add(" ") + f.Add(";\n ") + f.Add("BEGIN") + f.Add("SELECT 1;") + + db, err := Open(":memory:") + if err != nil { + f.Fatal(err) } - return binary.LittleEndian.Uint32(m[offset : offset+4]), true -} + defer db.Close() -func (m mockMemory) ReadFloat32Le(offset uint32) (float32, bool) { - v, ok := m.ReadUint32Le(offset) - if !ok { - return 0, false - } - return math.Float32frombits(v), true -} - -func (m mockMemory) ReadUint64Le(offset uint32) (uint64, bool) { - if !m.hasSize(offset, 8) { - return 0, false - } - return binary.LittleEndian.Uint64(m[offset : offset+8]), true -} - -func (m mockMemory) ReadFloat64Le(offset uint32) (float64, bool) { - v, ok := m.ReadUint64Le(offset) - if !ok { - return 0, false - } - return math.Float64frombits(v), true -} - -func (m mockMemory) Read(offset, byteCount uint32) ([]byte, bool) { - if !m.hasSize(offset, byteCount) { - return nil, false - } - return m[offset : offset+byteCount : offset+byteCount], true -} - -func (m mockMemory) WriteByte(offset uint32, v byte) bool { - if offset >= m.Size() { - return false - } - m[offset] = v - return true -} - -func (m mockMemory) WriteUint16Le(offset uint32, v uint16) bool { - if !m.hasSize(offset, 2) { - return false - } - binary.LittleEndian.PutUint16(m[offset:], v) - return true -} - -func (m mockMemory) WriteUint32Le(offset, v uint32) bool { - if !m.hasSize(offset, 4) { - return false - } - binary.LittleEndian.PutUint32(m[offset:], v) - return true -} - -func (m mockMemory) WriteFloat32Le(offset uint32, v float32) bool { - return m.WriteUint32Le(offset, math.Float32bits(v)) -} - -func (m mockMemory) WriteUint64Le(offset uint32, v uint64) bool { - if !m.hasSize(offset, 8) { - return false - } - binary.LittleEndian.PutUint64(m[offset:], v) - return true -} - -func (m mockMemory) WriteFloat64Le(offset uint32, v float64) bool { - return m.WriteUint64Le(offset, math.Float64bits(v)) -} - -func (m mockMemory) Write(offset uint32, val []byte) bool { - if !m.hasSize(offset, uint32(len(val))) { - return false - } - copy(m[offset:], val) - return true -} - -func (m mockMemory) WriteString(offset uint32, val string) bool { - if !m.hasSize(offset, uint32(len(val))) { - return false - } - copy(m[offset:], val) - return true -} - -func (m *mockMemory) Grow(delta uint32) (result uint32, ok bool) { - prev := (len(*m) + 65535) / 65536 - *m = append(*m, make([]byte, 65536*delta)...) - return uint32(prev), true -} - -func (m mockMemory) PageSize() (result uint32) { - return uint32(len(m) / 65536) -} - -func (m mockMemory) hasSize(offset uint32, byteCount uint32) bool { - return uint64(offset)+uint64(byteCount) <= uint64(len(m)) + f.Fuzz(func(t *testing.T, sql string) { + // If empty, SQLite parses it as empty. + if emptyStatement(sql) { + stmt, _, err := db.Prepare(sql) + if err != nil { + t.Error(err) + } + if stmt != nil { + t.Error(stmt) + } + stmt.Close() + } + }) }