Zero blobs, tests, documentation

This commit is contained in:
Nuno Cruces
2023-02-22 14:19:56 +00:00
parent b749b32a62
commit e91758c6a4
22 changed files with 487 additions and 395 deletions

1
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1 @@
custom: https://www.paypal.com/donate/buttons/manage/33P59ELZWGMK6

View File

@@ -18,4 +18,9 @@ Roadmap:
- [x] provide a simple `database/sql` driver
- [x] file locking, compatible with SQLite on Windows/Unix
- [ ] shared memory, compatible with SQLite on Windows/Unix
- needed for improved WAL mode
- needed for improved WAL mode
- [ ] advanced features
- [ ] incremental BLOB I/O
- [ ] online backup
- [ ] session extension
- [ ] snapshot

6
blob.go Normal file
View File

@@ -0,0 +1,6 @@
package sqlite3
// ZeroBlob represents a zero-filled, length n BLOB
// that can be used as an argument to
// [database.sql.DB.Exec] and similar methods.
type ZeroBlob int64

View File

@@ -2,212 +2,10 @@ package sqlite3
import (
"bytes"
"context"
"errors"
"math"
"strings"
"testing"
)
func TestConn_Close(t *testing.T) {
var conn *Conn
conn.Close()
}
func TestConn_Close_BUSY(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`BEGIN`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
err = db.Close()
if err == nil {
t.Fatal("want error")
}
var serr *Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != BUSY {
t.Errorf("got %d, want sqlite3.BUSY", rc)
}
var terr interface{ Temporary() bool }
if !errors.As(err, &terr) || !terr.Temporary() {
t.Error("not temporary", err)
}
if got := err.Error(); got != `sqlite3: database is locked: unable to close due to unfinalized statements or unfinished backups` {
t.Error("got message: ", got)
}
}
func TestConn_SetInterrupt(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx.Done())
// Interrupt doesn't interrupt this.
err = db.Exec(`SELECT 1`)
if err != nil {
t.Fatal(err)
}
db.SetInterrupt(nil)
stmt, _, err := db.Prepare(`
WITH RECURSIVE
fibonacci (curr, next)
AS (
SELECT 0, 1
UNION ALL
SELECT next, curr + next FROM fibonacci
LIMIT 1e6
)
SELECT min(curr) FROM fibonacci
`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
cancel()
db.SetInterrupt(ctx.Done())
var serr *Error
// Interrupting works.
err = stmt.Exec()
if err != nil {
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
}
// Interrupting sticks.
err = db.Exec(`SELECT 1`)
if err != nil {
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
}
db.SetInterrupt(nil)
// Interrupting can be cleared.
err = db.Exec(`SELECT 1`)
if err != nil {
t.Fatal(err)
}
}
func TestConn_Prepare_Empty(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(``)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt != nil {
t.Error("want nil")
}
}
func TestConn_Prepare_Tail(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, tail, err := db.Prepare(`SELECT 1; -- HERE`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if !strings.Contains(tail, "-- HERE") {
t.Errorf("got %q", tail)
}
}
func TestConn_Prepare_Invalid(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
var serr *Error
_, _, err = db.Prepare(`SELECT`)
if err == nil {
t.Fatal("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message: ", got)
}
_, _, err = db.Prepare(`SELECT * FRM sqlite_schema`)
if err == nil {
t.Fatal("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.ERROR", err)
}
if rc := serr.Code(); rc != ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := serr.SQL(); got != `FRM sqlite_schema` {
t.Error("got SQL: ", got)
}
if got := serr.Error(); got != `sqlite3: SQL logic error: near "FRM": syntax error` {
t.Error("got message: ", got)
}
}
func TestConn_new(t *testing.T) {
t.Parallel()

View File

@@ -197,6 +197,7 @@ const (
NULL Datatype = 5
)
// String implements the [fmt.Stringer] interface.
func (t Datatype) String() string {
const name = "INTEGERFLOATTEXTBLOBNULL"
switch t {

View File

@@ -3,6 +3,8 @@ package sqlite3
import "testing"
func TestDatatype_String(t *testing.T) {
t.Parallel()
tests := []struct {
data Datatype
want string

View File

@@ -186,8 +186,9 @@ type stmt struct {
var (
// Ensure these interfaces are implemented:
_ driver.StmtExecContext = stmt{}
_ driver.StmtQueryContext = stmt{}
_ driver.StmtExecContext = stmt{}
_ driver.StmtQueryContext = stmt{}
_ driver.NamedValueChecker = stmt{}
)
func (s stmt) Close() error {
@@ -256,6 +257,8 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
switch a := arg.Value.(type) {
case bool:
err = s.stmt.BindBool(id, a)
case int:
err = s.stmt.BindInt(id, a)
case int64:
err = s.stmt.BindInt64(id, a)
case float64:
@@ -264,6 +267,8 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
err = s.stmt.BindText(id, a)
case []byte:
err = s.stmt.BindBlob(id, a)
case sqlite3.ZeroBlob:
err = s.stmt.BindZeroBlob(id, int64(a))
case time.Time:
err = s.stmt.BindText(id, a.Format(time.RFC3339Nano))
case nil:
@@ -280,6 +285,16 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
return rows{ctx, s.stmt, s.conn}, nil
}
func (s stmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) {
case bool, int, int64, float64, string, []byte,
sqlite3.ZeroBlob, time.Time, nil:
return nil
default:
return driver.ErrSkip
}
}
type result struct{ lastInsertId, rowsAffected int64 }
func (r result) LastInsertId() (int64, error) {

View File

@@ -157,7 +157,7 @@ func Test_BeginTx(t *testing.T) {
t.Fatal(err)
}
_, err = tx1.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
_, err = tx1.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err == nil {
t.Error("want error")
}
@@ -310,3 +310,39 @@ func Test_QueryRow_blob_null(t *testing.T) {
}
}
}
func Test_ZeroBlob(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
_, err = conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
_, err = conn.ExecContext(ctx, `INSERT INTO test(col) VALUES(?)`, sqlite3.ZeroBlob(4))
if err != nil {
t.Fatal(err)
}
var got []byte
err = conn.QueryRowContext(ctx, `SELECT col FROM test`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if string(got) != "\x00\x00\x00\x00" {
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
}
}

13
stmt.go
View File

@@ -222,6 +222,19 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
return s.c.error(r[0])
}
// BindZeroBlob binds a zero-filled, length n BLOB to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindZeroBlob(param int, n int64) error {
r, err := s.c.api.bindZeroBlob.Call(s.c.ctx,
uint64(s.handle), uint64(param), uint64(n))
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
// BindNull binds a NULL to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//

View File

@@ -1,10 +1,13 @@
package driver_test
package bradfitz
// Adapted from: https://github.com/bradfitz/go-sql-test
import (
"database/sql"
"fmt"
"math/rand"
"path/filepath"
"sync"
"testing"
_ "github.com/ncruces/go-sqlite3/driver"
@@ -37,11 +40,6 @@ func (t params) mustExec(sql string, args ...interface{}) sql.Result {
return res
}
// q converts "?" characters to $1, $2, $n on postgres, :1, :2, :n on Oracle
func (t params) q(sql string) string {
return sql
}
func (sqliteDB) RunTest(t *testing.T, fn func(params)) {
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "foo.db"))
if err != nil {
@@ -53,21 +51,17 @@ func (sqliteDB) RunTest(t *testing.T, fn func(params)) {
}
}
func sqlBlobParam(t params, size int) string {
return fmt.Sprintf("blob[%d]", size)
}
func TestBlobs_SQLite(t *testing.T) { sqlite.RunTest(t, testBlobs) }
func testBlobs(t params) {
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
t.mustExec("create table " + TablePrefix + "foo (id integer primary key, bar " + sqlBlobParam(t, 16) + ")")
t.mustExec(t.q("insert into "+TablePrefix+"foo (id, bar) values(?,?)"), 0, blob)
t.mustExec("create table " + TablePrefix + "foo (id integer primary key, bar blob)")
t.mustExec("insert into "+TablePrefix+"foo (id, bar) values(?,?)", 0, blob)
want := fmt.Sprintf("%x", blob)
b := make([]byte, 16)
err := t.QueryRow(t.q("select bar from "+TablePrefix+"foo where id = ?"), 0).Scan(&b)
err := t.QueryRow("select bar from "+TablePrefix+"foo where id = ?", 0).Scan(&b)
got := fmt.Sprintf("%x", b)
if err != nil {
t.Errorf("[]byte scan: %v", err)
@@ -75,7 +69,7 @@ func testBlobs(t params) {
t.Errorf("for []byte, got %q; want %q", got, want)
}
err = t.QueryRow(t.q("select bar from "+TablePrefix+"foo where id = ?"), 0).Scan(&got)
err = t.QueryRow("select bar from "+TablePrefix+"foo where id = ?", 0).Scan(&got)
want = string(blob)
if err != nil {
t.Errorf("string scan: %v", err)
@@ -88,14 +82,13 @@ func TestManyQueryRow_SQLite(t *testing.T) { sqlite.RunTest(t, testManyQueryRow)
func testManyQueryRow(t params) {
if testing.Short() {
t.Logf("skipping in short mode")
return
t.Skip("skipping in short mode")
}
t.mustExec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))")
t.mustExec(t.q("insert into "+TablePrefix+"foo (id, name) values(?,?)"), 1, "bob")
t.mustExec("insert into "+TablePrefix+"foo (id, name) values(?,?)", 1, "bob")
var name string
for i := 0; i < 10000; i++ {
err := t.QueryRow(t.q("select name from "+TablePrefix+"foo where id = ?"), 1).Scan(&name)
err := t.QueryRow("select name from "+TablePrefix+"foo where id = ?", 1).Scan(&name)
if err != nil || name != "bob" {
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
}
@@ -116,12 +109,12 @@ func testTxQuery(t params) {
t.Logf("cannot drop table "+TablePrefix+"foo: %s", err)
}
_, err = tx.Exec(t.q("insert into "+TablePrefix+"foo (id, name) values(?,?)"), 1, "bob")
_, err = tx.Exec("insert into "+TablePrefix+"foo (id, name) values(?,?)", 1, "bob")
if err != nil {
t.Fatal(err)
}
r, err := tx.Query(t.q("select name from "+TablePrefix+"foo where id = ?"), 1)
r, err := tx.Query("select name from "+TablePrefix+"foo where id = ?", 1)
if err != nil {
t.Fatal(err)
}
@@ -145,8 +138,7 @@ func TestPreparedStmt_SQLite(t *testing.T) { sqlite.RunTest(t, testPreparedStmt)
func testPreparedStmt(t params) {
if testing.Short() {
t.Logf("skipping in short mode")
return
t.Skip("skipping in short mode")
}
t.mustExec("CREATE TABLE " + TablePrefix + "t (count INT)")
@@ -154,7 +146,7 @@ func testPreparedStmt(t params) {
if err != nil {
t.Fatalf("prepare 1: %v", err)
}
ins, err := t.Prepare(t.q("INSERT INTO " + TablePrefix + "t (count) VALUES (?)"))
ins, err := t.Prepare("INSERT INTO " + TablePrefix + "t (count) VALUES (?)")
if err != nil {
t.Fatalf("prepare 2: %v", err)
}
@@ -166,12 +158,11 @@ func testPreparedStmt(t params) {
}
const nRuns = 10
ch := make(chan bool)
var wg sync.WaitGroup
for i := 0; i < nRuns; i++ {
wg.Add(1)
go func() {
defer func() {
ch <- true
}()
defer wg.Done()
for j := 0; j < 10; j++ {
count := 0
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
@@ -185,7 +176,5 @@ func testPreparedStmt(t params) {
}
}()
}
for i := 0; i < nRuns; i++ {
<-ch
}
wg.Wait()
}

View File

@@ -1,4 +1,4 @@
package compile_empty
package compile
import (
"testing"

View File

@@ -1,4 +1,4 @@
package compile_empty
package compile
import (
"testing"
@@ -6,7 +6,7 @@ import (
"github.com/ncruces/go-sqlite3"
)
func TestCompile_empty(t *testing.T) {
func TestCompile_missing(t *testing.T) {
sqlite3.Path = "sqlite3.wasm"
_, err := sqlite3.Open(":memory:")
if err == nil {

View File

@@ -1,4 +1,4 @@
package compile_empty
package compile
import (
"testing"
@@ -6,7 +6,7 @@ import (
"github.com/ncruces/go-sqlite3"
)
func TestCompile_empty(t *testing.T) {
func TestCompile_nil(t *testing.T) {
_, err := sqlite3.Open(":memory:")
if err == nil {
t.Error("want error")

229
tests/conn_test.go Normal file
View File

@@ -0,0 +1,229 @@
package tests
import (
"context"
"errors"
"strings"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestConn_Open_dir(t *testing.T) {
t.Parallel()
_, err := sqlite3.Open(".")
if err == nil {
t.Fatal("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.CANTOPEN {
t.Errorf("got %d, want sqlite3.CANTOPEN", rc)
}
if got := err.Error(); got != `sqlite3: unable to open database file` {
t.Error("got message: ", got)
}
}
func TestConn_Close(t *testing.T) {
var conn *sqlite3.Conn
conn.Close()
}
func TestConn_Close_BUSY(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`BEGIN`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
err = db.Close()
if err == nil {
t.Fatal("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.BUSY {
t.Errorf("got %d, want sqlite3.BUSY", rc)
}
var terr interface{ Temporary() bool }
if !errors.As(err, &terr) || !terr.Temporary() {
t.Error("not temporary", err)
}
if got := err.Error(); got != `sqlite3: database is locked: unable to close due to unfinalized statements or unfinished backups` {
t.Error("got message: ", got)
}
}
func TestConn_SetInterrupt(t *testing.T) {
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx.Done())
// Interrupt doesn't interrupt this.
err = db.Exec(`SELECT 1`)
if err != nil {
t.Fatal(err)
}
db.SetInterrupt(nil)
stmt, _, err := db.Prepare(`
WITH RECURSIVE
fibonacci (curr, next)
AS (
SELECT 0, 1
UNION ALL
SELECT next, curr + next FROM fibonacci
LIMIT 1e6
)
SELECT min(curr) FROM fibonacci
`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
cancel()
db.SetInterrupt(ctx.Done())
var serr *sqlite3.Error
// Interrupting works.
err = stmt.Exec()
if err != nil {
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
}
// Interrupting sticks.
err = db.Exec(`SELECT 1`)
if err != nil {
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
}
db.SetInterrupt(nil)
// Interrupting can be cleared.
err = db.Exec(`SELECT 1`)
if err != nil {
t.Fatal(err)
}
}
func TestConn_Prepare_empty(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(``)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt != nil {
t.Error("want nil")
}
}
func TestConn_Prepare_tail(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, tail, err := db.Prepare(`SELECT 1; -- HERE`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if !strings.Contains(tail, "-- HERE") {
t.Errorf("got %q", tail)
}
}
func TestConn_Prepare_invalid(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
var serr *sqlite3.Error
_, _, err = db.Prepare(`SELECT`)
if err == nil {
t.Fatal("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message: ", got)
}
_, _, err = db.Prepare(`SELECT * FRM sqlite_schema`)
if err == nil {
t.Fatal("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.ERROR", err)
}
if rc := serr.Code(); rc != sqlite3.ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := serr.SQL(); got != `FRM sqlite_schema` {
t.Error("got SQL: ", got)
}
if got := serr.Error(); got != `sqlite3: SQL logic error: near "FRM": syntax error` {
t.Error("got message: ", got)
}
}

View File

@@ -17,6 +17,8 @@ func TestDB_file(t *testing.T) {
}
func testDB(t *testing.T, name string) {
t.Parallel()
db, err := sqlite3.Open(name)
if err != nil {
t.Fatal(err)

View File

@@ -1,26 +0,0 @@
package tests
import (
"errors"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestDir(t *testing.T) {
_, err := sqlite3.Open(".")
if err == nil {
t.Fatal("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.CANTOPEN {
t.Errorf("got %d, want sqlite3.CANTOPEN", rc)
}
if got := err.Error(); got != `sqlite3: unable to open database file` {
t.Error("got message: ", got)
}
}

View File

@@ -10,6 +10,8 @@ import (
)
func TestDriver(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

View File

@@ -22,7 +22,7 @@ func TestParallel(t *testing.T) {
func TestMultiProcess(t *testing.T) {
if testing.Short() {
return
t.Skip()
}
name := filepath.Join(t.TempDir(), "test.db")
@@ -57,7 +57,7 @@ func TestMultiProcess(t *testing.T) {
func TestChildProcess(t *testing.T) {
name := os.Getenv("TestMultiProcess_dbname")
if name == "" || testing.Short() {
return
t.SkipNow()
}
testParallel(t, name, 1000)

View File

@@ -1,15 +1,17 @@
package sqlite3
package tests
import (
"math"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
)
func TestStmt(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
@@ -30,103 +32,80 @@ func TestStmt(t *testing.T) {
t.Errorf("got %d, want 1", got)
}
err = stmt.BindBool(1, false)
if err != nil {
if err := stmt.BindBool(1, false); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
if err := stmt.BindBool(1, true); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.ClearBindings()
if err != nil {
if err := stmt.BindInt(1, 2); err != nil {
t.Fatal(err)
}
if err = stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
if err := stmt.BindFloat(1, math.Pi); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.BindBool(1, true)
if err != nil {
if err := stmt.BindNull(1); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
if err := stmt.BindText(1, ""); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.BindInt(1, 2)
if err != nil {
if err := stmt.BindText(1, "text"); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
if err := stmt.BindBlob(1, []byte("blob")); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.BindFloat(1, math.Pi)
if err != nil {
if err := stmt.BindBlob(1, nil); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
if err := stmt.BindZeroBlob(1, 4); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.BindNull(1)
if err != nil {
if err := stmt.ClearBindings(); err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindText(1, "")
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindText(1, "text")
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindBlob(1, []byte("blob"))
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindBlob(1, nil)
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
@@ -135,7 +114,7 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
// The table should have: 0, NULL, 1, 2, π, NULL, "", "text", `blob`, NULL
// The table should have: 0, 1, 2, π, NULL, "", "text", "blob", NULL, "\0\0\0\0", NULL
stmt, _, err = db.Prepare(`SELECT col FROM test`)
if err != nil {
t.Fatal(err)
@@ -143,7 +122,7 @@ func TestStmt(t *testing.T) {
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnType(0); got != INTEGER {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != false {
@@ -164,28 +143,7 @@ func TestStmt(t *testing.T) {
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != INTEGER {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != true {
@@ -206,7 +164,7 @@ func TestStmt(t *testing.T) {
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != INTEGER {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != true {
@@ -227,7 +185,7 @@ func TestStmt(t *testing.T) {
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != FLOAT {
if got := stmt.ColumnType(0); got != sqlite3.FLOAT {
t.Errorf("got %v, want FLOAT", got)
}
if got := stmt.ColumnBool(0); got != true {
@@ -248,7 +206,7 @@ func TestStmt(t *testing.T) {
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != NULL {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
@@ -269,7 +227,7 @@ func TestStmt(t *testing.T) {
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != TEXT {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnBool(0); got != false {
@@ -290,7 +248,7 @@ func TestStmt(t *testing.T) {
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != TEXT {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnBool(0); got != false {
@@ -311,7 +269,7 @@ func TestStmt(t *testing.T) {
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != BLOB {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnBool(0); got != false {
@@ -332,7 +290,7 @@ func TestStmt(t *testing.T) {
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != NULL {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
@@ -352,24 +310,66 @@ func TestStmt(t *testing.T) {
}
}
err = stmt.Close()
if err != nil {
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "\x00\x00\x00\x00" {
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "\x00\x00\x00\x00" {
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
}
if err := stmt.Close(); err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
if err := db.Close(); err != nil {
t.Fatal(err)
}
}
func TestStmt_Close(t *testing.T) {
var stmt *Stmt
var stmt *sqlite3.Stmt
stmt.Close()
}
func TestStmt_BindName(t *testing.T) {
db, err := Open(":memory:")
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
@@ -401,7 +401,9 @@ func TestStmt_BindName(t *testing.T) {
}
func TestStmt_Time(t *testing.T) {
db, err := Open(":memory:")
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
@@ -414,44 +416,44 @@ func TestStmt_Time(t *testing.T) {
defer stmt.Close()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
err = stmt.BindTime(1, reference, TimeFormat4)
err = stmt.BindTime(1, reference, sqlite3.TimeFormat4)
if err != nil {
t.Fatal(err)
}
err = stmt.BindTime(2, reference, TimeFormatUnixMilli)
err = stmt.BindTime(2, reference, sqlite3.TimeFormatUnixMilli)
if err != nil {
t.Fatal(err)
}
err = stmt.BindTime(3, reference, TimeFormatJulianDay)
err = stmt.BindTime(3, reference, sqlite3.TimeFormatJulianDay)
if err != nil {
t.Fatal(err)
}
if now := time.Now(); stmt.Step() {
if got := stmt.ColumnTime(0, TimeFormatAuto); !reference.Equal(got) {
if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); !reference.Equal(got) {
t.Errorf("got %v, want %v", got, reference)
}
if got := stmt.ColumnTime(1, TimeFormatAuto); !reference.Equal(got) {
if got := stmt.ColumnTime(1, sqlite3.TimeFormatAuto); !reference.Equal(got) {
t.Errorf("got %v, want %v", got, reference)
}
if got := stmt.ColumnTime(2, TimeFormatAuto); reference.Sub(got) > time.Millisecond {
if got := stmt.ColumnTime(2, sqlite3.TimeFormatAuto); reference.Sub(got) > time.Millisecond {
t.Errorf("got %v, want %v", got, reference)
}
if got := stmt.ColumnTime(3, TimeFormatAuto); now.Sub(got) > time.Second {
if got := stmt.ColumnTime(3, sqlite3.TimeFormatAuto); now.Sub(got) > time.Second {
t.Errorf("got %v, want %v", got, now)
}
if got := stmt.ColumnTime(4, TimeFormatAuto); now.Sub(got) > time.Second {
if got := stmt.ColumnTime(4, sqlite3.TimeFormatAuto); now.Sub(got) > time.Second {
t.Errorf("got %v, want %v", got, now)
}
if got := stmt.ColumnTime(5, TimeFormatAuto); now.Sub(got) > time.Millisecond {
if got := stmt.ColumnTime(5, sqlite3.TimeFormatAuto); now.Sub(got) > time.Millisecond {
t.Errorf("got %v, want %v", got, now)
}
if got := stmt.ColumnTime(6, TimeFormatAuto); got != (time.Time{}) {
if got := stmt.ColumnTime(6, sqlite3.TimeFormatAuto); got != (time.Time{}) {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnTime(7, TimeFormatAuto); got != (time.Time{}) {
if got := stmt.ColumnTime(7, sqlite3.TimeFormatAuto); got != (time.Time{}) {
t.Errorf("got %v, want zero", got)
}
if stmt.Err() == nil {

31
time.go
View File

@@ -14,6 +14,9 @@ import (
// https://www.sqlite.org/lang_datefunc.html
type TimeFormat string
// TimeFormats recognized by SQLite to encode/decode time values.
//
// https://www.sqlite.org/lang_datefunc.html
const (
TimeFormatDefault TimeFormat = "" // time.RFC3339Nano
@@ -43,9 +46,9 @@ const (
TimeFormatJulianDay TimeFormat = "julianday"
TimeFormatUnix TimeFormat = "unixepoch"
TimeFormatUnixFrac TimeFormat = "unixepoch_frac"
TimeFormatUnixMilli TimeFormat = "unixepoch_milli"
TimeFormatUnixMicro TimeFormat = "unixepoch_micro"
TimeFormatUnixNano TimeFormat = "unixepoch_nano"
TimeFormatUnixMilli TimeFormat = "unixepoch_milli" // not an SQLite format
TimeFormatUnixMicro TimeFormat = "unixepoch_micro" // not an SQLite format
TimeFormatUnixNano TimeFormat = "unixepoch_nano" // not an SQLite format
// Auto
TimeFormatAuto TimeFormat = "auto"
@@ -54,7 +57,10 @@ const (
// Encode encodes a time value using this format.
//
// [TimeFormatDefault] and [TimeFormatAuto] encode using [time.RFC3339Nano],
// preserving timezone, with nanosecond accuracy.
// with nanosecond accuracy, and preserving timezone.
//
// Formats that don't record the timezone
// convert time values to UTC before encoding.
//
// https://www.sqlite.org/lang_datefunc.html
func (f TimeFormat) Encode(t time.Time) any {
@@ -65,7 +71,7 @@ func (f TimeFormat) Encode(t time.Time) any {
case TimeFormatUnix:
return t.Unix()
case TimeFormatUnixFrac:
return float64(t.Unix()) + float64(t.Nanosecond())/1_000_000_000
return float64(t.Unix()) + float64(t.Nanosecond())*1e-9
case TimeFormatUnixMilli:
return t.UnixMilli()
case TimeFormatUnixMicro:
@@ -77,7 +83,9 @@ func (f TimeFormat) Encode(t time.Time) any {
f = time.RFC3339Nano
}
// SQLite assumes UTC if unspecified.
if !strings.Contains(string(f), "Z07") && !strings.Contains(string(f), "-07") {
if !strings.Contains(string(f), "MST") &&
!strings.Contains(string(f), "Z07") &&
!strings.Contains(string(f), "-07") {
t = t.UTC()
}
return t.Format(string(f))
@@ -85,6 +93,9 @@ func (f TimeFormat) Encode(t time.Time) any {
// Decode decodes a time value using this format.
//
// Decoding of SQLite recognized formats is lenient:
// timezones and fractional seconds are always optional.
//
// https://www.sqlite.org/lang_datefunc.html
func (f TimeFormat) Decode(v any) (time.Time, error) {
switch f {
@@ -112,7 +123,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
switch v := v.(type) {
case float64:
sec, frac := math.Modf(v)
nsec := math.Floor(frac * 1_000_000_000)
nsec := math.Floor(frac * 1e9)
return time.Unix(int64(sec), int64(nsec)), nil
case int64:
return time.Unix(v, 0), nil
@@ -130,7 +141,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
switch v := v.(type) {
case float64:
return time.UnixMilli(int64(v)), nil
return time.UnixMilli(int64(math.Floor(v))), nil
case int64:
return time.UnixMilli(int64(v)), nil
default:
@@ -147,7 +158,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
switch v := v.(type) {
case float64:
return time.UnixMicro(int64(v)), nil
return time.UnixMicro(int64(math.Floor(v))), nil
case int64:
return time.UnixMicro(int64(v)), nil
default:
@@ -164,7 +175,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
switch v := v.(type) {
case float64:
return time.Unix(0, int64(v)), nil
return time.Unix(0, int64(math.Floor(v))), nil
case int64:
return time.Unix(0, int64(v)), nil
default:

View File

@@ -7,6 +7,8 @@ import (
)
func TestTimeFormat_Encode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
@@ -33,6 +35,8 @@ func TestTimeFormat_Encode(t *testing.T) {
}
func TestTimeFormat_Decode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
reftime := time.Date(2000, 1, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))

View File

@@ -5,6 +5,8 @@ import (
)
func Test_emptyStatement(t *testing.T) {
t.Parallel()
tests := []struct {
name string
stmt string