mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Read-only transactions, locking.
This commit is contained in:
@@ -144,6 +144,11 @@ func testTxQuery(t params) {
|
||||
func TestPreparedStmt_SQLite(t *testing.T) { sqlite.RunTest(t, testPreparedStmt) }
|
||||
|
||||
func testPreparedStmt(t params) {
|
||||
if testing.Short() {
|
||||
t.Logf("skipping in short mode")
|
||||
return
|
||||
}
|
||||
|
||||
t.mustExec("CREATE TABLE " + TablePrefix + "t (count INT)")
|
||||
sel, err := t.Prepare("SELECT count FROM " + TablePrefix + "t ORDER BY count DESC")
|
||||
if err != nil {
|
||||
|
||||
@@ -20,34 +20,48 @@ func init() {
|
||||
type sqlite struct{}
|
||||
|
||||
func (sqlite) Open(name string) (driver.Conn, error) {
|
||||
u, err := url.Parse(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var txBegin = "BEGIN "
|
||||
var pragmas strings.Builder
|
||||
for _, p := range u.Query()["_pragma"] {
|
||||
pragmas.WriteString(`PRAGMA `)
|
||||
pragmas.WriteString(p)
|
||||
pragmas.WriteByte(';')
|
||||
if _, after, ok := strings.Cut(name, "?"); ok {
|
||||
query, _ := url.ParseQuery(after)
|
||||
|
||||
switch v := query.Get("_txlock"); v {
|
||||
case "deferred", "immediate", "exclusive":
|
||||
txBegin += v
|
||||
}
|
||||
|
||||
for _, p := range query["_pragma"] {
|
||||
pragmas.WriteString(`PRAGMA `)
|
||||
pragmas.WriteString(p)
|
||||
pragmas.WriteByte(';')
|
||||
}
|
||||
}
|
||||
if pragmas.Len() == 0 {
|
||||
pragmas.WriteString(`PRAGMA locking_mode=normal;`)
|
||||
pragmas.WriteString(`PRAGMA busy_timeout=60000;`)
|
||||
}
|
||||
|
||||
err = c.Exec(pragmas.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn{c, pragmas.String()}, nil
|
||||
return conn{
|
||||
conn: c,
|
||||
txBegin: txBegin,
|
||||
pragmas: pragmas.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
conn *sqlite3.Conn
|
||||
pragmas string
|
||||
conn *sqlite3.Conn
|
||||
pragmas string
|
||||
txBegin string
|
||||
txRollback bool
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -55,7 +69,7 @@ var (
|
||||
_ driver.Validator = conn{}
|
||||
_ driver.SessionResetter = conn{}
|
||||
_ driver.ExecerContext = conn{}
|
||||
// _ driver.ConnBeginTx = conn{}
|
||||
_ driver.ConnBeginTx = conn{}
|
||||
)
|
||||
|
||||
func (c conn) Close() error {
|
||||
@@ -73,7 +87,27 @@ func (c conn) ResetSession(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (c conn) Begin() (driver.Tx, error) {
|
||||
err := c.conn.Exec(`BEGIN`)
|
||||
return c.BeginTx(context.Background(), driver.TxOptions{})
|
||||
}
|
||||
|
||||
func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
switch opts.Isolation {
|
||||
default:
|
||||
return nil, isolationErr
|
||||
case driver.IsolationLevel(sql.LevelDefault):
|
||||
case driver.IsolationLevel(sql.LevelSerializable):
|
||||
}
|
||||
|
||||
txBegin := c.txBegin
|
||||
if opts.ReadOnly {
|
||||
txBegin = `
|
||||
BEGIN DEFERRED;
|
||||
PRAGMA query_only=on;
|
||||
`
|
||||
}
|
||||
c.txRollback = opts.ReadOnly
|
||||
|
||||
err := c.conn.Exec(txBegin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -81,6 +115,9 @@ func (c conn) Begin() (driver.Tx, error) {
|
||||
}
|
||||
|
||||
func (c conn) Commit() error {
|
||||
if c.txRollback {
|
||||
return c.Rollback()
|
||||
}
|
||||
err := c.conn.Exec(`COMMIT`)
|
||||
if err != nil {
|
||||
c.Rollback()
|
||||
|
||||
@@ -5,6 +5,7 @@ type errorString string
|
||||
func (e errorString) Error() string { return string(e) }
|
||||
|
||||
const (
|
||||
assertErr = errorString("sqlite3: assertion failed")
|
||||
tailErr = errorString("sqlite3: multiple statements")
|
||||
assertErr = errorString("sqlite3: assertion failed")
|
||||
tailErr = errorString("sqlite3: multiple statements")
|
||||
isolationErr = errorString("sqlite3: unsupport isolation level")
|
||||
)
|
||||
|
||||
12
util_test.go
12
util_test.go
@@ -19,7 +19,7 @@ func Test_emptyStatement(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := emptyStatement(tt.stmt); got != tt.want {
|
||||
t.Errorf("emptyStatement(%q) = %v, want %v", tt.stmt, got, tt.want)
|
||||
t.Errorf("got %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -29,6 +29,7 @@ func Fuzz_emptyStatement(f *testing.F) {
|
||||
f.Add("")
|
||||
f.Add(" ")
|
||||
f.Add(";\n ")
|
||||
f.Add("; ;\v")
|
||||
f.Add("BEGIN")
|
||||
f.Add("SELECT 1;")
|
||||
|
||||
@@ -41,12 +42,15 @@ func Fuzz_emptyStatement(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, sql string) {
|
||||
// If empty, SQLite parses it as empty.
|
||||
if emptyStatement(sql) {
|
||||
stmt, _, err := db.Prepare(sql)
|
||||
stmt, tail, err := db.Prepare(sql)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
t.Errorf("%q, %v", sql, err)
|
||||
}
|
||||
if stmt != nil {
|
||||
t.Error(stmt)
|
||||
t.Errorf("%q, %v", sql, stmt)
|
||||
}
|
||||
if tail != "" {
|
||||
t.Errorf("%q", sql)
|
||||
}
|
||||
stmt.Close()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user