This commit is contained in:
Nuno Cruces
2023-02-28 16:03:31 +00:00
parent 54046b6adc
commit e64bffa520
3 changed files with 62 additions and 28 deletions

29
conn.go
View File

@@ -3,7 +3,10 @@ package sqlite3
import (
"context"
"database/sql/driver"
"fmt"
"math"
"net/url"
"strings"
"sync"
"github.com/tetratelabs/wazero/api"
@@ -25,13 +28,17 @@ type Conn struct {
pending *Stmt
}
// Open calls [OpenFlags] with [OPEN_READWRITE] and [OPEN_CREATE].
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI].
func Open(filename string) (conn *Conn, err error) {
return OpenFlags(filename, OPEN_READWRITE|OPEN_CREATE)
return OpenFlags(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
}
// OpenFlags opens an SQLite database file as specified by the filename argument.
//
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
//
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
//
// https://www.sqlite.org/c3ref/open.html
func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
ctx := context.Background()
@@ -61,6 +68,21 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
if err := c.error(r[0]); err != nil {
return nil, err
}
if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
var pragmas strings.Builder
if _, after, ok := strings.Cut(filename, "?"); ok {
query, _ := url.ParseQuery(after)
for _, p := range query["_pragma"] {
pragmas.WriteString(`PRAGMA `)
pragmas.WriteString(p)
pragmas.WriteByte(';')
}
}
if err := c.Exec(pragmas.String()); err != nil {
return nil, fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
}
return c, nil
}
@@ -276,9 +298,6 @@ func (c *Conn) Pragma(str string) []string {
for stmt.Step() {
pragmas = append(pragmas, stmt.ColumnText(0))
}
if err := stmt.Err(); err != nil {
panic(err)
}
return pragmas
}

View File

@@ -43,41 +43,42 @@ func init() {
type sqlite struct{}
func (sqlite) Open(name string) (driver.Conn, error) {
func (sqlite) Open(name string) (_ driver.Conn, err error) {
c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE)
if err != nil {
return nil, err
}
var txBegin string
var pragmas strings.Builder
if _, after, ok := strings.Cut(name, "?"); ok {
query, _ := url.ParseQuery(after)
var pragmas []string
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
query, _ := url.ParseQuery(after)
switch s := query.Get("_txlock"); s {
case "":
txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
txBegin = "BEGIN " + s
default:
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
}
switch s := query.Get("_txlock"); s {
case "":
txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
txBegin = "BEGIN " + s
default:
c.Close()
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
}
for _, p := range query["_pragma"] {
pragmas.WriteString(`PRAGMA `)
pragmas.WriteString(p)
pragmas.WriteByte(';')
pragmas = query["_pragma"]
}
}
if pragmas.Len() == 0 {
pragmas.WriteString(`PRAGMA busy_timeout=60000;`)
pragmas.WriteString(`PRAGMA locking_mode=normal;`)
if len(pragmas) == 0 {
err := c.Exec(`
PRAGMA busy_timeout=60000;
PRAGMA locking_mode=normal;
`)
if err != nil {
c.Close()
return nil, err
}
}
err = c.Exec(pragmas.String())
if err != nil {
return nil, fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
return conn{
conn: c,
txBegin: txBegin,

View File

@@ -248,3 +248,17 @@ func TestConn_MustPrepare_invalid(t *testing.T) {
_ = db.MustPrepare(`SELECT`)
t.Error("want panic")
}
func TestConn_Pragma(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.Pragma("encoding=''")
t.Error("want panic")
}