From e91758c6a4e77fc0055acf6bca42207b58b85521 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Wed, 22 Feb 2023 14:19:56 +0000 Subject: [PATCH] Zero blobs, tests, documentation --- .github/FUNDING.yml | 1 + README.md | 7 +- blob.go | 6 + conn_test.go | 202 --------------- const.go | 1 + const_test.go | 2 + driver/driver.go | 19 +- driver/driver_test.go | 38 ++- stmt.go | 13 + .../bradfitz/sql_test.go | 49 ++-- tests/compile/empty/compile_test.go | 2 +- tests/compile/missing/compile_test.go | 4 +- tests/compile/nil/compile_test.go | 4 +- tests/conn_test.go | 229 +++++++++++++++++ tests/db_test.go | 2 + tests/dir_test.go | 26 -- tests/driver_test.go | 2 + tests/{ => parallel}/parallel_test.go | 4 +- stmt_test.go => tests/stmt_test.go | 234 +++++++++--------- time.go | 31 ++- time_test.go | 4 + util_test.go | 2 + 22 files changed, 487 insertions(+), 395 deletions(-) create mode 100644 .github/FUNDING.yml create mode 100644 blob.go rename driver/bradfitz_test.go => tests/bradfitz/sql_test.go (74%) create mode 100644 tests/conn_test.go delete mode 100644 tests/dir_test.go rename tests/{ => parallel}/parallel_test.go (99%) rename stmt_test.go => tests/stmt_test.go (69%) diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000..2955ba5 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +custom: https://www.paypal.com/donate/buttons/manage/33P59ELZWGMK6 \ No newline at end of file diff --git a/README.md b/README.md index bb82da9..035aaa6 100644 --- a/README.md +++ b/README.md @@ -18,4 +18,9 @@ Roadmap: - [x] provide a simple `database/sql` driver - [x] file locking, compatible with SQLite on Windows/Unix - [ ] shared memory, compatible with SQLite on Windows/Unix - - needed for improved WAL mode \ No newline at end of file + - needed for improved WAL mode +- [ ] advanced features + - [ ] incremental BLOB I/O + - [ ] online backup + - [ ] session extension + - [ ] snapshot \ No newline at end of file diff --git a/blob.go b/blob.go new file mode 100644 index 0000000..850be5f --- /dev/null +++ b/blob.go @@ -0,0 +1,6 @@ +package sqlite3 + +// ZeroBlob represents a zero-filled, length n BLOB +// that can be used as an argument to +// [database.sql.DB.Exec] and similar methods. +type ZeroBlob int64 diff --git a/conn_test.go b/conn_test.go index f8832d4..94e1195 100644 --- a/conn_test.go +++ b/conn_test.go @@ -2,212 +2,10 @@ package sqlite3 import ( "bytes" - "context" - "errors" "math" - "strings" "testing" ) -func TestConn_Close(t *testing.T) { - var conn *Conn - conn.Close() -} - -func TestConn_Close_BUSY(t *testing.T) { - t.Parallel() - - db, err := Open(":memory:") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - stmt, _, err := db.Prepare(`BEGIN`) - if err != nil { - t.Fatal(err) - } - defer stmt.Close() - - err = db.Close() - if err == nil { - t.Fatal("want error") - } - var serr *Error - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) - } - if rc := serr.Code(); rc != BUSY { - t.Errorf("got %d, want sqlite3.BUSY", rc) - } - var terr interface{ Temporary() bool } - if !errors.As(err, &terr) || !terr.Temporary() { - t.Error("not temporary", err) - } - if got := err.Error(); got != `sqlite3: database is locked: unable to close due to unfinalized statements or unfinished backups` { - t.Error("got message: ", got) - } -} - -func TestConn_SetInterrupt(t *testing.T) { - db, err := Open(":memory:") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - ctx, cancel := context.WithCancel(context.Background()) - db.SetInterrupt(ctx.Done()) - - // Interrupt doesn't interrupt this. - err = db.Exec(`SELECT 1`) - if err != nil { - t.Fatal(err) - } - - db.SetInterrupt(nil) - - stmt, _, err := db.Prepare(` - WITH RECURSIVE - fibonacci (curr, next) - AS ( - SELECT 0, 1 - UNION ALL - SELECT next, curr + next FROM fibonacci - LIMIT 1e6 - ) - SELECT min(curr) FROM fibonacci - `) - if err != nil { - t.Fatal(err) - } - defer stmt.Close() - - cancel() - db.SetInterrupt(ctx.Done()) - - var serr *Error - - // Interrupting works. - err = stmt.Exec() - if err != nil { - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) - } - if rc := serr.Code(); rc != INTERRUPT { - t.Errorf("got %d, want sqlite3.INTERRUPT", rc) - } - if got := err.Error(); got != `sqlite3: interrupted` { - t.Error("got message: ", got) - } - } - - // Interrupting sticks. - err = db.Exec(`SELECT 1`) - if err != nil { - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) - } - if rc := serr.Code(); rc != INTERRUPT { - t.Errorf("got %d, want sqlite3.INTERRUPT", rc) - } - if got := err.Error(); got != `sqlite3: interrupted` { - t.Error("got message: ", got) - } - } - - db.SetInterrupt(nil) - - // Interrupting can be cleared. - err = db.Exec(`SELECT 1`) - if err != nil { - t.Fatal(err) - } -} - -func TestConn_Prepare_Empty(t *testing.T) { - t.Parallel() - - db, err := Open(":memory:") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - stmt, _, err := db.Prepare(``) - if err != nil { - t.Fatal(err) - } - defer stmt.Close() - - if stmt != nil { - t.Error("want nil") - } -} - -func TestConn_Prepare_Tail(t *testing.T) { - t.Parallel() - - db, err := Open(":memory:") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - stmt, tail, err := db.Prepare(`SELECT 1; -- HERE`) - if err != nil { - t.Fatal(err) - } - defer stmt.Close() - - if !strings.Contains(tail, "-- HERE") { - t.Errorf("got %q", tail) - } -} - -func TestConn_Prepare_Invalid(t *testing.T) { - t.Parallel() - - db, err := Open(":memory:") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - var serr *Error - - _, _, err = db.Prepare(`SELECT`) - if err == nil { - t.Fatal("want error") - } - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) - } - if rc := serr.Code(); rc != ERROR { - t.Errorf("got %d, want sqlite3.ERROR", rc) - } - if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` { - t.Error("got message: ", got) - } - - _, _, err = db.Prepare(`SELECT * FRM sqlite_schema`) - if err == nil { - t.Fatal("want error") - } - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.ERROR", err) - } - if rc := serr.Code(); rc != ERROR { - t.Errorf("got %d, want sqlite3.ERROR", rc) - } - if got := serr.SQL(); got != `FRM sqlite_schema` { - t.Error("got SQL: ", got) - } - if got := serr.Error(); got != `sqlite3: SQL logic error: near "FRM": syntax error` { - t.Error("got message: ", got) - } -} - func TestConn_new(t *testing.T) { t.Parallel() diff --git a/const.go b/const.go index 9ba7553..62db4c8 100644 --- a/const.go +++ b/const.go @@ -197,6 +197,7 @@ const ( NULL Datatype = 5 ) +// String implements the [fmt.Stringer] interface. func (t Datatype) String() string { const name = "INTEGERFLOATTEXTBLOBNULL" switch t { diff --git a/const_test.go b/const_test.go index 5cf97bf..5753ebf 100644 --- a/const_test.go +++ b/const_test.go @@ -3,6 +3,8 @@ package sqlite3 import "testing" func TestDatatype_String(t *testing.T) { + t.Parallel() + tests := []struct { data Datatype want string diff --git a/driver/driver.go b/driver/driver.go index 3142caa..53cd18e 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -186,8 +186,9 @@ type stmt struct { var ( // Ensure these interfaces are implemented: - _ driver.StmtExecContext = stmt{} - _ driver.StmtQueryContext = stmt{} + _ driver.StmtExecContext = stmt{} + _ driver.StmtQueryContext = stmt{} + _ driver.NamedValueChecker = stmt{} ) func (s stmt) Close() error { @@ -256,6 +257,8 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive switch a := arg.Value.(type) { case bool: err = s.stmt.BindBool(id, a) + case int: + err = s.stmt.BindInt(id, a) case int64: err = s.stmt.BindInt64(id, a) case float64: @@ -264,6 +267,8 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive err = s.stmt.BindText(id, a) case []byte: err = s.stmt.BindBlob(id, a) + case sqlite3.ZeroBlob: + err = s.stmt.BindZeroBlob(id, int64(a)) case time.Time: err = s.stmt.BindText(id, a.Format(time.RFC3339Nano)) case nil: @@ -280,6 +285,16 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive return rows{ctx, s.stmt, s.conn}, nil } +func (s stmt) CheckNamedValue(arg *driver.NamedValue) error { + switch arg.Value.(type) { + case bool, int, int64, float64, string, []byte, + sqlite3.ZeroBlob, time.Time, nil: + return nil + default: + return driver.ErrSkip + } +} + type result struct{ lastInsertId, rowsAffected int64 } func (r result) LastInsertId() (int64, error) { diff --git a/driver/driver_test.go b/driver/driver_test.go index aa1e285..a5cfa8f 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -157,7 +157,7 @@ func Test_BeginTx(t *testing.T) { t.Fatal(err) } - _, err = tx1.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`) + _, err = tx1.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) if err == nil { t.Error("want error") } @@ -310,3 +310,39 @@ func Test_QueryRow_blob_null(t *testing.T) { } } } + +func Test_ZeroBlob(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + _, err = conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + t.Fatal(err) + } + + _, err = conn.ExecContext(ctx, `INSERT INTO test(col) VALUES(?)`, sqlite3.ZeroBlob(4)) + if err != nil { + t.Fatal(err) + } + + var got []byte + err = conn.QueryRowContext(ctx, `SELECT col FROM test`).Scan(&got) + if err != nil { + t.Fatal(err) + } + if string(got) != "\x00\x00\x00\x00" { + t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got) + } +} diff --git a/stmt.go b/stmt.go index bfb534a..59997fe 100644 --- a/stmt.go +++ b/stmt.go @@ -222,6 +222,19 @@ func (s *Stmt) BindBlob(param int, value []byte) error { return s.c.error(r[0]) } +// BindZeroBlob binds a zero-filled, length n BLOB to the prepared statement. +// The leftmost SQL parameter has an index of 1. +// +// https://www.sqlite.org/c3ref/bind_blob.html +func (s *Stmt) BindZeroBlob(param int, n int64) error { + r, err := s.c.api.bindZeroBlob.Call(s.c.ctx, + uint64(s.handle), uint64(param), uint64(n)) + if err != nil { + panic(err) + } + return s.c.error(r[0]) +} + // BindNull binds a NULL to the prepared statement. // The leftmost SQL parameter has an index of 1. // diff --git a/driver/bradfitz_test.go b/tests/bradfitz/sql_test.go similarity index 74% rename from driver/bradfitz_test.go rename to tests/bradfitz/sql_test.go index 005fcc0..a264c90 100644 --- a/driver/bradfitz_test.go +++ b/tests/bradfitz/sql_test.go @@ -1,10 +1,13 @@ -package driver_test +package bradfitz + +// Adapted from: https://github.com/bradfitz/go-sql-test import ( "database/sql" "fmt" "math/rand" "path/filepath" + "sync" "testing" _ "github.com/ncruces/go-sqlite3/driver" @@ -37,11 +40,6 @@ func (t params) mustExec(sql string, args ...interface{}) sql.Result { return res } -// q converts "?" characters to $1, $2, $n on postgres, :1, :2, :n on Oracle -func (t params) q(sql string) string { - return sql -} - func (sqliteDB) RunTest(t *testing.T, fn func(params)) { db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "foo.db")) if err != nil { @@ -53,21 +51,17 @@ func (sqliteDB) RunTest(t *testing.T, fn func(params)) { } } -func sqlBlobParam(t params, size int) string { - return fmt.Sprintf("blob[%d]", size) -} - func TestBlobs_SQLite(t *testing.T) { sqlite.RunTest(t, testBlobs) } func testBlobs(t params) { var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - t.mustExec("create table " + TablePrefix + "foo (id integer primary key, bar " + sqlBlobParam(t, 16) + ")") - t.mustExec(t.q("insert into "+TablePrefix+"foo (id, bar) values(?,?)"), 0, blob) + t.mustExec("create table " + TablePrefix + "foo (id integer primary key, bar blob)") + t.mustExec("insert into "+TablePrefix+"foo (id, bar) values(?,?)", 0, blob) want := fmt.Sprintf("%x", blob) b := make([]byte, 16) - err := t.QueryRow(t.q("select bar from "+TablePrefix+"foo where id = ?"), 0).Scan(&b) + err := t.QueryRow("select bar from "+TablePrefix+"foo where id = ?", 0).Scan(&b) got := fmt.Sprintf("%x", b) if err != nil { t.Errorf("[]byte scan: %v", err) @@ -75,7 +69,7 @@ func testBlobs(t params) { t.Errorf("for []byte, got %q; want %q", got, want) } - err = t.QueryRow(t.q("select bar from "+TablePrefix+"foo where id = ?"), 0).Scan(&got) + err = t.QueryRow("select bar from "+TablePrefix+"foo where id = ?", 0).Scan(&got) want = string(blob) if err != nil { t.Errorf("string scan: %v", err) @@ -88,14 +82,13 @@ func TestManyQueryRow_SQLite(t *testing.T) { sqlite.RunTest(t, testManyQueryRow) func testManyQueryRow(t params) { if testing.Short() { - t.Logf("skipping in short mode") - return + t.Skip("skipping in short mode") } t.mustExec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))") - t.mustExec(t.q("insert into "+TablePrefix+"foo (id, name) values(?,?)"), 1, "bob") + t.mustExec("insert into "+TablePrefix+"foo (id, name) values(?,?)", 1, "bob") var name string for i := 0; i < 10000; i++ { - err := t.QueryRow(t.q("select name from "+TablePrefix+"foo where id = ?"), 1).Scan(&name) + err := t.QueryRow("select name from "+TablePrefix+"foo where id = ?", 1).Scan(&name) if err != nil || name != "bob" { t.Fatalf("on query %d: err=%v, name=%q", i, err, name) } @@ -116,12 +109,12 @@ func testTxQuery(t params) { t.Logf("cannot drop table "+TablePrefix+"foo: %s", err) } - _, err = tx.Exec(t.q("insert into "+TablePrefix+"foo (id, name) values(?,?)"), 1, "bob") + _, err = tx.Exec("insert into "+TablePrefix+"foo (id, name) values(?,?)", 1, "bob") if err != nil { t.Fatal(err) } - r, err := tx.Query(t.q("select name from "+TablePrefix+"foo where id = ?"), 1) + r, err := tx.Query("select name from "+TablePrefix+"foo where id = ?", 1) if err != nil { t.Fatal(err) } @@ -145,8 +138,7 @@ func TestPreparedStmt_SQLite(t *testing.T) { sqlite.RunTest(t, testPreparedStmt) func testPreparedStmt(t params) { if testing.Short() { - t.Logf("skipping in short mode") - return + t.Skip("skipping in short mode") } t.mustExec("CREATE TABLE " + TablePrefix + "t (count INT)") @@ -154,7 +146,7 @@ func testPreparedStmt(t params) { if err != nil { t.Fatalf("prepare 1: %v", err) } - ins, err := t.Prepare(t.q("INSERT INTO " + TablePrefix + "t (count) VALUES (?)")) + ins, err := t.Prepare("INSERT INTO " + TablePrefix + "t (count) VALUES (?)") if err != nil { t.Fatalf("prepare 2: %v", err) } @@ -166,12 +158,11 @@ func testPreparedStmt(t params) { } const nRuns = 10 - ch := make(chan bool) + var wg sync.WaitGroup for i := 0; i < nRuns; i++ { + wg.Add(1) go func() { - defer func() { - ch <- true - }() + defer wg.Done() for j := 0; j < 10; j++ { count := 0 if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows { @@ -185,7 +176,5 @@ func testPreparedStmt(t params) { } }() } - for i := 0; i < nRuns; i++ { - <-ch - } + wg.Wait() } diff --git a/tests/compile/empty/compile_test.go b/tests/compile/empty/compile_test.go index 9648c6a..38cc602 100644 --- a/tests/compile/empty/compile_test.go +++ b/tests/compile/empty/compile_test.go @@ -1,4 +1,4 @@ -package compile_empty +package compile import ( "testing" diff --git a/tests/compile/missing/compile_test.go b/tests/compile/missing/compile_test.go index be9cdb0..940741d 100644 --- a/tests/compile/missing/compile_test.go +++ b/tests/compile/missing/compile_test.go @@ -1,4 +1,4 @@ -package compile_empty +package compile import ( "testing" @@ -6,7 +6,7 @@ import ( "github.com/ncruces/go-sqlite3" ) -func TestCompile_empty(t *testing.T) { +func TestCompile_missing(t *testing.T) { sqlite3.Path = "sqlite3.wasm" _, err := sqlite3.Open(":memory:") if err == nil { diff --git a/tests/compile/nil/compile_test.go b/tests/compile/nil/compile_test.go index 5e3a8cc..9dc6275 100644 --- a/tests/compile/nil/compile_test.go +++ b/tests/compile/nil/compile_test.go @@ -1,4 +1,4 @@ -package compile_empty +package compile import ( "testing" @@ -6,7 +6,7 @@ import ( "github.com/ncruces/go-sqlite3" ) -func TestCompile_empty(t *testing.T) { +func TestCompile_nil(t *testing.T) { _, err := sqlite3.Open(":memory:") if err == nil { t.Error("want error") diff --git a/tests/conn_test.go b/tests/conn_test.go new file mode 100644 index 0000000..fa80daf --- /dev/null +++ b/tests/conn_test.go @@ -0,0 +1,229 @@ +package tests + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/embed" +) + +func TestConn_Open_dir(t *testing.T) { + t.Parallel() + + _, err := sqlite3.Open(".") + if err == nil { + t.Fatal("want error") + } + var serr *sqlite3.Error + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != sqlite3.CANTOPEN { + t.Errorf("got %d, want sqlite3.CANTOPEN", rc) + } + if got := err.Error(); got != `sqlite3: unable to open database file` { + t.Error("got message: ", got) + } +} + +func TestConn_Close(t *testing.T) { + var conn *sqlite3.Conn + conn.Close() +} + +func TestConn_Close_BUSY(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + stmt, _, err := db.Prepare(`BEGIN`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + err = db.Close() + if err == nil { + t.Fatal("want error") + } + var serr *sqlite3.Error + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != sqlite3.BUSY { + t.Errorf("got %d, want sqlite3.BUSY", rc) + } + var terr interface{ Temporary() bool } + if !errors.As(err, &terr) || !terr.Temporary() { + t.Error("not temporary", err) + } + if got := err.Error(); got != `sqlite3: database is locked: unable to close due to unfinalized statements or unfinished backups` { + t.Error("got message: ", got) + } +} + +func TestConn_SetInterrupt(t *testing.T) { + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + db.SetInterrupt(ctx.Done()) + + // Interrupt doesn't interrupt this. + err = db.Exec(`SELECT 1`) + if err != nil { + t.Fatal(err) + } + + db.SetInterrupt(nil) + + stmt, _, err := db.Prepare(` + WITH RECURSIVE + fibonacci (curr, next) + AS ( + SELECT 0, 1 + UNION ALL + SELECT next, curr + next FROM fibonacci + LIMIT 1e6 + ) + SELECT min(curr) FROM fibonacci + `) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + cancel() + db.SetInterrupt(ctx.Done()) + + var serr *sqlite3.Error + + // Interrupting works. + err = stmt.Exec() + if err != nil { + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != sqlite3.INTERRUPT { + t.Errorf("got %d, want sqlite3.INTERRUPT", rc) + } + if got := err.Error(); got != `sqlite3: interrupted` { + t.Error("got message: ", got) + } + } + + // Interrupting sticks. + err = db.Exec(`SELECT 1`) + if err != nil { + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != sqlite3.INTERRUPT { + t.Errorf("got %d, want sqlite3.INTERRUPT", rc) + } + if got := err.Error(); got != `sqlite3: interrupted` { + t.Error("got message: ", got) + } + } + + db.SetInterrupt(nil) + + // Interrupting can be cleared. + err = db.Exec(`SELECT 1`) + if err != nil { + t.Fatal(err) + } +} + +func TestConn_Prepare_empty(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + stmt, _, err := db.Prepare(``) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if stmt != nil { + t.Error("want nil") + } +} + +func TestConn_Prepare_tail(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + stmt, tail, err := db.Prepare(`SELECT 1; -- HERE`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if !strings.Contains(tail, "-- HERE") { + t.Errorf("got %q", tail) + } +} + +func TestConn_Prepare_invalid(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + var serr *sqlite3.Error + + _, _, err = db.Prepare(`SELECT`) + if err == nil { + t.Fatal("want error") + } + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != sqlite3.ERROR { + t.Errorf("got %d, want sqlite3.ERROR", rc) + } + if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` { + t.Error("got message: ", got) + } + + _, _, err = db.Prepare(`SELECT * FRM sqlite_schema`) + if err == nil { + t.Fatal("want error") + } + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.ERROR", err) + } + if rc := serr.Code(); rc != sqlite3.ERROR { + t.Errorf("got %d, want sqlite3.ERROR", rc) + } + if got := serr.SQL(); got != `FRM sqlite_schema` { + t.Error("got SQL: ", got) + } + if got := serr.Error(); got != `sqlite3: SQL logic error: near "FRM": syntax error` { + t.Error("got message: ", got) + } +} diff --git a/tests/db_test.go b/tests/db_test.go index dc1229a..47d8988 100644 --- a/tests/db_test.go +++ b/tests/db_test.go @@ -17,6 +17,8 @@ func TestDB_file(t *testing.T) { } func testDB(t *testing.T, name string) { + t.Parallel() + db, err := sqlite3.Open(name) if err != nil { t.Fatal(err) diff --git a/tests/dir_test.go b/tests/dir_test.go deleted file mode 100644 index 928961b..0000000 --- a/tests/dir_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package tests - -import ( - "errors" - "testing" - - "github.com/ncruces/go-sqlite3" - _ "github.com/ncruces/go-sqlite3/embed" -) - -func TestDir(t *testing.T) { - _, err := sqlite3.Open(".") - if err == nil { - t.Fatal("want error") - } - var serr *sqlite3.Error - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) - } - if rc := serr.Code(); rc != sqlite3.CANTOPEN { - t.Errorf("got %d, want sqlite3.CANTOPEN", rc) - } - if got := err.Error(); got != `sqlite3: unable to open database file` { - t.Error("got message: ", got) - } -} diff --git a/tests/driver_test.go b/tests/driver_test.go index 7ea00a6..ebf851d 100644 --- a/tests/driver_test.go +++ b/tests/driver_test.go @@ -10,6 +10,8 @@ import ( ) func TestDriver(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/tests/parallel_test.go b/tests/parallel/parallel_test.go similarity index 99% rename from tests/parallel_test.go rename to tests/parallel/parallel_test.go index a251e92..e9b6cbe 100644 --- a/tests/parallel_test.go +++ b/tests/parallel/parallel_test.go @@ -22,7 +22,7 @@ func TestParallel(t *testing.T) { func TestMultiProcess(t *testing.T) { if testing.Short() { - return + t.Skip() } name := filepath.Join(t.TempDir(), "test.db") @@ -57,7 +57,7 @@ func TestMultiProcess(t *testing.T) { func TestChildProcess(t *testing.T) { name := os.Getenv("TestMultiProcess_dbname") if name == "" || testing.Short() { - return + t.SkipNow() } testParallel(t, name, 1000) diff --git a/stmt_test.go b/tests/stmt_test.go similarity index 69% rename from stmt_test.go rename to tests/stmt_test.go index 45c8946..fa53314 100644 --- a/stmt_test.go +++ b/tests/stmt_test.go @@ -1,15 +1,17 @@ -package sqlite3 +package tests import ( "math" "testing" "time" + + "github.com/ncruces/go-sqlite3" ) func TestStmt(t *testing.T) { t.Parallel() - db, err := Open(":memory:") + db, err := sqlite3.Open(":memory:") if err != nil { t.Fatal(err) } @@ -30,103 +32,80 @@ func TestStmt(t *testing.T) { t.Errorf("got %d, want 1", got) } - err = stmt.BindBool(1, false) - if err != nil { + if err := stmt.BindBool(1, false); err != nil { + t.Fatal(err) + } + if err := stmt.Exec(); err != nil { t.Fatal(err) } - err = stmt.Exec() - if err != nil { + if err := stmt.BindBool(1, true); err != nil { + t.Fatal(err) + } + if err := stmt.Exec(); err != nil { t.Fatal(err) } - err = stmt.ClearBindings() - if err != nil { + if err := stmt.BindInt(1, 2); err != nil { + t.Fatal(err) + } + if err = stmt.Exec(); err != nil { t.Fatal(err) } - err = stmt.Exec() - if err != nil { + if err := stmt.BindFloat(1, math.Pi); err != nil { + t.Fatal(err) + } + if err := stmt.Exec(); err != nil { t.Fatal(err) } - err = stmt.BindBool(1, true) - if err != nil { + if err := stmt.BindNull(1); err != nil { + t.Fatal(err) + } + if err := stmt.Exec(); err != nil { t.Fatal(err) } - err = stmt.Exec() - if err != nil { + if err := stmt.BindText(1, ""); err != nil { + t.Fatal(err) + } + if err := stmt.Exec(); err != nil { t.Fatal(err) } - err = stmt.BindInt(1, 2) - if err != nil { + if err := stmt.BindText(1, "text"); err != nil { + t.Fatal(err) + } + if err := stmt.Exec(); err != nil { t.Fatal(err) } - err = stmt.Exec() - if err != nil { + if err := stmt.BindBlob(1, []byte("blob")); err != nil { + t.Fatal(err) + } + if err := stmt.Exec(); err != nil { t.Fatal(err) } - err = stmt.BindFloat(1, math.Pi) - if err != nil { + if err := stmt.BindBlob(1, nil); err != nil { + t.Fatal(err) + } + if err := stmt.Exec(); err != nil { t.Fatal(err) } - err = stmt.Exec() - if err != nil { + if err := stmt.BindZeroBlob(1, 4); err != nil { + t.Fatal(err) + } + if err := stmt.Exec(); err != nil { t.Fatal(err) } - err = stmt.BindNull(1) - if err != nil { + if err := stmt.ClearBindings(); err != nil { t.Fatal(err) } - - err = stmt.Exec() - if err != nil { - t.Fatal(err) - } - - err = stmt.BindText(1, "") - if err != nil { - t.Fatal(err) - } - - err = stmt.Exec() - if err != nil { - t.Fatal(err) - } - - err = stmt.BindText(1, "text") - if err != nil { - t.Fatal(err) - } - - err = stmt.Exec() - if err != nil { - t.Fatal(err) - } - - err = stmt.BindBlob(1, []byte("blob")) - if err != nil { - t.Fatal(err) - } - - err = stmt.Exec() - if err != nil { - t.Fatal(err) - } - - err = stmt.BindBlob(1, nil) - if err != nil { - t.Fatal(err) - } - - err = stmt.Exec() - if err != nil { + if err := stmt.Exec(); err != nil { t.Fatal(err) } @@ -135,7 +114,7 @@ func TestStmt(t *testing.T) { t.Fatal(err) } - // The table should have: 0, NULL, 1, 2, π, NULL, "", "text", `blob`, NULL + // The table should have: 0, 1, 2, π, NULL, "", "text", "blob", NULL, "\0\0\0\0", NULL stmt, _, err = db.Prepare(`SELECT col FROM test`) if err != nil { t.Fatal(err) @@ -143,7 +122,7 @@ func TestStmt(t *testing.T) { defer stmt.Close() if stmt.Step() { - if got := stmt.ColumnType(0); got != INTEGER { + if got := stmt.ColumnType(0); got != sqlite3.INTEGER { t.Errorf("got %v, want INTEGER", got) } if got := stmt.ColumnBool(0); got != false { @@ -164,28 +143,7 @@ func TestStmt(t *testing.T) { } if stmt.Step() { - if got := stmt.ColumnType(0); got != NULL { - t.Errorf("got %v, want NULL", got) - } - if got := stmt.ColumnBool(0); got != false { - t.Errorf("got %v, want false", got) - } - if got := stmt.ColumnInt(0); got != 0 { - t.Errorf("got %v, want zero", got) - } - if got := stmt.ColumnFloat(0); got != 0 { - t.Errorf("got %v, want zero", got) - } - if got := stmt.ColumnText(0); got != "" { - t.Errorf("got %q, want empty", got) - } - if got := stmt.ColumnBlob(0, nil); got != nil { - t.Errorf("got %q, want nil", got) - } - } - - if stmt.Step() { - if got := stmt.ColumnType(0); got != INTEGER { + if got := stmt.ColumnType(0); got != sqlite3.INTEGER { t.Errorf("got %v, want INTEGER", got) } if got := stmt.ColumnBool(0); got != true { @@ -206,7 +164,7 @@ func TestStmt(t *testing.T) { } if stmt.Step() { - if got := stmt.ColumnType(0); got != INTEGER { + if got := stmt.ColumnType(0); got != sqlite3.INTEGER { t.Errorf("got %v, want INTEGER", got) } if got := stmt.ColumnBool(0); got != true { @@ -227,7 +185,7 @@ func TestStmt(t *testing.T) { } if stmt.Step() { - if got := stmt.ColumnType(0); got != FLOAT { + if got := stmt.ColumnType(0); got != sqlite3.FLOAT { t.Errorf("got %v, want FLOAT", got) } if got := stmt.ColumnBool(0); got != true { @@ -248,7 +206,7 @@ func TestStmt(t *testing.T) { } if stmt.Step() { - if got := stmt.ColumnType(0); got != NULL { + if got := stmt.ColumnType(0); got != sqlite3.NULL { t.Errorf("got %v, want NULL", got) } if got := stmt.ColumnBool(0); got != false { @@ -269,7 +227,7 @@ func TestStmt(t *testing.T) { } if stmt.Step() { - if got := stmt.ColumnType(0); got != TEXT { + if got := stmt.ColumnType(0); got != sqlite3.TEXT { t.Errorf("got %v, want TEXT", got) } if got := stmt.ColumnBool(0); got != false { @@ -290,7 +248,7 @@ func TestStmt(t *testing.T) { } if stmt.Step() { - if got := stmt.ColumnType(0); got != TEXT { + if got := stmt.ColumnType(0); got != sqlite3.TEXT { t.Errorf("got %v, want TEXT", got) } if got := stmt.ColumnBool(0); got != false { @@ -311,7 +269,7 @@ func TestStmt(t *testing.T) { } if stmt.Step() { - if got := stmt.ColumnType(0); got != BLOB { + if got := stmt.ColumnType(0); got != sqlite3.BLOB { t.Errorf("got %v, want BLOB", got) } if got := stmt.ColumnBool(0); got != false { @@ -332,7 +290,7 @@ func TestStmt(t *testing.T) { } if stmt.Step() { - if got := stmt.ColumnType(0); got != NULL { + if got := stmt.ColumnType(0); got != sqlite3.NULL { t.Errorf("got %v, want NULL", got) } if got := stmt.ColumnBool(0); got != false { @@ -352,24 +310,66 @@ func TestStmt(t *testing.T) { } } - err = stmt.Close() - if err != nil { + if stmt.Step() { + if got := stmt.ColumnType(0); got != sqlite3.BLOB { + t.Errorf("got %v, want BLOB", got) + } + if got := stmt.ColumnBool(0); got != false { + t.Errorf("got %v, want false", got) + } + if got := stmt.ColumnInt(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnFloat(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnText(0); got != "\x00\x00\x00\x00" { + t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got) + } + if got := stmt.ColumnBlob(0, nil); string(got) != "\x00\x00\x00\x00" { + t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got) + } + } + + if stmt.Step() { + if got := stmt.ColumnType(0); got != sqlite3.NULL { + t.Errorf("got %v, want NULL", got) + } + if got := stmt.ColumnBool(0); got != false { + t.Errorf("got %v, want false", got) + } + if got := stmt.ColumnInt(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnFloat(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnText(0); got != "" { + t.Errorf("got %q, want empty", got) + } + if got := stmt.ColumnBlob(0, nil); got != nil { + t.Errorf("got %q, want nil", got) + } + } + + if err := stmt.Close(); err != nil { t.Fatal(err) } - err = db.Close() - if err != nil { + if err := db.Close(); err != nil { t.Fatal(err) } } func TestStmt_Close(t *testing.T) { - var stmt *Stmt + var stmt *sqlite3.Stmt stmt.Close() } func TestStmt_BindName(t *testing.T) { - db, err := Open(":memory:") + t.Parallel() + + db, err := sqlite3.Open(":memory:") if err != nil { t.Fatal(err) } @@ -401,7 +401,9 @@ func TestStmt_BindName(t *testing.T) { } func TestStmt_Time(t *testing.T) { - db, err := Open(":memory:") + t.Parallel() + + db, err := sqlite3.Open(":memory:") if err != nil { t.Fatal(err) } @@ -414,44 +416,44 @@ func TestStmt_Time(t *testing.T) { defer stmt.Close() reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600)) - err = stmt.BindTime(1, reference, TimeFormat4) + err = stmt.BindTime(1, reference, sqlite3.TimeFormat4) if err != nil { t.Fatal(err) } - err = stmt.BindTime(2, reference, TimeFormatUnixMilli) + err = stmt.BindTime(2, reference, sqlite3.TimeFormatUnixMilli) if err != nil { t.Fatal(err) } - err = stmt.BindTime(3, reference, TimeFormatJulianDay) + err = stmt.BindTime(3, reference, sqlite3.TimeFormatJulianDay) if err != nil { t.Fatal(err) } if now := time.Now(); stmt.Step() { - if got := stmt.ColumnTime(0, TimeFormatAuto); !reference.Equal(got) { + if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); !reference.Equal(got) { t.Errorf("got %v, want %v", got, reference) } - if got := stmt.ColumnTime(1, TimeFormatAuto); !reference.Equal(got) { + if got := stmt.ColumnTime(1, sqlite3.TimeFormatAuto); !reference.Equal(got) { t.Errorf("got %v, want %v", got, reference) } - if got := stmt.ColumnTime(2, TimeFormatAuto); reference.Sub(got) > time.Millisecond { + if got := stmt.ColumnTime(2, sqlite3.TimeFormatAuto); reference.Sub(got) > time.Millisecond { t.Errorf("got %v, want %v", got, reference) } - if got := stmt.ColumnTime(3, TimeFormatAuto); now.Sub(got) > time.Second { + if got := stmt.ColumnTime(3, sqlite3.TimeFormatAuto); now.Sub(got) > time.Second { t.Errorf("got %v, want %v", got, now) } - if got := stmt.ColumnTime(4, TimeFormatAuto); now.Sub(got) > time.Second { + if got := stmt.ColumnTime(4, sqlite3.TimeFormatAuto); now.Sub(got) > time.Second { t.Errorf("got %v, want %v", got, now) } - if got := stmt.ColumnTime(5, TimeFormatAuto); now.Sub(got) > time.Millisecond { + if got := stmt.ColumnTime(5, sqlite3.TimeFormatAuto); now.Sub(got) > time.Millisecond { t.Errorf("got %v, want %v", got, now) } - if got := stmt.ColumnTime(6, TimeFormatAuto); got != (time.Time{}) { + if got := stmt.ColumnTime(6, sqlite3.TimeFormatAuto); got != (time.Time{}) { t.Errorf("got %v, want zero", got) } - if got := stmt.ColumnTime(7, TimeFormatAuto); got != (time.Time{}) { + if got := stmt.ColumnTime(7, sqlite3.TimeFormatAuto); got != (time.Time{}) { t.Errorf("got %v, want zero", got) } if stmt.Err() == nil { diff --git a/time.go b/time.go index 8f12a9e..4e88d93 100644 --- a/time.go +++ b/time.go @@ -14,6 +14,9 @@ import ( // https://www.sqlite.org/lang_datefunc.html type TimeFormat string +// TimeFormats recognized by SQLite to encode/decode time values. +// +// https://www.sqlite.org/lang_datefunc.html const ( TimeFormatDefault TimeFormat = "" // time.RFC3339Nano @@ -43,9 +46,9 @@ const ( TimeFormatJulianDay TimeFormat = "julianday" TimeFormatUnix TimeFormat = "unixepoch" TimeFormatUnixFrac TimeFormat = "unixepoch_frac" - TimeFormatUnixMilli TimeFormat = "unixepoch_milli" - TimeFormatUnixMicro TimeFormat = "unixepoch_micro" - TimeFormatUnixNano TimeFormat = "unixepoch_nano" + TimeFormatUnixMilli TimeFormat = "unixepoch_milli" // not an SQLite format + TimeFormatUnixMicro TimeFormat = "unixepoch_micro" // not an SQLite format + TimeFormatUnixNano TimeFormat = "unixepoch_nano" // not an SQLite format // Auto TimeFormatAuto TimeFormat = "auto" @@ -54,7 +57,10 @@ const ( // Encode encodes a time value using this format. // // [TimeFormatDefault] and [TimeFormatAuto] encode using [time.RFC3339Nano], -// preserving timezone, with nanosecond accuracy. +// with nanosecond accuracy, and preserving timezone. +// +// Formats that don't record the timezone +// convert time values to UTC before encoding. // // https://www.sqlite.org/lang_datefunc.html func (f TimeFormat) Encode(t time.Time) any { @@ -65,7 +71,7 @@ func (f TimeFormat) Encode(t time.Time) any { case TimeFormatUnix: return t.Unix() case TimeFormatUnixFrac: - return float64(t.Unix()) + float64(t.Nanosecond())/1_000_000_000 + return float64(t.Unix()) + float64(t.Nanosecond())*1e-9 case TimeFormatUnixMilli: return t.UnixMilli() case TimeFormatUnixMicro: @@ -77,7 +83,9 @@ func (f TimeFormat) Encode(t time.Time) any { f = time.RFC3339Nano } // SQLite assumes UTC if unspecified. - if !strings.Contains(string(f), "Z07") && !strings.Contains(string(f), "-07") { + if !strings.Contains(string(f), "MST") && + !strings.Contains(string(f), "Z07") && + !strings.Contains(string(f), "-07") { t = t.UTC() } return t.Format(string(f)) @@ -85,6 +93,9 @@ func (f TimeFormat) Encode(t time.Time) any { // Decode decodes a time value using this format. // +// Decoding of SQLite recognized formats is lenient: +// timezones and fractional seconds are always optional. +// // https://www.sqlite.org/lang_datefunc.html func (f TimeFormat) Decode(v any) (time.Time, error) { switch f { @@ -112,7 +123,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) { switch v := v.(type) { case float64: sec, frac := math.Modf(v) - nsec := math.Floor(frac * 1_000_000_000) + nsec := math.Floor(frac * 1e9) return time.Unix(int64(sec), int64(nsec)), nil case int64: return time.Unix(v, 0), nil @@ -130,7 +141,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) { } switch v := v.(type) { case float64: - return time.UnixMilli(int64(v)), nil + return time.UnixMilli(int64(math.Floor(v))), nil case int64: return time.UnixMilli(int64(v)), nil default: @@ -147,7 +158,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) { } switch v := v.(type) { case float64: - return time.UnixMicro(int64(v)), nil + return time.UnixMicro(int64(math.Floor(v))), nil case int64: return time.UnixMicro(int64(v)), nil default: @@ -164,7 +175,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) { } switch v := v.(type) { case float64: - return time.Unix(0, int64(v)), nil + return time.Unix(0, int64(math.Floor(v))), nil case int64: return time.Unix(0, int64(v)), nil default: diff --git a/time_test.go b/time_test.go index af19cc5..5893a4c 100644 --- a/time_test.go +++ b/time_test.go @@ -7,6 +7,8 @@ import ( ) func TestTimeFormat_Encode(t *testing.T) { + t.Parallel() + reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600)) tests := []struct { @@ -33,6 +35,8 @@ func TestTimeFormat_Encode(t *testing.T) { } func TestTimeFormat_Decode(t *testing.T) { + t.Parallel() + reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600)) reftime := time.Date(2000, 1, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600)) diff --git a/util_test.go b/util_test.go index 391e06f..04e6259 100644 --- a/util_test.go +++ b/util_test.go @@ -5,6 +5,8 @@ import ( ) func Test_emptyStatement(t *testing.T) { + t.Parallel() + tests := []struct { name string stmt string