From 231d3a04388ad05a1e2cb4772e928187e8c742ae Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sun, 19 Feb 2023 16:16:13 +0000 Subject: [PATCH] Read-only transactions, locking. --- driver/bradfitz_test.go | 5 ++++ driver/driver.go | 63 ++++++++++++++++++++++++++++++++--------- driver/error.go | 5 ++-- util_test.go | 12 +++++--- 4 files changed, 66 insertions(+), 19 deletions(-) diff --git a/driver/bradfitz_test.go b/driver/bradfitz_test.go index 1793d67..005fcc0 100644 --- a/driver/bradfitz_test.go +++ b/driver/bradfitz_test.go @@ -144,6 +144,11 @@ func testTxQuery(t params) { func TestPreparedStmt_SQLite(t *testing.T) { sqlite.RunTest(t, testPreparedStmt) } func testPreparedStmt(t params) { + if testing.Short() { + t.Logf("skipping in short mode") + return + } + t.mustExec("CREATE TABLE " + TablePrefix + "t (count INT)") sel, err := t.Prepare("SELECT count FROM " + TablePrefix + "t ORDER BY count DESC") if err != nil { diff --git a/driver/driver.go b/driver/driver.go index 8f04b8a..ac4d2e3 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -20,34 +20,48 @@ func init() { type sqlite struct{} func (sqlite) Open(name string) (driver.Conn, error) { - u, err := url.Parse(name) - if err != nil { - return nil, err - } c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE) if err != nil { return nil, err } + + var txBegin = "BEGIN " var pragmas strings.Builder - for _, p := range u.Query()["_pragma"] { - pragmas.WriteString(`PRAGMA `) - pragmas.WriteString(p) - pragmas.WriteByte(';') + if _, after, ok := strings.Cut(name, "?"); ok { + query, _ := url.ParseQuery(after) + + switch v := query.Get("_txlock"); v { + case "deferred", "immediate", "exclusive": + txBegin += v + } + + for _, p := range query["_pragma"] { + pragmas.WriteString(`PRAGMA `) + pragmas.WriteString(p) + pragmas.WriteByte(';') + } } if pragmas.Len() == 0 { pragmas.WriteString(`PRAGMA locking_mode=normal;`) pragmas.WriteString(`PRAGMA busy_timeout=60000;`) } + err = c.Exec(pragmas.String()) if err != nil { return nil, err } - return conn{c, pragmas.String()}, nil + return conn{ + conn: c, + txBegin: txBegin, + pragmas: pragmas.String(), + }, nil } type conn struct { - conn *sqlite3.Conn - pragmas string + conn *sqlite3.Conn + pragmas string + txBegin string + txRollback bool } var ( @@ -55,7 +69,7 @@ var ( _ driver.Validator = conn{} _ driver.SessionResetter = conn{} _ driver.ExecerContext = conn{} - // _ driver.ConnBeginTx = conn{} + _ driver.ConnBeginTx = conn{} ) func (c conn) Close() error { @@ -73,7 +87,27 @@ func (c conn) ResetSession(ctx context.Context) error { } func (c conn) Begin() (driver.Tx, error) { - err := c.conn.Exec(`BEGIN`) + return c.BeginTx(context.Background(), driver.TxOptions{}) +} + +func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + switch opts.Isolation { + default: + return nil, isolationErr + case driver.IsolationLevel(sql.LevelDefault): + case driver.IsolationLevel(sql.LevelSerializable): + } + + txBegin := c.txBegin + if opts.ReadOnly { + txBegin = ` + BEGIN DEFERRED; + PRAGMA query_only=on; + ` + } + c.txRollback = opts.ReadOnly + + err := c.conn.Exec(txBegin) if err != nil { return nil, err } @@ -81,6 +115,9 @@ func (c conn) Begin() (driver.Tx, error) { } func (c conn) Commit() error { + if c.txRollback { + return c.Rollback() + } err := c.conn.Exec(`COMMIT`) if err != nil { c.Rollback() diff --git a/driver/error.go b/driver/error.go index 240183f..eee8cf9 100644 --- a/driver/error.go +++ b/driver/error.go @@ -5,6 +5,7 @@ type errorString string func (e errorString) Error() string { return string(e) } const ( - assertErr = errorString("sqlite3: assertion failed") - tailErr = errorString("sqlite3: multiple statements") + assertErr = errorString("sqlite3: assertion failed") + tailErr = errorString("sqlite3: multiple statements") + isolationErr = errorString("sqlite3: unsupport isolation level") ) diff --git a/util_test.go b/util_test.go index 48e6b95..391e06f 100644 --- a/util_test.go +++ b/util_test.go @@ -19,7 +19,7 @@ func Test_emptyStatement(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := emptyStatement(tt.stmt); got != tt.want { - t.Errorf("emptyStatement(%q) = %v, want %v", tt.stmt, got, tt.want) + t.Errorf("got %v, want %v", got, tt.want) } }) } @@ -29,6 +29,7 @@ func Fuzz_emptyStatement(f *testing.F) { f.Add("") f.Add(" ") f.Add(";\n ") + f.Add("; ;\v") f.Add("BEGIN") f.Add("SELECT 1;") @@ -41,12 +42,15 @@ func Fuzz_emptyStatement(f *testing.F) { f.Fuzz(func(t *testing.T, sql string) { // If empty, SQLite parses it as empty. if emptyStatement(sql) { - stmt, _, err := db.Prepare(sql) + stmt, tail, err := db.Prepare(sql) if err != nil { - t.Error(err) + t.Errorf("%q, %v", sql, err) } if stmt != nil { - t.Error(stmt) + t.Errorf("%q, %v", sql, stmt) + } + if tail != "" { + t.Errorf("%q", sql) } stmt.Close() }