diff --git a/driver/driver.go b/driver/driver.go index 8d0305b..14acc8c 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -45,12 +45,12 @@ func init() { type sqlite struct{} func (sqlite) Open(name string) (_ driver.Conn, err error) { - c, err := sqlite3.Open(name) + var c conn + c.conn, err = sqlite3.Open(name) if err != nil { return nil, err } - var txBegin string var pragmas []string if strings.HasPrefix(name, "file:") { if _, after, ok := strings.Cut(name, "?"); ok { @@ -58,9 +58,9 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) { switch s := query.Get("_txlock"); s { case "": - txBegin = "BEGIN" + c.txBegin = "BEGIN" case "deferred", "immediate", "exclusive": - txBegin = "BEGIN " + s + c.txBegin = "BEGIN " + s default: c.Close() return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s) @@ -70,20 +70,36 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) { } } if len(pragmas) == 0 { - err := c.Exec(` - PRAGMA busy_timeout=60000; + err := c.conn.Exec(` PRAGMA locking_mode=normal; + PRAGMA busy_timeout=60000; `) if err != nil { c.Close() return nil, err } + c.reusable = true + } else { + s, _, err := c.conn.Prepare(` + SELECT * FROM + PRAGMA_locking_mode, + PRAGMA_query_only; + `) + if err != nil { + c.Close() + return nil, err + } + if s.Step() { + c.reusable = s.ColumnText(0) == "normal" + c.readOnly = s.ColumnRawText(1)[0] // 0 or 1 + } + err = s.Close() + if err != nil { + c.Close() + return nil, err + } } - - return conn{ - conn: c, - txBegin: txBegin, - }, nil + return c, nil } type conn struct { @@ -91,6 +107,8 @@ type conn struct { txBegin string txCommit string txRollback string + reusable bool + readOnly byte } var ( @@ -105,9 +123,8 @@ func (c conn) Close() error { return c.conn.Close() } -func (c conn) IsValid() (valid bool) { - r, err := c.conn.Pragma("locking_mode") - return err == nil && len(r) == 1 && r[0] == "normal" +func (c conn) IsValid() bool { + return c.reusable } func (c conn) Begin() (driver.Tx, error) { @@ -120,16 +137,12 @@ func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, erro c.txRollback = `ROLLBACK` if opts.ReadOnly { - query_only, err := c.conn.Pragma("query_only") - if err != nil { - return nil, err - } txBegin = ` BEGIN deferred; PRAGMA query_only=on` c.txCommit = ` ROLLBACK; - PRAGMA query_only=` + query_only[0] + PRAGMA query_only=` + string(c.readOnly) c.txRollback = c.txCommit } @@ -140,14 +153,6 @@ func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, erro driver.IsolationLevel(sql.LevelDefault), driver.IsolationLevel(sql.LevelSerializable): break - case driver.IsolationLevel(sql.LevelReadUncommitted): - read_uncommitted, err := c.conn.Pragma("read_uncommitted") - if err != nil { - return nil, err - } - txBegin += `; PRAGMA read_uncommitted=on` - c.txCommit += `; PRAGMA read_uncommitted=` + read_uncommitted[0] - c.txRollback += `; PRAGMA read_uncommitted=` + read_uncommitted[0] } err := c.conn.Exec(txBegin) diff --git a/driver/driver_test.go b/driver/driver_test.go index 2656309..12a4cfb 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -147,16 +147,6 @@ func Test_BeginTx(t *testing.T) { t.Error("want isolationErr") } - tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadUncommitted}) - if err != nil { - t.Fatal(err) - } - - err = tx.Rollback() - if err != nil { - t.Fatal(err) - } - tx1, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) if err != nil { t.Fatal(err)