From ad27d5d840d83caa982be992411c90174b65237b Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sat, 18 Feb 2023 12:20:42 +0000 Subject: [PATCH] Support pragmas, integration test. --- driver/bradfitz_test.go | 186 ++++++++++++++++++++++++++++++++++++++++ driver/driver.go | 40 ++++++--- 2 files changed, 215 insertions(+), 11 deletions(-) create mode 100644 driver/bradfitz_test.go diff --git a/driver/bradfitz_test.go b/driver/bradfitz_test.go new file mode 100644 index 0000000..1793d67 --- /dev/null +++ b/driver/bradfitz_test.go @@ -0,0 +1,186 @@ +package driver_test + +import ( + "database/sql" + "fmt" + "math/rand" + "path/filepath" + "testing" + + _ "github.com/ncruces/go-sqlite3/driver" + _ "github.com/ncruces/go-sqlite3/embed" +) + +type Tester interface { + RunTest(*testing.T, func(params)) +} + +var ( + sqlite Tester = sqliteDB{} +) + +const TablePrefix = "gosqltest_" + +type sqliteDB struct{} + +type params struct { + dbType Tester + *testing.T + *sql.DB +} + +func (t params) mustExec(sql string, args ...interface{}) sql.Result { + res, err := t.DB.Exec(sql, args...) + if err != nil { + t.Fatalf("Error running %q: %v", sql, err) + } + return res +} + +// q converts "?" characters to $1, $2, $n on postgres, :1, :2, :n on Oracle +func (t params) q(sql string) string { + return sql +} + +func (sqliteDB) RunTest(t *testing.T, fn func(params)) { + db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "foo.db")) + if err != nil { + t.Fatalf("foo.db open fail: %v", err) + } + fn(params{sqlite, t, db}) + if err := db.Close(); err != nil { + t.Fatalf("foo.db close fail: %v", err) + } +} + +func sqlBlobParam(t params, size int) string { + return fmt.Sprintf("blob[%d]", size) +} + +func TestBlobs_SQLite(t *testing.T) { sqlite.RunTest(t, testBlobs) } + +func testBlobs(t params) { + var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + t.mustExec("create table " + TablePrefix + "foo (id integer primary key, bar " + sqlBlobParam(t, 16) + ")") + t.mustExec(t.q("insert into "+TablePrefix+"foo (id, bar) values(?,?)"), 0, blob) + + want := fmt.Sprintf("%x", blob) + + b := make([]byte, 16) + err := t.QueryRow(t.q("select bar from "+TablePrefix+"foo where id = ?"), 0).Scan(&b) + got := fmt.Sprintf("%x", b) + if err != nil { + t.Errorf("[]byte scan: %v", err) + } else if got != want { + t.Errorf("for []byte, got %q; want %q", got, want) + } + + err = t.QueryRow(t.q("select bar from "+TablePrefix+"foo where id = ?"), 0).Scan(&got) + want = string(blob) + if err != nil { + t.Errorf("string scan: %v", err) + } else if got != want { + t.Errorf("for string, got %q; want %q", got, want) + } +} + +func TestManyQueryRow_SQLite(t *testing.T) { sqlite.RunTest(t, testManyQueryRow) } + +func testManyQueryRow(t params) { + if testing.Short() { + t.Logf("skipping in short mode") + return + } + t.mustExec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))") + t.mustExec(t.q("insert into "+TablePrefix+"foo (id, name) values(?,?)"), 1, "bob") + var name string + for i := 0; i < 10000; i++ { + err := t.QueryRow(t.q("select name from "+TablePrefix+"foo where id = ?"), 1).Scan(&name) + if err != nil || name != "bob" { + t.Fatalf("on query %d: err=%v, name=%q", i, err, name) + } + } +} + +func TestTxQuery_SQLite(t *testing.T) { sqlite.RunTest(t, testTxQuery) } + +func testTxQuery(t params) { + tx, err := t.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + _, err = t.DB.Exec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))") + if err != nil { + t.Logf("cannot drop table "+TablePrefix+"foo: %s", err) + } + + _, err = tx.Exec(t.q("insert into "+TablePrefix+"foo (id, name) values(?,?)"), 1, "bob") + if err != nil { + t.Fatal(err) + } + + r, err := tx.Query(t.q("select name from "+TablePrefix+"foo where id = ?"), 1) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if !r.Next() { + if r.Err() != nil { + t.Fatal(err) + } + t.Fatal("expected one rows") + } + + var name string + err = r.Scan(&name) + if err != nil { + t.Fatal(err) + } +} + +func TestPreparedStmt_SQLite(t *testing.T) { sqlite.RunTest(t, testPreparedStmt) } + +func testPreparedStmt(t params) { + t.mustExec("CREATE TABLE " + TablePrefix + "t (count INT)") + sel, err := t.Prepare("SELECT count FROM " + TablePrefix + "t ORDER BY count DESC") + if err != nil { + t.Fatalf("prepare 1: %v", err) + } + ins, err := t.Prepare(t.q("INSERT INTO " + TablePrefix + "t (count) VALUES (?)")) + if err != nil { + t.Fatalf("prepare 2: %v", err) + } + + for n := 1; n <= 3; n++ { + if _, err := ins.Exec(n); err != nil { + t.Fatalf("insert(%d) = %v", n, err) + } + } + + const nRuns = 10 + ch := make(chan bool) + for i := 0; i < nRuns; i++ { + go func() { + defer func() { + ch <- true + }() + for j := 0; j < 10; j++ { + count := 0 + if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows { + t.Errorf("Query: %v", err) + return + } + if _, err := ins.Exec(rand.Intn(100)); err != nil { + t.Errorf("Insert: %v", err) + return + } + } + }() + } + for i := 0; i < nRuns; i++ { + <-ch + } +} diff --git a/driver/driver.go b/driver/driver.go index fe0d989..4a17839 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -6,6 +6,8 @@ import ( "database/sql" "database/sql/driver" "io" + "net/url" + "strings" "time" "github.com/ncruces/go-sqlite3" @@ -18,30 +20,42 @@ func init() { type sqlite struct{} func (sqlite) Open(name string) (driver.Conn, error) { + u, err := url.Parse(name) + if err != nil { + return nil, err + } c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE) if err != nil { return nil, err } - // If the database is not in WAL mode, - // use normal locking mode. - journal, err := pragma(c, "journal_mode") + var pragmas strings.Builder + for _, p := range u.Query()["_pragma"] { + pragmas.WriteString(`PRAGMA `) + pragmas.WriteString(p) + pragmas.WriteByte(';') + } + if pragmas.Len() == 0 { + pragmas.WriteString(`PRAGMA locking_mode=normal;`) + pragmas.WriteString(`PRAGMA busy_timeout=60000;`) + } + err = c.Exec(pragmas.String()) if err != nil { return nil, err } - if journal != "wal" { - pragma(c, "locking_mode=normal") - } - return conn{c}, nil + return conn{c, pragmas.String()}, nil } -type conn struct{ conn *sqlite3.Conn } +type conn struct { + conn *sqlite3.Conn + pragmas string +} var ( // Ensure these interfaces are implemented: - _ driver.Validator = conn{} - _ driver.ExecerContext = conn{} + _ driver.Validator = conn{} + _ driver.SessionResetter = conn{} + _ driver.ExecerContext = conn{} // _ driver.ConnBeginTx = conn{} - // _ driver.SessionResetter = conn{} ) func (c conn) Close() error { @@ -54,6 +68,10 @@ func (c conn) IsValid() bool { return mode == "normal" } +func (c conn) ResetSession(ctx context.Context) error { + return c.conn.Exec(c.pragmas) +} + func (c conn) Begin() (driver.Tx, error) { err := c.conn.Exec(`BEGIN`) if err != nil {