From e64bffa520e1c9d3d7f47d0b85b4f098e85bb1c6 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Tue, 28 Feb 2023 16:03:31 +0000 Subject: [PATCH] Pragmas. --- conn.go | 29 +++++++++++++++++++++++----- driver/driver.go | 47 +++++++++++++++++++++++----------------------- tests/conn_test.go | 14 ++++++++++++++ 3 files changed, 62 insertions(+), 28 deletions(-) diff --git a/conn.go b/conn.go index b7193e2..cf0d1aa 100644 --- a/conn.go +++ b/conn.go @@ -3,7 +3,10 @@ package sqlite3 import ( "context" "database/sql/driver" + "fmt" "math" + "net/url" + "strings" "sync" "github.com/tetratelabs/wazero/api" @@ -25,13 +28,17 @@ type Conn struct { pending *Stmt } -// Open calls [OpenFlags] with [OPEN_READWRITE] and [OPEN_CREATE]. +// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI]. func Open(filename string) (conn *Conn, err error) { - return OpenFlags(filename, OPEN_READWRITE|OPEN_CREATE) + return OpenFlags(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI) } // OpenFlags opens an SQLite database file as specified by the filename argument. // +// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma": +// +// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)") +// // https://www.sqlite.org/c3ref/open.html func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) { ctx := context.Background() @@ -61,6 +68,21 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) { if err := c.error(r[0]); err != nil { return nil, err } + + if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") { + var pragmas strings.Builder + if _, after, ok := strings.Cut(filename, "?"); ok { + query, _ := url.ParseQuery(after) + for _, p := range query["_pragma"] { + pragmas.WriteString(`PRAGMA `) + pragmas.WriteString(p) + pragmas.WriteByte(';') + } + } + if err := c.Exec(pragmas.String()); err != nil { + return nil, fmt.Errorf("sqlite3: invalid _pragma: %w", err) + } + } return c, nil } @@ -276,9 +298,6 @@ func (c *Conn) Pragma(str string) []string { for stmt.Step() { pragmas = append(pragmas, stmt.ColumnText(0)) } - if err := stmt.Err(); err != nil { - panic(err) - } return pragmas } diff --git a/driver/driver.go b/driver/driver.go index cea9e34..0c56212 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -43,41 +43,42 @@ func init() { type sqlite struct{} -func (sqlite) Open(name string) (driver.Conn, error) { +func (sqlite) Open(name string) (_ driver.Conn, err error) { c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE) if err != nil { return nil, err } var txBegin string - var pragmas strings.Builder - if _, after, ok := strings.Cut(name, "?"); ok { - query, _ := url.ParseQuery(after) + var pragmas []string + if strings.HasPrefix(name, "file:") { + if _, after, ok := strings.Cut(name, "?"); ok { + query, _ := url.ParseQuery(after) - switch s := query.Get("_txlock"); s { - case "": - txBegin = "BEGIN" - case "deferred", "immediate", "exclusive": - txBegin = "BEGIN " + s - default: - return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s) - } + switch s := query.Get("_txlock"); s { + case "": + txBegin = "BEGIN" + case "deferred", "immediate", "exclusive": + txBegin = "BEGIN " + s + default: + c.Close() + return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s) + } - for _, p := range query["_pragma"] { - pragmas.WriteString(`PRAGMA `) - pragmas.WriteString(p) - pragmas.WriteByte(';') + pragmas = query["_pragma"] } } - if pragmas.Len() == 0 { - pragmas.WriteString(`PRAGMA busy_timeout=60000;`) - pragmas.WriteString(`PRAGMA locking_mode=normal;`) + if len(pragmas) == 0 { + err := c.Exec(` + PRAGMA busy_timeout=60000; + PRAGMA locking_mode=normal; + `) + if err != nil { + c.Close() + return nil, err + } } - err = c.Exec(pragmas.String()) - if err != nil { - return nil, fmt.Errorf("sqlite3: invalid _pragma: %w", err) - } return conn{ conn: c, txBegin: txBegin, diff --git a/tests/conn_test.go b/tests/conn_test.go index df7b0b4..feb0534 100644 --- a/tests/conn_test.go +++ b/tests/conn_test.go @@ -248,3 +248,17 @@ func TestConn_MustPrepare_invalid(t *testing.T) { _ = db.MustPrepare(`SELECT`) t.Error("want panic") } + +func TestConn_Pragma(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + defer func() { _ = recover() }() + _ = db.Pragma("encoding=''") + t.Error("want panic") +}