From a51cdb04e6bed927a4d4badac969c70766a34ceb Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sat, 18 Feb 2023 02:57:47 +0000 Subject: [PATCH] Exec fast path. --- driver/driver.go | 41 ++++++++++++++++++++++++++++++++++++----- driver/error.go | 5 ++++- driver/example_test.go | 10 +++------- 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/driver/driver.go b/driver/driver.go index 091f016..fe0d989 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -18,7 +18,7 @@ func init() { type sqlite struct{} func (sqlite) Open(name string) (driver.Conn, error) { - c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI) + c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE) if err != nil { return nil, err } @@ -38,10 +38,10 @@ type conn struct{ conn *sqlite3.Conn } var ( // Ensure these interfaces are implemented: - _ driver.Validator = conn{} - // _ driver.SessionResetter = conn{} - // _ driver.ExecerContext = conn{} + _ driver.Validator = conn{} + _ driver.ExecerContext = conn{} // _ driver.ConnBeginTx = conn{} + // _ driver.SessionResetter = conn{} ) func (c conn) Close() error { @@ -75,13 +75,44 @@ func (c conn) Rollback() error { } func (c conn) Prepare(query string) (driver.Stmt, error) { - s, _, err := c.conn.Prepare(query) + s, tail, err := c.conn.Prepare(query) if err != nil { return nil, err } + if tail != "" { + // Check if the tail contains any SQL. + s, _, err := c.conn.Prepare(tail) + if err != nil { + return nil, err + } + if s != nil { + s.Close() + return nil, tailErr + } + } return stmt{s, c.conn}, nil } +func (c conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + if len(args) != 0 { + // Slow path. + return nil, driver.ErrSkip + } + + ch := c.conn.SetInterrupt(ctx.Done()) + defer c.conn.SetInterrupt(ch) + + err := c.conn.Exec(query) + if err != nil { + return nil, err + } + + return result{ + int64(c.conn.LastInsertRowID()), + int64(c.conn.Changes()), + }, nil +} + func pragma(c *sqlite3.Conn, pragma string) (string, error) { stmt, _, err := c.Prepare(`PRAGMA ` + pragma) if err != nil { diff --git a/driver/error.go b/driver/error.go index cc8b2fd..240183f 100644 --- a/driver/error.go +++ b/driver/error.go @@ -4,4 +4,7 @@ type errorString string func (e errorString) Error() string { return string(e) } -const assertErr = errorString("sqlite3: assertion failed") +const ( + assertErr = errorString("sqlite3: assertion failed") + tailErr = errorString("sqlite3: multiple statements") +) diff --git a/driver/example_test.go b/driver/example_test.go index e396bd5..749db76 100644 --- a/driver/example_test.go +++ b/driver/example_test.go @@ -66,18 +66,14 @@ type Album struct { } func setupDatabase() error { - _, err := db.Exec(`DROP TABLE IF EXISTS album`) - if err != nil { - return err - } - - _, err = db.Exec(` + _, err := db.Exec(` + DROP TABLE IF EXISTS album; CREATE TABLE album ( id INTEGER PRIMARY KEY AUTOINCREMENT, title VARCHAR(128) NOT NULL, artist VARCHAR(255) NOT NULL, price DECIMAL(5,2) NOT NULL - ) + ); `) if err != nil { return err