Support pragmas, integration test.

This commit is contained in:
Nuno Cruces
2023-02-18 12:20:42 +00:00
parent ec5bd236f8
commit ad27d5d840
2 changed files with 215 additions and 11 deletions

186
driver/bradfitz_test.go Normal file
View File

@@ -0,0 +1,186 @@
package driver_test
import (
"database/sql"
"fmt"
"math/rand"
"path/filepath"
"testing"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
type Tester interface {
RunTest(*testing.T, func(params))
}
var (
sqlite Tester = sqliteDB{}
)
const TablePrefix = "gosqltest_"
type sqliteDB struct{}
type params struct {
dbType Tester
*testing.T
*sql.DB
}
func (t params) mustExec(sql string, args ...interface{}) sql.Result {
res, err := t.DB.Exec(sql, args...)
if err != nil {
t.Fatalf("Error running %q: %v", sql, err)
}
return res
}
// q converts "?" characters to $1, $2, $n on postgres, :1, :2, :n on Oracle
func (t params) q(sql string) string {
return sql
}
func (sqliteDB) RunTest(t *testing.T, fn func(params)) {
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "foo.db"))
if err != nil {
t.Fatalf("foo.db open fail: %v", err)
}
fn(params{sqlite, t, db})
if err := db.Close(); err != nil {
t.Fatalf("foo.db close fail: %v", err)
}
}
func sqlBlobParam(t params, size int) string {
return fmt.Sprintf("blob[%d]", size)
}
func TestBlobs_SQLite(t *testing.T) { sqlite.RunTest(t, testBlobs) }
func testBlobs(t params) {
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
t.mustExec("create table " + TablePrefix + "foo (id integer primary key, bar " + sqlBlobParam(t, 16) + ")")
t.mustExec(t.q("insert into "+TablePrefix+"foo (id, bar) values(?,?)"), 0, blob)
want := fmt.Sprintf("%x", blob)
b := make([]byte, 16)
err := t.QueryRow(t.q("select bar from "+TablePrefix+"foo where id = ?"), 0).Scan(&b)
got := fmt.Sprintf("%x", b)
if err != nil {
t.Errorf("[]byte scan: %v", err)
} else if got != want {
t.Errorf("for []byte, got %q; want %q", got, want)
}
err = t.QueryRow(t.q("select bar from "+TablePrefix+"foo where id = ?"), 0).Scan(&got)
want = string(blob)
if err != nil {
t.Errorf("string scan: %v", err)
} else if got != want {
t.Errorf("for string, got %q; want %q", got, want)
}
}
func TestManyQueryRow_SQLite(t *testing.T) { sqlite.RunTest(t, testManyQueryRow) }
func testManyQueryRow(t params) {
if testing.Short() {
t.Logf("skipping in short mode")
return
}
t.mustExec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))")
t.mustExec(t.q("insert into "+TablePrefix+"foo (id, name) values(?,?)"), 1, "bob")
var name string
for i := 0; i < 10000; i++ {
err := t.QueryRow(t.q("select name from "+TablePrefix+"foo where id = ?"), 1).Scan(&name)
if err != nil || name != "bob" {
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
}
}
}
func TestTxQuery_SQLite(t *testing.T) { sqlite.RunTest(t, testTxQuery) }
func testTxQuery(t params) {
tx, err := t.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
_, err = t.DB.Exec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))")
if err != nil {
t.Logf("cannot drop table "+TablePrefix+"foo: %s", err)
}
_, err = tx.Exec(t.q("insert into "+TablePrefix+"foo (id, name) values(?,?)"), 1, "bob")
if err != nil {
t.Fatal(err)
}
r, err := tx.Query(t.q("select name from "+TablePrefix+"foo where id = ?"), 1)
if err != nil {
t.Fatal(err)
}
defer r.Close()
if !r.Next() {
if r.Err() != nil {
t.Fatal(err)
}
t.Fatal("expected one rows")
}
var name string
err = r.Scan(&name)
if err != nil {
t.Fatal(err)
}
}
func TestPreparedStmt_SQLite(t *testing.T) { sqlite.RunTest(t, testPreparedStmt) }
func testPreparedStmt(t params) {
t.mustExec("CREATE TABLE " + TablePrefix + "t (count INT)")
sel, err := t.Prepare("SELECT count FROM " + TablePrefix + "t ORDER BY count DESC")
if err != nil {
t.Fatalf("prepare 1: %v", err)
}
ins, err := t.Prepare(t.q("INSERT INTO " + TablePrefix + "t (count) VALUES (?)"))
if err != nil {
t.Fatalf("prepare 2: %v", err)
}
for n := 1; n <= 3; n++ {
if _, err := ins.Exec(n); err != nil {
t.Fatalf("insert(%d) = %v", n, err)
}
}
const nRuns = 10
ch := make(chan bool)
for i := 0; i < nRuns; i++ {
go func() {
defer func() {
ch <- true
}()
for j := 0; j < 10; j++ {
count := 0
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
t.Errorf("Query: %v", err)
return
}
if _, err := ins.Exec(rand.Intn(100)); err != nil {
t.Errorf("Insert: %v", err)
return
}
}
}()
}
for i := 0; i < nRuns; i++ {
<-ch
}
}

View File

@@ -6,6 +6,8 @@ import (
"database/sql"
"database/sql/driver"
"io"
"net/url"
"strings"
"time"
"github.com/ncruces/go-sqlite3"
@@ -18,30 +20,42 @@ func init() {
type sqlite struct{}
func (sqlite) Open(name string) (driver.Conn, error) {
u, err := url.Parse(name)
if err != nil {
return nil, err
}
c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE)
if err != nil {
return nil, err
}
// If the database is not in WAL mode,
// use normal locking mode.
journal, err := pragma(c, "journal_mode")
var pragmas strings.Builder
for _, p := range u.Query()["_pragma"] {
pragmas.WriteString(`PRAGMA `)
pragmas.WriteString(p)
pragmas.WriteByte(';')
}
if pragmas.Len() == 0 {
pragmas.WriteString(`PRAGMA locking_mode=normal;`)
pragmas.WriteString(`PRAGMA busy_timeout=60000;`)
}
err = c.Exec(pragmas.String())
if err != nil {
return nil, err
}
if journal != "wal" {
pragma(c, "locking_mode=normal")
}
return conn{c}, nil
return conn{c, pragmas.String()}, nil
}
type conn struct{ conn *sqlite3.Conn }
type conn struct {
conn *sqlite3.Conn
pragmas string
}
var (
// Ensure these interfaces are implemented:
_ driver.Validator = conn{}
_ driver.ExecerContext = conn{}
_ driver.Validator = conn{}
_ driver.SessionResetter = conn{}
_ driver.ExecerContext = conn{}
// _ driver.ConnBeginTx = conn{}
// _ driver.SessionResetter = conn{}
)
func (c conn) Close() error {
@@ -54,6 +68,10 @@ func (c conn) IsValid() bool {
return mode == "normal"
}
func (c conn) ResetSession(ctx context.Context) error {
return c.conn.Exec(c.pragmas)
}
func (c conn) Begin() (driver.Tx, error) {
err := c.conn.Exec(`BEGIN`)
if err != nil {