From c351400be77d41310e70d0880c2807242048ea68 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Mon, 20 Feb 2023 13:30:01 +0000 Subject: [PATCH] Tests. --- conn_test.go | 23 +++- driver/driver.go | 19 ++- driver/driver_test.go | 291 +++++++++++++++++++++++++++++++++++++++++ driver/error.go | 2 +- driver/time.go | 10 +- driver/time_test.go | 4 + tests/db_test.go | 24 ++-- tests/driver_test.go | 101 ++++++++++++++ tests/parallel_test.go | 7 +- vfs_lock_test.go | 8 +- vfs_test.go | 6 +- 11 files changed, 465 insertions(+), 30 deletions(-) create mode 100644 driver/driver_test.go create mode 100644 tests/driver_test.go diff --git a/conn_test.go b/conn_test.go index b3eec7b..41b27ca 100644 --- a/conn_test.go +++ b/conn_test.go @@ -5,6 +5,7 @@ import ( "context" "errors" "math" + "strings" "testing" ) @@ -51,7 +52,7 @@ func TestConn_SetInterrupt(t *testing.T) { } defer db.Close() - ctx, cancel := context.WithCancel(context.TODO()) + ctx, cancel := context.WithCancel(context.Background()) db.SetInterrupt(ctx.Done()) // Interrupt doesn't interrupt this. @@ -140,6 +141,26 @@ func TestConn_Prepare_Empty(t *testing.T) { } } +func TestConn_Prepare_Tail(t *testing.T) { + t.Parallel() + + db, err := Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + stmt, tail, err := db.Prepare(`SELECT 1; -- HERE`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if !strings.Contains(tail, "-- HERE") { + t.Errorf("got %q", tail) + } +} + func TestConn_Prepare_Invalid(t *testing.T) { t.Parallel() diff --git a/driver/driver.go b/driver/driver.go index ac4d2e3..e20e790 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -5,6 +5,7 @@ import ( "context" "database/sql" "database/sql/driver" + "fmt" "io" "net/url" "strings" @@ -48,7 +49,7 @@ func (sqlite) Open(name string) (driver.Conn, error) { err = c.Exec(pragmas.String()) if err != nil { - return nil, err + return nil, fmt.Errorf("sqlite3: invalid _pragma: %w", err) } return conn{ conn: c, @@ -61,7 +62,7 @@ type conn struct { conn *sqlite3.Conn pragmas string txBegin string - txRollback bool + txReadOnly bool } var ( @@ -101,11 +102,11 @@ func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, er txBegin := c.txBegin if opts.ReadOnly { txBegin = ` - BEGIN DEFERRED; + BEGIN deferred; PRAGMA query_only=on; ` } - c.txRollback = opts.ReadOnly + c.txReadOnly = opts.ReadOnly err := c.conn.Exec(txBegin) if err != nil { @@ -115,7 +116,7 @@ func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, er } func (c conn) Commit() error { - if c.txRollback { + if c.txReadOnly { return c.Rollback() } err := c.conn.Exec(`COMMIT`) @@ -198,7 +199,13 @@ func (s stmt) Close() error { } func (s stmt) NumInput() int { - return s.stmt.BindCount() + n := s.stmt.BindCount() + for i := 1; i <= n; i++ { + if s.stmt.BindName(i) != "" { + return -1 + } + } + return n } // Deprecated: use ExecContext instead. diff --git a/driver/driver_test.go b/driver/driver_test.go new file mode 100644 index 0000000..3d37f1f --- /dev/null +++ b/driver/driver_test.go @@ -0,0 +1,291 @@ +// Package driver provides a database/sql driver for SQLite. +package driver + +import ( + "bytes" + "context" + "database/sql" + "errors" + "math" + "path/filepath" + "testing" + "time" + + "github.com/ncruces/go-sqlite3" +) + +func Test_Open_dir(t *testing.T) { + db, err := sql.Open("sqlite3", ".") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Conn(context.TODO()) + if err == nil { + t.Fatal("want error") + } + var serr *sqlite3.Error + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != sqlite3.CANTOPEN { + t.Errorf("got %d, want sqlite3.CANTOPEN", rc) + } + if got := err.Error(); got != `sqlite3: unable to open database file` { + t.Error("got message: ", got) + } +} + +func Test_Open_pragma(t *testing.T) { + db, err := sql.Open("sqlite3", "file::memory:?_pragma=busy_timeout(1000)") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + var timeout int + err = db.QueryRow(`PRAGMA busy_timeout`).Scan(&timeout) + if err != nil { + t.Fatal(err) + } + if timeout != 1000 { + t.Errorf("got %v, want 1000", timeout) + } +} + +func Test_Open_pragma_invalid(t *testing.T) { + db, err := sql.Open("sqlite3", "file::memory:?_pragma=busy_timeout+1000") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Conn(context.TODO()) + if err == nil { + t.Fatal("want error") + } + var serr *sqlite3.Error + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != sqlite3.ERROR { + t.Errorf("got %d, want sqlite3.ERROR", rc) + } + if got := err.Error(); got != `sqlite3: invalid _pragma: sqlite3: SQL logic error: near "1000": syntax error` { + t.Error("got message: ", got) + } +} + +func Test_Open_txLock(t *testing.T) { + db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.db")+ + "?_txlock=exclusive&_pragma=busy_timeout(0)") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + tx1, err := db.Begin() + if err != nil { + t.Fatal(err) + } + + _, err = db.Begin() + if err == nil { + t.Error("want error") + } + var serr *sqlite3.Error + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != sqlite3.BUSY { + t.Errorf("got %d, want sqlite3.BUSY", rc) + } + if got := err.Error(); got != `sqlite3: database is locked` { + t.Error("got message: ", got) + } + + err = tx1.Commit() + if err != nil { + t.Fatal(err) + } +} + +func Test_BeginTx(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.db")) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted}) + if err != isolationErr { + t.Error("want isolationErr") + } + + tx1, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + if err != nil { + t.Fatal(err) + } + + tx2, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + if err != nil { + t.Fatal(err) + } + + _, err = tx1.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`) + if err == nil { + t.Error("want error") + } + var serr *sqlite3.Error + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != sqlite3.READONLY { + t.Errorf("got %d, want sqlite3.READONLY", rc) + } + if got := err.Error(); got != `sqlite3: attempt to write a readonly database` { + t.Error("got message: ", got) + } + + err = tx2.Commit() + if err != nil { + t.Fatal(err) + } + + err = tx1.Commit() + if err != nil { + t.Fatal(err) + } +} + +func Test_Prepare(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + stmt, err := db.Prepare(`SELECT 1; -- HERE`) + if err != nil { + t.Error(err) + } + defer stmt.Close() + + var serr *sqlite3.Error + _, err = db.Prepare(`SELECT`) + if err == nil { + t.Error("want error") + } + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != sqlite3.ERROR { + t.Errorf("got %d, want sqlite3.ERROR", rc) + } + if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` { + t.Error("got message: ", got) + } + + _, err = db.Prepare(`SELECT 1; SELECT`) + if err == nil { + t.Error("want error") + } + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != sqlite3.ERROR { + t.Errorf("got %d, want sqlite3.ERROR", rc) + } + if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` { + t.Error("got message: ", got) + } + + _, err = db.Prepare(`SELECT 1; SELECT 2`) + if err != tailErr { + t.Error("want tailErr") + } +} + +func Test_QueryRow_named(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + stmt, err := conn.PrepareContext(ctx, `SELECT ?, ?5, :AAA, @AAA, $AAA`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + date := time.Now() + row := stmt.QueryRow(true, sql.Named("AAA", math.Pi), nil /*3*/, nil /*4*/, date /*5*/) + + var first bool + var fifth time.Time + var colon, at, dollar float32 + err = row.Scan(&first, &fifth, &colon, &at, &dollar) + if err != nil { + t.Fatal(err) + } + + if first != true { + t.Errorf("want true, got %v", first) + } + if colon != math.Pi { + t.Errorf("want π, got %v", colon) + } + if at != math.Pi { + t.Errorf("want π, got %v", at) + } + if dollar != math.Pi { + t.Errorf("want π, got %v", dollar) + } + if !fifth.Equal(date) { + t.Errorf("want %v, got %v", date, fifth) + } +} + +func Test_QueryRow_blob_null(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + rows, err := db.Query(` + SELECT NULL UNION ALL + SELECT x'cafe' UNION ALL + SELECT x'babe' UNION ALL + SELECT NULL + `) + if err != nil { + t.Fatal(err) + } + + want := [][]byte{nil, {0xca, 0xfe}, {0xba, 0xbe}, nil} + for i := 0; rows.Next(); i++ { + var buf []byte + err = rows.Scan(&buf) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, want[i]) { + t.Errorf("got %q, want %q", buf, want[i]) + } + } +} diff --git a/driver/error.go b/driver/error.go index eee8cf9..8af54b1 100644 --- a/driver/error.go +++ b/driver/error.go @@ -7,5 +7,5 @@ func (e errorString) Error() string { return string(e) } const ( assertErr = errorString("sqlite3: assertion failed") tailErr = errorString("sqlite3: multiple statements") - isolationErr = errorString("sqlite3: unsupport isolation level") + isolationErr = errorString("sqlite3: unsupported isolation level") ) diff --git a/driver/time.go b/driver/time.go index 0a9a5b9..efac7dd 100644 --- a/driver/time.go +++ b/driver/time.go @@ -12,17 +12,15 @@ import ( func maybeDate(text string) driver.Value { // Weed out (some) values that can't possibly be // [time.RFC3339Nano] timestamps. - if len(text) < len("2006-01-02T15:04:05") { + if len(text) < len("2006-01-02T15:04:05Z") { + return text + } + if len(text) > len(time.RFC3339Nano) { return text } if text[4] != '-' || text[10] != 'T' || text[16] != ':' { return text } - for _, c := range []byte(text[:4]) { - if c < '0' || '9' < c { - return text - } - } // Slow path. date, err := time.Parse(time.RFC3339Nano, text) diff --git a/driver/time_test.go b/driver/time_test.go index 7fd3711..66d418e 100644 --- a/driver/time_test.go +++ b/driver/time_test.go @@ -15,6 +15,10 @@ func Fuzz_maybeDate(f *testing.F) { f.Add(time.DateTime) f.Add(time.DateOnly) f.Add(time.TimeOnly) + f.Add("2006-01-02T15:04:05Z") + f.Add("2006-01-02T15:04:05.000Z") + f.Add("2006-01-02T15:04:05.9999999999Z") + f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") f.Fuzz(func(t *testing.T, str string) { value := maybeDate(str) diff --git a/tests/db_test.go b/tests/db_test.go index 0f31367..dc1229a 100644 --- a/tests/db_test.go +++ b/tests/db_test.go @@ -32,6 +32,10 @@ func testDB(t *testing.T, name string) { if err != nil { t.Fatal(err) } + changes := db.Changes() + if changes != 3 { + t.Errorf("got %d want 3", changes) + } stmt, _, err := db.Prepare(`SELECT id, name FROM users`) if err != nil { @@ -43,18 +47,22 @@ func testDB(t *testing.T, name string) { ids := []int{0, 1, 2} names := []string{"go", "zig", "whatever"} for ; stmt.Step(); row++ { - if ids[row] != stmt.ColumnInt(0) { - t.Errorf("got %d, want %d", stmt.ColumnInt(0), ids[row]) + id := stmt.ColumnInt(0) + name := stmt.ColumnText(1) + + if id != ids[row] { + t.Errorf("got %d, want %d", id, ids[row]) } - if names[row] != stmt.ColumnText(1) { - t.Errorf("got %q, want %q", stmt.ColumnText(1), names[row]) + if name != names[row] { + t.Errorf("got %q, want %q", name, names[row]) } } - if err := stmt.Err(); err != nil { - t.Fatal(err) - } if row != 3 { - t.Errorf("got %d rows, want %d", row, len(ids)) + t.Errorf("got %d, want %d", row, len(ids)) + } + + if err := stmt.Err(); err != nil { + t.Fatal(err) } err = stmt.Close() diff --git a/tests/driver_test.go b/tests/driver_test.go new file mode 100644 index 0000000..7ea00a6 --- /dev/null +++ b/tests/driver_test.go @@ -0,0 +1,101 @@ +package tests + +import ( + "context" + "database/sql" + "testing" + + _ "github.com/ncruces/go-sqlite3/driver" + _ "github.com/ncruces/go-sqlite3/embed" +) + +func TestDriver(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + _, err = conn.ExecContext(ctx, + `CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`) + if err != nil { + t.Fatal(err) + } + + res, err := conn.ExecContext(ctx, + `INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`) + if err != nil { + t.Fatal(err) + } + changes, err := res.RowsAffected() + if err != nil { + t.Fatal(err) + } + if changes != 3 { + t.Errorf("got %d want 3", changes) + } + + stmt, err := conn.PrepareContext(context.Background(), + `SELECT id, name FROM users`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + rows, err := stmt.Query() + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + row := 0 + ids := []int{0, 1, 2} + names := []string{"go", "zig", "whatever"} + for ; rows.Next(); row++ { + var id int + var name string + err := rows.Scan(&id, &name) + if err != nil { + t.Fatal(err) + } + + if id != ids[row] { + t.Errorf("got %d, want %d", id, ids[row]) + } + if name != names[row] { + t.Errorf("got %q, want %q", name, names[row]) + } + } + if row != 3 { + t.Errorf("got %d, want %d", row, len(ids)) + } + + err = rows.Close() + if err != nil { + t.Fatal(err) + } + + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + + err = conn.Close() + if err != nil { + t.Fatal(err) + } + + err = db.Close() + if err != nil { + t.Fatal(err) + } +} diff --git a/tests/parallel_test.go b/tests/parallel_test.go index 57e2a73..a251e92 100644 --- a/tests/parallel_test.go +++ b/tests/parallel_test.go @@ -1,6 +1,7 @@ package tests import ( + "errors" "io" "os" "os/exec" @@ -44,7 +45,11 @@ func TestMultiProcess(t *testing.T) { testParallel(t, name, 1000) if err := cmd.Wait(); err != nil { - t.Fatal(err) + t.Error(err) + var eerr *exec.ExitError + if errors.As(err, &eerr) { + t.Error(eerr.Stderr) + } } testIntegrity(t, name) } diff --git a/vfs_lock_test.go b/vfs_lock_test.go index 13ebce4..bd62844 100644 --- a/vfs_lock_test.go +++ b/vfs_lock_test.go @@ -3,6 +3,7 @@ package sqlite3 import ( "context" "os" + "path/filepath" "runtime" "testing" ) @@ -16,16 +17,15 @@ func Test_vfsLock(t *testing.T) { t.Skip() } + name := filepath.Join(t.TempDir(), "test.db") + // Create a temporary file. - file1, err := os.CreateTemp("", "sqlite3-") + file1, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666) if err != nil { t.Fatal(err) } defer file1.Close() - name := file1.Name() - defer os.RemoveAll(name) - // Open the temporary file again. file2, err := os.OpenFile(name, os.O_RDWR, 0) if err != nil { diff --git a/vfs_test.go b/vfs_test.go index 6af6d21..7262990 100644 --- a/vfs_test.go +++ b/vfs_test.go @@ -136,12 +136,12 @@ func Test_vfsFullPathname(t *testing.T) { } func Test_vfsDelete(t *testing.T) { - file, err := os.CreateTemp("", "sqlite3-") + name := filepath.Join(t.TempDir(), "test.db") + + file, err := os.Create(name) if err != nil { t.Fatal(err) } - name := file.Name() - defer os.RemoveAll(name) file.Close() mem := newMemory(128 + _MAX_PATHNAME)