Fix readonly transaction rollback.

This commit is contained in:
Nuno Cruces
2023-03-08 18:05:18 +00:00
parent 926adeb3f5
commit 66a730893f
2 changed files with 42 additions and 16 deletions

View File

@@ -86,9 +86,10 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
}
type conn struct {
conn *sqlite3.Conn
txBegin string
txCommit string
conn *sqlite3.Conn
txBegin string
txCommit string
txRollback string
}
var (
@@ -107,26 +108,39 @@ func (c conn) Begin() (driver.Tx, error) {
}
func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, error) {
switch opts.Isolation {
default:
return nil, isolationErr
case driver.IsolationLevel(sql.LevelDefault):
case driver.IsolationLevel(sql.LevelSerializable):
}
txBegin := c.txBegin
c.txCommit = `COMMIT`
c.txRollback = `ROLLBACK`
if opts.ReadOnly {
query_only, err := c.conn.Pragma("query_only")
if err != nil {
return nil, err
}
c.txCommit = `
ROLLBACK;
PRAGMA query_only=` + query_only[0]
txBegin = `
BEGIN deferred;
PRAGMA query_only=on`
c.txCommit = `
ROLLBACK;
PRAGMA query_only=` + query_only[0]
c.txRollback = c.txCommit
}
switch opts.Isolation {
default:
return nil, isolationErr
case
driver.IsolationLevel(sql.LevelDefault),
driver.IsolationLevel(sql.LevelSerializable):
break
case driver.IsolationLevel(sql.LevelReadUncommitted):
read_uncommitted, err := c.conn.Pragma("read_uncommitted")
if err != nil {
return nil, err
}
txBegin += `; PRAGMA read_uncommitted=on`
c.txCommit += `; PRAGMA read_uncommitted=` + read_uncommitted[0]
c.txRollback += `; PRAGMA read_uncommitted=` + read_uncommitted[0]
}
err := c.conn.Exec(txBegin)
@@ -138,14 +152,14 @@ func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, erro
func (c conn) Commit() error {
err := c.conn.Exec(c.txCommit)
if err != nil {
if err != nil && !c.conn.GetAutocommit() {
c.Rollback()
}
return err
}
func (c conn) Rollback() error {
return c.conn.Exec(`ROLLBACK`)
return c.conn.Exec(c.txRollback)
}
func (c conn) Prepare(query string) (driver.Stmt, error) {

View File

@@ -134,7 +134,9 @@ func Test_BeginTx(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.db"))
db, err := sql.Open("sqlite3", "file:"+
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
"?_txlock=exclusive&_pragma=busy_timeout(0)")
if err != nil {
t.Fatal(err)
}
@@ -145,6 +147,16 @@ func Test_BeginTx(t *testing.T) {
t.Error("want isolationErr")
}
tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadUncommitted})
if err != nil {
t.Fatal(err)
}
err = tx.Rollback()
if err != nil {
t.Fatal(err)
}
tx1, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
t.Fatal(err)