Compare commits

...

10 Commits

Author SHA1 Message Date
Nuno Cruces
71ae26e5c9 Documentation. 2023-02-22 17:51:30 +00:00
Nuno Cruces
e91758c6a4 Zero blobs, tests, documentation 2023-02-22 14:19:56 +00:00
Nuno Cruces
b749b32a62 Unlock tweaks, tests. 2023-02-21 12:56:39 +00:00
Nuno Cruces
3b4df71a94 Time handling. 2023-02-21 04:45:25 +00:00
Nuno Cruces
df687a1c54 Tests. 2023-02-20 14:43:19 +00:00
Edoardo Vacchi
2f5b9837e1 deps: updates wazero to 1.0.0-pre.9
This updates [wazero](https://wazero.io/) to [1.0.0-pre.9][1]. Notably:

* This release includes our last breaking changes before 1.0.0 final:
  * Requires at least Go 1.8
  * Renames `Runtime.InstantiateModuleFromBinary` to `Runtime.Instantiate`
* This release also integrates Go context to limit execution time.
  More details on the [Release Notes][1]
* We are now passing third-party integration test suites: wasi-testsuite,
  TinyGo's, Zig's.

[1]: https://github.com/tetratelabs/wazero/releases/tag/v1.0.0-pre.9

Signed-off-by: Edoardo Vacchi <evacchi@users.noreply.github.com>
2023-02-20 13:32:52 +00:00
Nuno Cruces
c351400be7 Tests. 2023-02-20 13:30:01 +00:00
Nuno Cruces
231d3a0438 Read-only transactions, locking. 2023-02-19 16:16:13 +00:00
Nuno Cruces
2f25e4eedb Bug fixes, optimizations, fuzz testing. 2023-02-19 12:44:26 +00:00
Nuno Cruces
ad27d5d840 Support pragmas, integration test. 2023-02-18 13:15:01 +00:00
38 changed files with 2067 additions and 544 deletions

1
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1 @@
custom: https://www.paypal.com/donate/buttons/manage/33P59ELZWGMK6

View File

@@ -36,5 +36,5 @@ jobs:
uses: ncruces/go-coverage-report@main
if: |
matrix.os == 'ubuntu-latest' &&
github.event_name == 'push'
github.event_name == 'push'
continue-on-error: true

View File

@@ -6,8 +6,7 @@
⚠️ CAUTION ⚠️
This is a WIP.\
DO NOT USE with data you care about.
This is a WIP.
Roadmap:
- [x] build SQLite using `zig cc --target=wasm32-wasi`
@@ -18,4 +17,11 @@ Roadmap:
- [x] provide a simple `database/sql` driver
- [x] file locking, compatible with SQLite on Windows/Unix
- [ ] shared memory, compatible with SQLite on Windows/Unix
- needed for improved WAL mode
- needed for improved WAL mode
- [ ] advanced SQLite features
- [ ] nested transactions
- [ ] incremental BLOB I/O
- [ ] online backup
- [ ] session extension
- [ ] snapshots
- [ ] SQL functions

6
blob.go Normal file
View File

@@ -0,0 +1,6 @@
package sqlite3
// ZeroBlob represents a zero-filled, length n BLOB
// that can be used as an argument to
// [database/sql.DB.Exec] and similar methods.
type ZeroBlob int64

View File

@@ -49,6 +49,10 @@ func (s *sqlite3Runtime) compileModule(ctx context.Context) {
return
}
}
if bin == nil {
s.err = binaryErr
return
}
s.compiled, s.err = s.runtime.CompileModule(ctx, bin)
}

View File

@@ -68,6 +68,8 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
// open blob handles, and/or unfinished backup objects,
// Close will leave the database connection open and return [BUSY].
//
// It is safe to close a nil, zero or closed connection.
//
// https://www.sqlite.org/c3ref/close.html
func (c *Conn) Close() error {
if c == nil || c.handle == 0 {
@@ -179,6 +181,10 @@ func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
//
// https://www.sqlite.org/c3ref/prepare.html
func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) {
if emptyStatement(sql) {
return nil, "", nil
}
defer c.arena.reset()
stmtPtr := c.arena.new(ptrlen)
tailPtr := c.arena.new(ptrlen)

View File

@@ -2,187 +2,10 @@ package sqlite3
import (
"bytes"
"context"
"errors"
"math"
"testing"
)
func TestConn_Close(t *testing.T) {
var conn *Conn
conn.Close()
}
func TestConn_Close_BUSY(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`BEGIN`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
err = db.Close()
if err == nil {
t.Fatal("want error")
}
var serr *Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != BUSY {
t.Errorf("got %d, want sqlite3.BUSY", rc)
}
if got := err.Error(); got != `sqlite3: database is locked: unable to close due to unfinalized statements or unfinished backups` {
t.Error("got message: ", got)
}
}
func TestConn_SetInterrupt(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithCancel(context.TODO())
db.SetInterrupt(ctx.Done())
// Interrupt doesn't interrupt this.
err = db.Exec(`SELECT 1`)
if err != nil {
t.Fatal(err)
}
db.SetInterrupt(nil)
stmt, _, err := db.Prepare(`
WITH RECURSIVE
fibonacci (curr, next)
AS (
SELECT 0, 1
UNION ALL
SELECT next, curr + next FROM fibonacci
LIMIT 1e6
)
SELECT min(curr) FROM fibonacci
`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
cancel()
db.SetInterrupt(ctx.Done())
var serr *Error
// Interrupting works.
err = stmt.Exec()
if err != nil {
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
}
// Interrupting sticks.
err = db.Exec(`SELECT 1`)
if err != nil {
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
}
db.SetInterrupt(nil)
// Interrupting can be cleared.
err = db.Exec(`SELECT 1`)
if err != nil {
t.Fatal(err)
}
}
func TestConn_Prepare_Empty(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(``)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt != nil {
t.Error("want nil")
}
}
func TestConn_Prepare_Invalid(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
var serr *Error
_, _, err = db.Prepare(`SELECT`)
if err == nil {
t.Fatal("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message: ", got)
}
_, _, err = db.Prepare(`SELECT * FRM sqlite_schema`)
if err == nil {
t.Fatal("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.ERROR", err)
}
if rc := serr.Code(); rc != ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := serr.SQL(); got != `FRM sqlite_schema` {
t.Error("got SQL: ", got)
}
if got := serr.Error(); got != `sqlite3: SQL logic error: near "FRM": syntax error` {
t.Error("got message: ", got)
}
}
func TestConn_new(t *testing.T) {
t.Parallel()

View File

@@ -197,6 +197,7 @@ const (
NULL Datatype = 5
)
// String implements the [fmt.Stringer] interface.
func (t Datatype) String() string {
const name = "INTEGERFLOATTEXTBLOBNULL"
switch t {

View File

@@ -3,6 +3,8 @@ package sqlite3
import "testing"
func TestDatatype_String(t *testing.T) {
t.Parallel()
tests := []struct {
data Datatype
want string

View File

@@ -5,7 +5,10 @@ import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"net/url"
"strings"
"time"
"github.com/ncruces/go-sqlite3"
@@ -22,26 +25,56 @@ func (sqlite) Open(name string) (driver.Conn, error) {
if err != nil {
return nil, err
}
// If the database is not in WAL mode,
// use normal locking mode.
journal, err := pragma(c, "journal_mode")
var txBegin string
var pragmas strings.Builder
if _, after, ok := strings.Cut(name, "?"); ok {
query, _ := url.ParseQuery(after)
switch s := query.Get("_txlock"); s {
case "":
txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
txBegin = "BEGIN " + s
default:
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
}
for _, p := range query["_pragma"] {
pragmas.WriteString(`PRAGMA `)
pragmas.WriteString(p)
pragmas.WriteByte(';')
}
}
if pragmas.Len() == 0 {
pragmas.WriteString(`PRAGMA locking_mode=normal;`)
pragmas.WriteString(`PRAGMA busy_timeout=60000;`)
}
err = c.Exec(pragmas.String())
if err != nil {
return nil, err
return nil, fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
if journal != "wal" {
pragma(c, "locking_mode=normal")
}
return conn{c}, nil
return conn{
conn: c,
txBegin: txBegin,
pragmas: pragmas.String(),
}, nil
}
type conn struct{ conn *sqlite3.Conn }
type conn struct {
conn *sqlite3.Conn
pragmas string
txBegin string
txReadOnly bool
}
var (
// Ensure these interfaces are implemented:
_ driver.Validator = conn{}
_ driver.ExecerContext = conn{}
// _ driver.ConnBeginTx = conn{}
// _ driver.SessionResetter = conn{}
_ driver.Validator = conn{}
_ driver.SessionResetter = conn{}
_ driver.ExecerContext = conn{}
_ driver.ConnBeginTx = conn{}
)
func (c conn) Close() error {
@@ -50,12 +83,40 @@ func (c conn) Close() error {
func (c conn) IsValid() bool {
// Pool only normal locking mode connections.
mode, _ := pragma(c.conn, "locking_mode")
return mode == "normal"
stmt, _, err := c.conn.Prepare(`PRAGMA locking_mode`)
if err != nil {
return false
}
defer stmt.Close()
return stmt.Step() && stmt.ColumnText(0) == "normal"
}
func (c conn) ResetSession(ctx context.Context) error {
return c.conn.Exec(c.pragmas)
}
func (c conn) Begin() (driver.Tx, error) {
err := c.conn.Exec(`BEGIN`)
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
switch opts.Isolation {
default:
return nil, isolationErr
case driver.IsolationLevel(sql.LevelDefault):
case driver.IsolationLevel(sql.LevelSerializable):
}
txBegin := c.txBegin
if opts.ReadOnly {
txBegin = `
BEGIN deferred;
PRAGMA query_only=on;
`
}
c.txReadOnly = opts.ReadOnly
err := c.conn.Exec(txBegin)
if err != nil {
return nil, err
}
@@ -63,6 +124,9 @@ func (c conn) Begin() (driver.Tx, error) {
}
func (c conn) Commit() error {
if c.txReadOnly {
return c.Rollback()
}
err := c.conn.Exec(`COMMIT`)
if err != nil {
c.Rollback()
@@ -81,12 +145,14 @@ func (c conn) Prepare(query string) (driver.Stmt, error) {
}
if tail != "" {
// Check if the tail contains any SQL.
s, _, err := c.conn.Prepare(tail)
st, _, err := c.conn.Prepare(tail)
if err != nil {
s.Close()
return nil, err
}
if s != nil {
if st != nil {
s.Close()
st.Close()
return nil, tailErr
}
}
@@ -113,18 +179,6 @@ func (c conn) ExecContext(ctx context.Context, query string, args []driver.Named
}, nil
}
func pragma(c *sqlite3.Conn, pragma string) (string, error) {
stmt, _, err := c.Prepare(`PRAGMA ` + pragma)
if err != nil {
return "", err
}
defer stmt.Close()
if stmt.Step() {
return stmt.ColumnText(0), nil
}
return "", stmt.Err()
}
type stmt struct {
stmt *sqlite3.Stmt
conn *sqlite3.Conn
@@ -132,8 +186,9 @@ type stmt struct {
var (
// Ensure these interfaces are implemented:
_ driver.StmtExecContext = stmt{}
_ driver.StmtQueryContext = stmt{}
_ driver.StmtExecContext = stmt{}
_ driver.StmtQueryContext = stmt{}
_ driver.NamedValueChecker = stmt{}
)
func (s stmt) Close() error {
@@ -141,7 +196,13 @@ func (s stmt) Close() error {
}
func (s stmt) NumInput() int {
return s.stmt.BindCount()
n := s.stmt.BindCount()
for i := 1; i <= n; i++ {
if s.stmt.BindName(i) != "" {
return -1
}
}
return n
}
// Deprecated: use ExecContext instead.
@@ -155,6 +216,8 @@ func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
}
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
// Use QueryContext to setup bindings.
// No need to close rows: that simply resets the statement, exec does the same.
_, err := s.QueryContext(ctx, args)
if err != nil {
return nil, err
@@ -194,6 +257,8 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
switch a := arg.Value.(type) {
case bool:
err = s.stmt.BindBool(id, a)
case int:
err = s.stmt.BindInt(id, a)
case int64:
err = s.stmt.BindInt64(id, a)
case float64:
@@ -202,6 +267,8 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
err = s.stmt.BindText(id, a)
case []byte:
err = s.stmt.BindBlob(id, a)
case sqlite3.ZeroBlob:
err = s.stmt.BindZeroBlob(id, int64(a))
case time.Time:
err = s.stmt.BindText(id, a.Format(time.RFC3339Nano))
case nil:
@@ -218,6 +285,16 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
return rows{ctx, s.stmt, s.conn}, nil
}
func (s stmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) {
case bool, int, int64, float64, string, []byte,
sqlite3.ZeroBlob, time.Time, nil:
return nil
default:
return driver.ErrSkip
}
}
type result struct{ lastInsertId, rowsAffected int64 }
func (r result) LastInsertId() (int64, error) {

348
driver/driver_test.go Normal file
View File

@@ -0,0 +1,348 @@
// Package driver provides a database/sql driver for SQLite.
package driver
import (
"bytes"
"context"
"database/sql"
"errors"
"math"
"path/filepath"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
)
func Test_Open_dir(t *testing.T) {
db, err := sql.Open("sqlite3", ".")
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Conn(context.TODO())
if err == nil {
t.Fatal("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.CANTOPEN {
t.Errorf("got %d, want sqlite3.CANTOPEN", rc)
}
if got := err.Error(); got != `sqlite3: unable to open database file` {
t.Error("got message: ", got)
}
}
func Test_Open_pragma(t *testing.T) {
db, err := sql.Open("sqlite3", "file::memory:?_pragma=busy_timeout(1000)")
if err != nil {
t.Fatal(err)
}
defer db.Close()
var timeout int
err = db.QueryRow(`PRAGMA busy_timeout`).Scan(&timeout)
if err != nil {
t.Fatal(err)
}
if timeout != 1000 {
t.Errorf("got %v, want 1000", timeout)
}
}
func Test_Open_pragma_invalid(t *testing.T) {
db, err := sql.Open("sqlite3", "file::memory:?_pragma=busy_timeout+1000")
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Conn(context.TODO())
if err == nil {
t.Fatal("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: invalid _pragma: sqlite3: SQL logic error: near "1000": syntax error` {
t.Error("got message: ", got)
}
}
func Test_Open_txLock(t *testing.T) {
db, err := sql.Open("sqlite3", "file:"+
filepath.Join(t.TempDir(), "test.db")+
"?_txlock=exclusive&_pragma=busy_timeout(0)")
if err != nil {
t.Fatal(err)
}
defer db.Close()
tx1, err := db.Begin()
if err != nil {
t.Fatal(err)
}
_, err = db.Begin()
if err == nil {
t.Error("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.BUSY {
t.Errorf("got %d, want sqlite3.BUSY", rc)
}
var terr interface{ Temporary() bool }
if !errors.As(err, &terr) || !terr.Temporary() {
t.Error("not temporary", err)
}
if got := err.Error(); got != `sqlite3: database is locked` {
t.Error("got message: ", got)
}
err = tx1.Commit()
if err != nil {
t.Fatal(err)
}
}
func Test_Open_txLock_invalid(t *testing.T) {
db, err := sql.Open("sqlite3", "file::memory:?_txlock=xclusive")
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Conn(context.TODO())
if err == nil {
t.Fatal("want error")
}
if got := err.Error(); got != `sqlite3: invalid _txlock: xclusive` {
t.Error("got message: ", got)
}
}
func Test_BeginTx(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.db"))
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err.Error() != string(isolationErr) {
t.Error("want isolationErr")
}
tx1, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
t.Fatal(err)
}
tx2, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
t.Fatal(err)
}
_, err = tx1.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err == nil {
t.Error("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.READONLY {
t.Errorf("got %d, want sqlite3.READONLY", rc)
}
if got := err.Error(); got != `sqlite3: attempt to write a readonly database` {
t.Error("got message: ", got)
}
err = tx2.Commit()
if err != nil {
t.Fatal(err)
}
err = tx1.Commit()
if err != nil {
t.Fatal(err)
}
}
func Test_Prepare(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, err := db.Prepare(`SELECT 1; -- HERE`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
var serr *sqlite3.Error
_, err = db.Prepare(`SELECT`)
if err == nil {
t.Error("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message: ", got)
}
_, err = db.Prepare(`SELECT 1; SELECT`)
if err == nil {
t.Error("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message: ", got)
}
_, err = db.Prepare(`SELECT 1; SELECT 2`)
if err.Error() != string(tailErr) {
t.Error("want tailErr")
}
}
func Test_QueryRow_named(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
stmt, err := conn.PrepareContext(ctx, `SELECT ?, ?5, :AAA, @AAA, $AAA`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
date := time.Now()
row := stmt.QueryRow(true, sql.Named("AAA", math.Pi), nil /*3*/, nil /*4*/, date /*5*/)
var first bool
var fifth time.Time
var colon, at, dollar float32
err = row.Scan(&first, &fifth, &colon, &at, &dollar)
if err != nil {
t.Fatal(err)
}
if first != true {
t.Errorf("want true, got %v", first)
}
if colon != math.Pi {
t.Errorf("want π, got %v", colon)
}
if at != math.Pi {
t.Errorf("want π, got %v", at)
}
if dollar != math.Pi {
t.Errorf("want π, got %v", dollar)
}
if !fifth.Equal(date) {
t.Errorf("want %v, got %v", date, fifth)
}
}
func Test_QueryRow_blob_null(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
rows, err := db.Query(`
SELECT NULL UNION ALL
SELECT x'cafe' UNION ALL
SELECT x'babe' UNION ALL
SELECT NULL
`)
if err != nil {
t.Fatal(err)
}
want := [][]byte{nil, {0xca, 0xfe}, {0xba, 0xbe}, nil}
for i := 0; rows.Next(); i++ {
var buf []byte
err = rows.Scan(&buf)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf, want[i]) {
t.Errorf("got %q, want %q", buf, want[i])
}
}
}
func Test_ZeroBlob(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
_, err = conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
_, err = conn.ExecContext(ctx, `INSERT INTO test(col) VALUES(?)`, sqlite3.ZeroBlob(4))
if err != nil {
t.Fatal(err)
}
var got []byte
err = conn.QueryRowContext(ctx, `SELECT col FROM test`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if string(got) != "\x00\x00\x00\x00" {
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
}
}

View File

@@ -5,6 +5,7 @@ type errorString string
func (e errorString) Error() string { return string(e) }
const (
assertErr = errorString("sqlite3: assertion failed")
tailErr = errorString("sqlite3: multiple statements")
assertErr = errorString("sqlite3: assertion failed")
tailErr = errorString("sqlite3: multiple statements")
isolationErr = errorString("sqlite3: unsupported isolation level")
)

View File

@@ -8,9 +8,21 @@ import (
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
// if it roundtrips back to the same string.
// This way times can be persisted to, and recovered from, the database,
// but if a string is needed, [database.sql] will recover the same string.
// TODO: optimize and fuzz test.
// but if a string is needed, [database/sql] will recover the same string.
func maybeDate(text string) driver.Value {
// Weed out (some) values that can't possibly be
// [time.RFC3339Nano] timestamps.
if len(text) < len("2006-01-02T15:04:05Z") {
return text
}
if len(text) > len(time.RFC3339Nano) {
return text
}
if text[4] != '-' || text[10] != 'T' || text[16] != ':' {
return text
}
// Slow path.
date, err := time.Parse(time.RFC3339Nano, text)
if err == nil && date.Format(time.RFC3339Nano) == text {
return date

46
driver/time_test.go Normal file
View File

@@ -0,0 +1,46 @@
package driver
import (
"testing"
"time"
)
func Fuzz_maybeDate(f *testing.F) {
f.Add("")
f.Add(" ")
f.Add("SQLite")
f.Add(time.RFC3339)
f.Add(time.RFC3339Nano)
f.Add(time.Layout)
f.Add(time.DateTime)
f.Add(time.DateOnly)
f.Add(time.TimeOnly)
f.Add("2006-01-02T15:04:05Z")
f.Add("2006-01-02T15:04:05.000Z")
f.Add("2006-01-02T15:04:05.9999999999Z")
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
f.Fuzz(func(t *testing.T, str string) {
value := maybeDate(str)
switch v := value.(type) {
case time.Time:
// Make sure times round-trip to the same string:
// https://pkg.go.dev/database/sql#Rows.Scan
if v.Format(time.RFC3339Nano) != str {
t.Fatalf("did not round-trip: %q", str)
}
case string:
if v != str {
t.Fatalf("did not round-trip: %q", str)
}
date, err := time.Parse(time.RFC3339Nano, str)
if err == nil && date.Format(time.RFC3339Nano) == str {
t.Fatalf("would round-trip: %q", str)
}
default:
t.Fatalf("invalid type %T: %q", v, str)
}
})
}

View File

@@ -50,6 +50,11 @@ func (e *Error) Error() string {
return b.String()
}
// Temporary returns true for [BUSY] errors.
func (e *Error) Temporary() bool {
return e.Code() == BUSY
}
// SQL returns the SQL starting at the token that triggered a syntax error.
func (e *Error) SQL() string {
return e.sql
@@ -60,12 +65,14 @@ type errorString string
func (e errorString) Error() string { return string(e) }
const (
binaryErr = errorString("sqlite3: no SQLite binary embed/set/loaded")
nilErr = errorString("sqlite3: invalid memory address or null pointer dereference")
oomErr = errorString("sqlite3: out of memory")
rangeErr = errorString("sqlite3: index out of range")
noNulErr = errorString("sqlite3: missing NUL terminator")
noGlobalErr = errorString("sqlite3: could not find global: ")
noFuncErr = errorString("sqlite3: could not find function: ")
timeErr = errorString("sqlite3: invalid time value")
)
func assertErr() errorString {

2
go.mod
View File

@@ -4,7 +4,7 @@ go 1.19
require (
github.com/ncruces/julianday v0.1.5
github.com/tetratelabs/wazero v1.0.0-pre.8
github.com/tetratelabs/wazero v1.0.0-pre.9
golang.org/x/sync v0.1.0
golang.org/x/sys v0.5.0
)

4
go.sum
View File

@@ -1,7 +1,7 @@
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.0.0-pre.8 h1:Ir82PWj79WCppH+9ny73eGY2qv+oCnE3VwMY92cBSyI=
github.com/tetratelabs/wazero v1.0.0-pre.8/go.mod h1:u8wrFmpdrykiFK0DFPiFm5a4+0RzsdmXYVtijBKqUVo=
github.com/tetratelabs/wazero v1.0.0-pre.9 h1:2uVdi2bvTi/JQxG2cp3LRm2aRadd3nURn5jcfbvqZcw=
github.com/tetratelabs/wazero v1.0.0-pre.9/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=

161
mock_test.go Normal file
View File

@@ -0,0 +1,161 @@
package sqlite3
import (
"context"
"encoding/binary"
"math"
"github.com/tetratelabs/wazero/api"
)
func init() {
Path = "./embed/sqlite3.wasm"
}
func newMemory(size uint32) memory {
mem := make(mockMemory, size)
return memory{mockModule{&mem}}
}
type mockModule struct {
memory api.Memory
}
func (m mockModule) Memory() api.Memory { return m.memory }
func (m mockModule) String() string { return "mockModule" }
func (m mockModule) Name() string { return "mockModule" }
func (m mockModule) ExportedGlobal(name string) api.Global { return nil }
func (m mockModule) ExportedMemory(name string) api.Memory { return nil }
func (m mockModule) ExportedFunction(name string) api.Function { return nil }
func (m mockModule) ExportedMemoryDefinitions() map[string]api.MemoryDefinition { return nil }
func (m mockModule) ExportedFunctionDefinitions() map[string]api.FunctionDefinition { return nil }
func (m mockModule) CloseWithExitCode(ctx context.Context, exitCode uint32) error { return nil }
func (m mockModule) Close(context.Context) error { return nil }
type mockMemory []byte
func (m mockMemory) Definition() api.MemoryDefinition { return nil }
func (m mockMemory) Size() uint32 { return uint32(len(m)) }
func (m mockMemory) ReadByte(offset uint32) (byte, bool) {
if offset >= m.Size() {
return 0, false
}
return m[offset], true
}
func (m mockMemory) ReadUint16Le(offset uint32) (uint16, bool) {
if !m.hasSize(offset, 2) {
return 0, false
}
return binary.LittleEndian.Uint16(m[offset : offset+2]), true
}
func (m mockMemory) ReadUint32Le(offset uint32) (uint32, bool) {
if !m.hasSize(offset, 4) {
return 0, false
}
return binary.LittleEndian.Uint32(m[offset : offset+4]), true
}
func (m mockMemory) ReadFloat32Le(offset uint32) (float32, bool) {
v, ok := m.ReadUint32Le(offset)
if !ok {
return 0, false
}
return math.Float32frombits(v), true
}
func (m mockMemory) ReadUint64Le(offset uint32) (uint64, bool) {
if !m.hasSize(offset, 8) {
return 0, false
}
return binary.LittleEndian.Uint64(m[offset : offset+8]), true
}
func (m mockMemory) ReadFloat64Le(offset uint32) (float64, bool) {
v, ok := m.ReadUint64Le(offset)
if !ok {
return 0, false
}
return math.Float64frombits(v), true
}
func (m mockMemory) Read(offset, byteCount uint32) ([]byte, bool) {
if !m.hasSize(offset, byteCount) {
return nil, false
}
return m[offset : offset+byteCount : offset+byteCount], true
}
func (m mockMemory) WriteByte(offset uint32, v byte) bool {
if offset >= m.Size() {
return false
}
m[offset] = v
return true
}
func (m mockMemory) WriteUint16Le(offset uint32, v uint16) bool {
if !m.hasSize(offset, 2) {
return false
}
binary.LittleEndian.PutUint16(m[offset:], v)
return true
}
func (m mockMemory) WriteUint32Le(offset, v uint32) bool {
if !m.hasSize(offset, 4) {
return false
}
binary.LittleEndian.PutUint32(m[offset:], v)
return true
}
func (m mockMemory) WriteFloat32Le(offset uint32, v float32) bool {
return m.WriteUint32Le(offset, math.Float32bits(v))
}
func (m mockMemory) WriteUint64Le(offset uint32, v uint64) bool {
if !m.hasSize(offset, 8) {
return false
}
binary.LittleEndian.PutUint64(m[offset:], v)
return true
}
func (m mockMemory) WriteFloat64Le(offset uint32, v float64) bool {
return m.WriteUint64Le(offset, math.Float64bits(v))
}
func (m mockMemory) Write(offset uint32, val []byte) bool {
if !m.hasSize(offset, uint32(len(val))) {
return false
}
copy(m[offset:], val)
return true
}
func (m mockMemory) WriteString(offset uint32, val string) bool {
if !m.hasSize(offset, uint32(len(val))) {
return false
}
copy(m[offset:], val)
return true
}
func (m *mockMemory) Grow(delta uint32) (result uint32, ok bool) {
prev := (len(*m) + 65535) / 65536
*m = append(*m, make([]byte, 65536*delta)...)
return uint32(prev), true
}
func (m mockMemory) PageSize() (result uint32) {
return uint32(len(m) / 65536)
}
func (m mockMemory) hasSize(offset uint32, byteCount uint32) bool {
return uint64(offset)+uint64(byteCount) <= uint64(len(m))
}

61
stmt.go
View File

@@ -2,6 +2,7 @@ package sqlite3
import (
"math"
"time"
)
// Stmt is a prepared statement object.
@@ -15,6 +16,8 @@ type Stmt struct {
// Close destroys the prepared statement object.
//
// It is safe to close a nil, zero or closed prepared statement.
//
// https://www.sqlite.org/c3ref/finalize.html
func (s *Stmt) Close() error {
if s == nil || s.handle == 0 {
@@ -219,6 +222,19 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
return s.c.error(r[0])
}
// BindZeroBlob binds a zero-filled, length n BLOB to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindZeroBlob(param int, n int64) error {
r, err := s.c.api.bindZeroBlob.Call(s.c.ctx,
uint64(s.handle), uint64(param), uint64(n))
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
// BindNull binds a NULL to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
@@ -232,6 +248,24 @@ func (s *Stmt) BindNull(param int) error {
return s.c.error(r[0])
}
// BindTime binds a [time.Time] to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
switch v := format.Encode(value).(type) {
case string:
s.BindText(param, v)
case int64:
s.BindInt64(param, v)
case float64:
s.BindFloat(param, v)
default:
panic(assertErr())
}
return nil
}
// ColumnCount returns the number of columns in a result set.
//
// https://www.sqlite.org/c3ref/column_count.html
@@ -257,7 +291,7 @@ func (s *Stmt) ColumnName(col int) string {
ptr := uint32(r[0])
if ptr == 0 {
return ""
panic(oomErr)
}
return s.c.mem.readString(ptr, _MAX_STRING)
}
@@ -323,6 +357,31 @@ func (s *Stmt) ColumnFloat(col int) float64 {
return math.Float64frombits(r[0])
}
// ColumnTime returns the value of the result column as a [time.Time].
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time {
var v any
switch s.ColumnType(col) {
case INTEGER:
v = s.ColumnInt64(col)
case FLOAT:
v = s.ColumnFloat(col)
case TEXT, BLOB:
v = s.ColumnText(col)
case NULL:
return time.Time{}
default:
panic(assertErr())
}
t, err := format.Decode(v)
if err != nil {
s.err = err
}
return t
}
// ColumnText returns the value of the result column as a string.
// The leftmost column of the result set has the index 0.
//

180
tests/bradfitz/sql_test.go Normal file
View File

@@ -0,0 +1,180 @@
package bradfitz
// Adapted from: https://github.com/bradfitz/go-sql-test
import (
"database/sql"
"fmt"
"math/rand"
"path/filepath"
"sync"
"testing"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
type Tester interface {
RunTest(*testing.T, func(params))
}
var (
sqlite Tester = sqliteDB{}
)
const TablePrefix = "gosqltest_"
type sqliteDB struct{}
type params struct {
dbType Tester
*testing.T
*sql.DB
}
func (t params) mustExec(sql string, args ...interface{}) sql.Result {
res, err := t.DB.Exec(sql, args...)
if err != nil {
t.Fatalf("Error running %q: %v", sql, err)
}
return res
}
func (sqliteDB) RunTest(t *testing.T, fn func(params)) {
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "foo.db"))
if err != nil {
t.Fatalf("foo.db open fail: %v", err)
}
fn(params{sqlite, t, db})
if err := db.Close(); err != nil {
t.Fatalf("foo.db close fail: %v", err)
}
}
func TestBlobs_SQLite(t *testing.T) { sqlite.RunTest(t, testBlobs) }
func testBlobs(t params) {
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
t.mustExec("create table " + TablePrefix + "foo (id integer primary key, bar blob)")
t.mustExec("insert into "+TablePrefix+"foo (id, bar) values(?,?)", 0, blob)
want := fmt.Sprintf("%x", blob)
b := make([]byte, 16)
err := t.QueryRow("select bar from "+TablePrefix+"foo where id = ?", 0).Scan(&b)
got := fmt.Sprintf("%x", b)
if err != nil {
t.Errorf("[]byte scan: %v", err)
} else if got != want {
t.Errorf("for []byte, got %q; want %q", got, want)
}
err = t.QueryRow("select bar from "+TablePrefix+"foo where id = ?", 0).Scan(&got)
want = string(blob)
if err != nil {
t.Errorf("string scan: %v", err)
} else if got != want {
t.Errorf("for string, got %q; want %q", got, want)
}
}
func TestManyQueryRow_SQLite(t *testing.T) { sqlite.RunTest(t, testManyQueryRow) }
func testManyQueryRow(t params) {
if testing.Short() {
t.Skip("skipping in short mode")
}
t.mustExec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))")
t.mustExec("insert into "+TablePrefix+"foo (id, name) values(?,?)", 1, "bob")
var name string
for i := 0; i < 10000; i++ {
err := t.QueryRow("select name from "+TablePrefix+"foo where id = ?", 1).Scan(&name)
if err != nil || name != "bob" {
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
}
}
}
func TestTxQuery_SQLite(t *testing.T) { sqlite.RunTest(t, testTxQuery) }
func testTxQuery(t params) {
tx, err := t.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
_, err = t.DB.Exec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))")
if err != nil {
t.Logf("cannot drop table "+TablePrefix+"foo: %s", err)
}
_, err = tx.Exec("insert into "+TablePrefix+"foo (id, name) values(?,?)", 1, "bob")
if err != nil {
t.Fatal(err)
}
r, err := tx.Query("select name from "+TablePrefix+"foo where id = ?", 1)
if err != nil {
t.Fatal(err)
}
defer r.Close()
if !r.Next() {
if r.Err() != nil {
t.Fatal(err)
}
t.Fatal("expected one rows")
}
var name string
err = r.Scan(&name)
if err != nil {
t.Fatal(err)
}
}
func TestPreparedStmt_SQLite(t *testing.T) { sqlite.RunTest(t, testPreparedStmt) }
func testPreparedStmt(t params) {
if testing.Short() {
t.Skip("skipping in short mode")
}
t.mustExec("CREATE TABLE " + TablePrefix + "t (count INT)")
sel, err := t.Prepare("SELECT count FROM " + TablePrefix + "t ORDER BY count DESC")
if err != nil {
t.Fatalf("prepare 1: %v", err)
}
ins, err := t.Prepare("INSERT INTO " + TablePrefix + "t (count) VALUES (?)")
if err != nil {
t.Fatalf("prepare 2: %v", err)
}
for n := 1; n <= 3; n++ {
if _, err := ins.Exec(n); err != nil {
t.Fatalf("insert(%d) = %v", n, err)
}
}
const nRuns = 10
var wg sync.WaitGroup
for i := 0; i < nRuns; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10; j++ {
count := 0
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
t.Errorf("Query: %v", err)
return
}
if _, err := ins.Exec(rand.Intn(100)); err != nil {
t.Errorf("Insert: %v", err)
return
}
}
}()
}
wg.Wait()
}

View File

@@ -1,4 +1,4 @@
package compile_empty
package compile
import (
"testing"

View File

@@ -1,4 +1,4 @@
package compile_empty
package compile
import (
"testing"
@@ -6,7 +6,7 @@ import (
"github.com/ncruces/go-sqlite3"
)
func TestCompile_empty(t *testing.T) {
func TestCompile_missing(t *testing.T) {
sqlite3.Path = "sqlite3.wasm"
_, err := sqlite3.Open(":memory:")
if err == nil {

View File

@@ -0,0 +1,14 @@
package compile
import (
"testing"
"github.com/ncruces/go-sqlite3"
)
func TestCompile_nil(t *testing.T) {
_, err := sqlite3.Open(":memory:")
if err == nil {
t.Error("want error")
}
}

229
tests/conn_test.go Normal file
View File

@@ -0,0 +1,229 @@
package tests
import (
"context"
"errors"
"strings"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestConn_Open_dir(t *testing.T) {
t.Parallel()
_, err := sqlite3.Open(".")
if err == nil {
t.Fatal("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.CANTOPEN {
t.Errorf("got %d, want sqlite3.CANTOPEN", rc)
}
if got := err.Error(); got != `sqlite3: unable to open database file` {
t.Error("got message: ", got)
}
}
func TestConn_Close(t *testing.T) {
var conn *sqlite3.Conn
conn.Close()
}
func TestConn_Close_BUSY(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`BEGIN`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
err = db.Close()
if err == nil {
t.Fatal("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.BUSY {
t.Errorf("got %d, want sqlite3.BUSY", rc)
}
var terr interface{ Temporary() bool }
if !errors.As(err, &terr) || !terr.Temporary() {
t.Error("not temporary", err)
}
if got := err.Error(); got != `sqlite3: database is locked: unable to close due to unfinalized statements or unfinished backups` {
t.Error("got message: ", got)
}
}
func TestConn_SetInterrupt(t *testing.T) {
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx.Done())
// Interrupt doesn't interrupt this.
err = db.Exec(`SELECT 1`)
if err != nil {
t.Fatal(err)
}
db.SetInterrupt(nil)
stmt, _, err := db.Prepare(`
WITH RECURSIVE
fibonacci (curr, next)
AS (
SELECT 0, 1
UNION ALL
SELECT next, curr + next FROM fibonacci
LIMIT 1e6
)
SELECT min(curr) FROM fibonacci
`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
cancel()
db.SetInterrupt(ctx.Done())
var serr *sqlite3.Error
// Interrupting works.
err = stmt.Exec()
if err != nil {
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
}
// Interrupting sticks.
err = db.Exec(`SELECT 1`)
if err != nil {
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
}
db.SetInterrupt(nil)
// Interrupting can be cleared.
err = db.Exec(`SELECT 1`)
if err != nil {
t.Fatal(err)
}
}
func TestConn_Prepare_empty(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(``)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt != nil {
t.Error("want nil")
}
}
func TestConn_Prepare_tail(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, tail, err := db.Prepare(`SELECT 1; -- HERE`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if !strings.Contains(tail, "-- HERE") {
t.Errorf("got %q", tail)
}
}
func TestConn_Prepare_invalid(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
var serr *sqlite3.Error
_, _, err = db.Prepare(`SELECT`)
if err == nil {
t.Fatal("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message: ", got)
}
_, _, err = db.Prepare(`SELECT * FRM sqlite_schema`)
if err == nil {
t.Fatal("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.ERROR", err)
}
if rc := serr.Code(); rc != sqlite3.ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := serr.SQL(); got != `FRM sqlite_schema` {
t.Error("got SQL: ", got)
}
if got := serr.Error(); got != `sqlite3: SQL logic error: near "FRM": syntax error` {
t.Error("got message: ", got)
}
}

View File

@@ -17,6 +17,8 @@ func TestDB_file(t *testing.T) {
}
func testDB(t *testing.T, name string) {
t.Parallel()
db, err := sqlite3.Open(name)
if err != nil {
t.Fatal(err)
@@ -32,6 +34,10 @@ func testDB(t *testing.T, name string) {
if err != nil {
t.Fatal(err)
}
changes := db.Changes()
if changes != 3 {
t.Errorf("got %d want 3", changes)
}
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
if err != nil {
@@ -43,18 +49,22 @@ func testDB(t *testing.T, name string) {
ids := []int{0, 1, 2}
names := []string{"go", "zig", "whatever"}
for ; stmt.Step(); row++ {
if ids[row] != stmt.ColumnInt(0) {
t.Errorf("got %d, want %d", stmt.ColumnInt(0), ids[row])
id := stmt.ColumnInt(0)
name := stmt.ColumnText(1)
if id != ids[row] {
t.Errorf("got %d, want %d", id, ids[row])
}
if names[row] != stmt.ColumnText(1) {
t.Errorf("got %q, want %q", stmt.ColumnText(1), names[row])
if name != names[row] {
t.Errorf("got %q, want %q", name, names[row])
}
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
if row != 3 {
t.Errorf("got %d rows, want %d", row, len(ids))
t.Errorf("got %d, want %d", row, len(ids))
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
err = stmt.Close()

View File

@@ -1,26 +0,0 @@
package tests
import (
"errors"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestDir(t *testing.T) {
_, err := sqlite3.Open(".")
if err == nil {
t.Fatal("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.CANTOPEN {
t.Errorf("got %d, want sqlite3.CANTOPEN", rc)
}
if got := err.Error(); got != `sqlite3: unable to open database file` {
t.Error("got message: ", got)
}
}

103
tests/driver_test.go Normal file
View File

@@ -0,0 +1,103 @@
package tests
import (
"context"
"database/sql"
"testing"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestDriver(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
_, err = conn.ExecContext(ctx,
`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
res, err := conn.ExecContext(ctx,
`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}
changes, err := res.RowsAffected()
if err != nil {
t.Fatal(err)
}
if changes != 3 {
t.Errorf("got %d want 3", changes)
}
stmt, err := conn.PrepareContext(context.Background(),
`SELECT id, name FROM users`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
t.Fatal(err)
}
defer rows.Close()
row := 0
ids := []int{0, 1, 2}
names := []string{"go", "zig", "whatever"}
for ; rows.Next(); row++ {
var id int
var name string
err := rows.Scan(&id, &name)
if err != nil {
t.Fatal(err)
}
if id != ids[row] {
t.Errorf("got %d, want %d", id, ids[row])
}
if name != names[row] {
t.Errorf("got %q, want %q", name, names[row])
}
}
if row != 3 {
t.Errorf("got %d, want %d", row, len(ids))
}
err = rows.Close()
if err != nil {
t.Fatal(err)
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
err = conn.Close()
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -1,6 +1,7 @@
package tests
import (
"errors"
"io"
"os"
"os/exec"
@@ -21,7 +22,7 @@ func TestParallel(t *testing.T) {
func TestMultiProcess(t *testing.T) {
if testing.Short() {
return
t.Skip()
}
name := filepath.Join(t.TempDir(), "test.db")
@@ -44,7 +45,11 @@ func TestMultiProcess(t *testing.T) {
testParallel(t, name, 1000)
if err := cmd.Wait(); err != nil {
t.Fatal(err)
t.Error(err)
var eerr *exec.ExitError
if errors.As(err, &eerr) {
t.Error(eerr.Stderr)
}
}
testIntegrity(t, name)
}
@@ -52,7 +57,7 @@ func TestMultiProcess(t *testing.T) {
func TestChildProcess(t *testing.T) {
name := os.Getenv("TestMultiProcess_dbname")
if name == "" || testing.Short() {
return
t.SkipNow()
}
testParallel(t, name, 1000)

View File

@@ -1,14 +1,17 @@
package sqlite3
package tests
import (
"math"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
)
func TestStmt(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
@@ -29,103 +32,80 @@ func TestStmt(t *testing.T) {
t.Errorf("got %d, want 1", got)
}
err = stmt.BindBool(1, false)
if err != nil {
if err := stmt.BindBool(1, false); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
if err := stmt.BindBool(1, true); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.ClearBindings()
if err != nil {
if err := stmt.BindInt(1, 2); err != nil {
t.Fatal(err)
}
if err = stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
if err := stmt.BindFloat(1, math.Pi); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.BindBool(1, true)
if err != nil {
if err := stmt.BindNull(1); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
if err := stmt.BindText(1, ""); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.BindInt(1, 2)
if err != nil {
if err := stmt.BindText(1, "text"); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
if err := stmt.BindBlob(1, []byte("blob")); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.BindFloat(1, math.Pi)
if err != nil {
if err := stmt.BindBlob(1, nil); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
if err := stmt.BindZeroBlob(1, 4); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.BindNull(1)
if err != nil {
if err := stmt.ClearBindings(); err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindText(1, "")
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindText(1, "text")
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindBlob(1, []byte("blob"))
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindBlob(1, nil)
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
@@ -134,7 +114,7 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
// The table should have: 0, NULL, 1, 2, π, NULL, "", "text", `blob`, NULL
// The table should have: 0, 1, 2, π, NULL, "", "text", "blob", NULL, "\0\0\0\0", NULL
stmt, _, err = db.Prepare(`SELECT col FROM test`)
if err != nil {
t.Fatal(err)
@@ -142,7 +122,7 @@ func TestStmt(t *testing.T) {
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnType(0); got != INTEGER {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != false {
@@ -163,28 +143,7 @@ func TestStmt(t *testing.T) {
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != INTEGER {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != true {
@@ -205,7 +164,7 @@ func TestStmt(t *testing.T) {
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != INTEGER {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != true {
@@ -226,7 +185,7 @@ func TestStmt(t *testing.T) {
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != FLOAT {
if got := stmt.ColumnType(0); got != sqlite3.FLOAT {
t.Errorf("got %v, want FLOAT", got)
}
if got := stmt.ColumnBool(0); got != true {
@@ -247,7 +206,7 @@ func TestStmt(t *testing.T) {
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != NULL {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
@@ -268,7 +227,7 @@ func TestStmt(t *testing.T) {
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != TEXT {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnBool(0); got != false {
@@ -289,7 +248,7 @@ func TestStmt(t *testing.T) {
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != TEXT {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnBool(0); got != false {
@@ -310,7 +269,7 @@ func TestStmt(t *testing.T) {
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != BLOB {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnBool(0); got != false {
@@ -331,7 +290,7 @@ func TestStmt(t *testing.T) {
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != NULL {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
@@ -351,24 +310,66 @@ func TestStmt(t *testing.T) {
}
}
err = stmt.Close()
if err != nil {
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "\x00\x00\x00\x00" {
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "\x00\x00\x00\x00" {
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
}
if err := stmt.Close(); err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
if err := db.Close(); err != nil {
t.Fatal(err)
}
}
func TestStmt_Close(t *testing.T) {
var stmt *Stmt
var stmt *sqlite3.Stmt
stmt.Close()
}
func TestStmt_BindName(t *testing.T) {
db, err := Open(":memory:")
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
@@ -398,3 +399,65 @@ func TestStmt_BindName(t *testing.T) {
}
}
}
func TestStmt_Time(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT ?, ?, ?, datetime(), unixepoch(), julianday(), NULL, 'abc'`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
err = stmt.BindTime(1, reference, sqlite3.TimeFormat4)
if err != nil {
t.Fatal(err)
}
err = stmt.BindTime(2, reference, sqlite3.TimeFormatUnixMilli)
if err != nil {
t.Fatal(err)
}
err = stmt.BindTime(3, reference, sqlite3.TimeFormatJulianDay)
if err != nil {
t.Fatal(err)
}
if now := time.Now(); stmt.Step() {
if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); !reference.Equal(got) {
t.Errorf("got %v, want %v", got, reference)
}
if got := stmt.ColumnTime(1, sqlite3.TimeFormatAuto); !reference.Equal(got) {
t.Errorf("got %v, want %v", got, reference)
}
if got := stmt.ColumnTime(2, sqlite3.TimeFormatAuto); reference.Sub(got) > time.Millisecond {
t.Errorf("got %v, want %v", got, reference)
}
if got := stmt.ColumnTime(3, sqlite3.TimeFormatAuto); now.Sub(got) > time.Second {
t.Errorf("got %v, want %v", got, now)
}
if got := stmt.ColumnTime(4, sqlite3.TimeFormatAuto); now.Sub(got) > time.Second {
t.Errorf("got %v, want %v", got, now)
}
if got := stmt.ColumnTime(5, sqlite3.TimeFormatAuto); now.Sub(got) > time.Millisecond {
t.Errorf("got %v, want %v", got, now)
}
if got := stmt.ColumnTime(6, sqlite3.TimeFormatAuto); got != (time.Time{}) {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnTime(7, sqlite3.TimeFormatAuto); got != (time.Time{}) {
t.Errorf("got %v, want zero", got)
}
if stmt.Err() == nil {
t.Errorf("want error")
}
}
}

317
time.go Normal file
View File

@@ -0,0 +1,317 @@
package sqlite3
import (
"math"
"strconv"
"strings"
"time"
"github.com/ncruces/julianday"
)
// TimeFormat specifies how to encode/decode time values.
//
// https://www.sqlite.org/lang_datefunc.html
type TimeFormat string
// TimeFormats recognized by SQLite to encode/decode time values.
//
// https://www.sqlite.org/lang_datefunc.html
const (
TimeFormatDefault TimeFormat = "" // time.RFC3339Nano
// Text formats
TimeFormat1 TimeFormat = "2006-01-02"
TimeFormat2 TimeFormat = "2006-01-02 15:04"
TimeFormat3 TimeFormat = "2006-01-02 15:04:05"
TimeFormat4 TimeFormat = "2006-01-02 15:04:05.000"
TimeFormat5 TimeFormat = "2006-01-02T15:04"
TimeFormat6 TimeFormat = "2006-01-02T15:04:05"
TimeFormat7 TimeFormat = "2006-01-02T15:04:05.000"
TimeFormat8 TimeFormat = "15:04"
TimeFormat9 TimeFormat = "15:04:05"
TimeFormat10 TimeFormat = "15:04:05.000"
TimeFormat2TZ = TimeFormat2 + "Z07:00"
TimeFormat3TZ = TimeFormat3 + "Z07:00"
TimeFormat4TZ = TimeFormat4 + "Z07:00"
TimeFormat5TZ = TimeFormat5 + "Z07:00"
TimeFormat6TZ = TimeFormat6 + "Z07:00"
TimeFormat7TZ = TimeFormat7 + "Z07:00"
TimeFormat8TZ = TimeFormat8 + "Z07:00"
TimeFormat9TZ = TimeFormat9 + "Z07:00"
TimeFormat10TZ = TimeFormat10 + "Z07:00"
// Numeric formats
TimeFormatJulianDay TimeFormat = "julianday"
TimeFormatUnix TimeFormat = "unixepoch"
TimeFormatUnixFrac TimeFormat = "unixepoch_frac"
TimeFormatUnixMilli TimeFormat = "unixepoch_milli" // not an SQLite format
TimeFormatUnixMicro TimeFormat = "unixepoch_micro" // not an SQLite format
TimeFormatUnixNano TimeFormat = "unixepoch_nano" // not an SQLite format
// Auto
TimeFormatAuto TimeFormat = "auto"
)
// Encode encodes a time value using this format.
//
// [TimeFormatDefault] and [TimeFormatAuto] encode using [time.RFC3339Nano],
// with nanosecond accuracy, and preserving timezone.
//
// Formats [TimeFormat1] through [TimeFormat10]
// convert time values to UTC before encoding.
//
// Returns a string for the text formats,
// a float64 for [TimeFormatJulianDay] and [TimeFormatUnixFrac],
// or an int64 for the other numeric formats.
//
// https://www.sqlite.org/lang_datefunc.html
func (f TimeFormat) Encode(t time.Time) any {
switch f {
// Numeric formats
case TimeFormatJulianDay:
return julianday.Float(t)
case TimeFormatUnix:
return t.Unix()
case TimeFormatUnixFrac:
return float64(t.Unix()) + float64(t.Nanosecond())*1e-9
case TimeFormatUnixMilli:
return t.UnixMilli()
case TimeFormatUnixMicro:
return t.UnixMicro()
case TimeFormatUnixNano:
return t.UnixNano()
// Special formats
case TimeFormatDefault, TimeFormatAuto:
f = time.RFC3339Nano
// SQLite assumes UTC if unspecified.
case
TimeFormat1, TimeFormat2,
TimeFormat3, TimeFormat4,
TimeFormat5, TimeFormat6,
TimeFormat7, TimeFormat8,
TimeFormat9, TimeFormat10:
t = t.UTC()
}
return t.Format(string(f))
}
// Decode decodes a time value using this format.
//
// The time value can be a string, an int64, or a float64.
//
// Formats [TimeFormat8] through [TimeFormat10]
// assume a date of 2000-01-01.
//
// The timezone indicator and fractional seconds are always optional
// for formats [TimeFormat2] through [TimeFormat10].
//
// [TimeFormatAuto] implements (and extends) the SQLite auto modifier.
// The julian day number is safe to use for historical dates,
// from 4712BC through 9999AD.
// Unix timestamps (expressed in seconds, milliseconds, microseconds, or nanoseconds),
// are safe to use for current events, from 1980 through at least 2260.
//
// https://www.sqlite.org/lang_datefunc.html
func (f TimeFormat) Decode(v any) (time.Time, error) {
switch f {
// Numeric formats
case TimeFormatJulianDay:
switch v := v.(type) {
case string:
return julianday.Parse(v)
case float64:
return julianday.FloatTime(v), nil
case int64:
return julianday.Time(v, 0), nil
default:
return time.Time{}, timeErr
}
case TimeFormatUnix, TimeFormatUnixFrac:
if s, ok := v.(string); ok {
f, err := strconv.ParseFloat(s, 64)
if err != nil {
return time.Time{}, err
}
v = f
}
switch v := v.(type) {
case float64:
sec, frac := math.Modf(v)
nsec := math.Floor(frac * 1e9)
return time.Unix(int64(sec), int64(nsec)), nil
case int64:
return time.Unix(v, 0), nil
default:
return time.Time{}, timeErr
}
case TimeFormatUnixMilli:
if s, ok := v.(string); ok {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return time.Time{}, err
}
v = i
}
switch v := v.(type) {
case float64:
return time.UnixMilli(int64(math.Floor(v))), nil
case int64:
return time.UnixMilli(int64(v)), nil
default:
return time.Time{}, timeErr
}
case TimeFormatUnixMicro:
if s, ok := v.(string); ok {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return time.Time{}, err
}
v = i
}
switch v := v.(type) {
case float64:
return time.UnixMicro(int64(math.Floor(v))), nil
case int64:
return time.UnixMicro(int64(v)), nil
default:
return time.Time{}, timeErr
}
case TimeFormatUnixNano:
if s, ok := v.(string); ok {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return time.Time{}, timeErr
}
v = i
}
switch v := v.(type) {
case float64:
return time.Unix(0, int64(math.Floor(v))), nil
case int64:
return time.Unix(0, int64(v)), nil
default:
return time.Time{}, timeErr
}
// Special formats
case TimeFormatAuto:
switch s := v.(type) {
case string:
i, err := strconv.ParseInt(s, 10, 64)
if err == nil {
v = i
break
}
f, err := strconv.ParseFloat(s, 64)
if err == nil {
v = f
break
}
dates := []TimeFormat{
TimeFormat6TZ, TimeFormat6, TimeFormat3TZ, TimeFormat3,
TimeFormat5TZ, TimeFormat5, TimeFormat2TZ, TimeFormat2,
TimeFormat1,
}
for _, f := range dates {
t, err := time.Parse(string(f), s)
if err == nil {
return t, nil
}
}
times := []TimeFormat{
TimeFormat9TZ, TimeFormat9, TimeFormat8TZ, TimeFormat8,
}
for _, f := range times {
t, err := time.Parse(string(f), s)
if err == nil {
return t.AddDate(2000, 0, 0), nil
}
}
}
switch v := v.(type) {
case float64:
if 0 <= v && v < 5373484.5 {
return TimeFormatJulianDay.Decode(v)
}
if v < 253402300800 {
return TimeFormatUnixFrac.Decode(v)
}
if v < 253402300800_000 {
return TimeFormatUnixMilli.Decode(v)
}
if v < 253402300800_000000 {
return TimeFormatUnixMicro.Decode(v)
}
return TimeFormatUnixNano.Decode(v)
case int64:
if 0 <= v && v < 5373485 {
return TimeFormatJulianDay.Decode(v)
}
if v < 253402300800 {
return TimeFormatUnixFrac.Decode(v)
}
if v < 253402300800_000 {
return TimeFormatUnixMilli.Decode(v)
}
if v < 253402300800_000000 {
return TimeFormatUnixMicro.Decode(v)
}
return TimeFormatUnixNano.Decode(v)
default:
return time.Time{}, timeErr
}
case
TimeFormat2, TimeFormat2TZ,
TimeFormat3, TimeFormat3TZ,
TimeFormat4, TimeFormat4TZ,
TimeFormat5, TimeFormat5TZ,
TimeFormat6, TimeFormat6TZ,
TimeFormat7, TimeFormat7TZ:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
}
return f.parseRelaxed(s)
case
TimeFormat8, TimeFormat8TZ,
TimeFormat9, TimeFormat9TZ,
TimeFormat10, TimeFormat10TZ:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
}
t, err := f.parseRelaxed(s)
return t.AddDate(2000, 0, 0), err
default:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
}
if f == "" {
f = time.RFC3339Nano
}
return time.Parse(string(f), s)
}
}
func (f TimeFormat) parseRelaxed(s string) (time.Time, error) {
fs := string(f)
fs = strings.TrimSuffix(fs, "Z07:00")
fs = strings.TrimSuffix(fs, ".000")
t, err := time.Parse(fs+"Z07:00", s)
if err != nil {
return time.Parse(fs, s)
}
return t, nil
}

118
time_test.go Normal file
View File

@@ -0,0 +1,118 @@
package sqlite3
import (
"reflect"
"testing"
"time"
)
func TestTimeFormat_Encode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt TimeFormat
time time.Time
want any
}{
{TimeFormatDefault, reference, "2013-10-07T04:23:19.12-04:00"},
{TimeFormatJulianDay, reference, 2456572.849526851851852},
{TimeFormatUnix, reference, int64(1381134199)},
{TimeFormatUnixFrac, reference, 1381134199.120},
{TimeFormatUnixMilli, reference, int64(1381134199_120)},
{TimeFormatUnixMicro, reference, int64(1381134199_120000)},
{TimeFormatUnixNano, reference, int64(1381134199_120000000)},
{TimeFormat7, reference, "2013-10-07T08:23:19.120"},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
if got := tt.fmt.Encode(tt.time); !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q.Encode(%v) = %v, want %v", tt.fmt, tt.time, got, tt.want)
}
})
}
}
func TestTimeFormat_Decode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
reftime := time.Date(2000, 1, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt TimeFormat
val any
want time.Time
wantDelta time.Duration
wantErr bool
}{
{TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
{TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
{TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
{TimeFormatJulianDay, false, time.Time{}, 0, true},
{TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
{TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
{TimeFormatUnix, int64(1381134199), reference, time.Second, false},
{TimeFormatUnix, "abc", time.Time{}, 0, true},
{TimeFormatUnix, false, time.Time{}, 0, true},
{TimeFormatUnixMilli, "1381134199120", reference, 0, false},
{TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
{TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
{TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
{TimeFormatUnixMilli, false, time.Time{}, 0, true},
{TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
{TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
{TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
{TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
{TimeFormatUnixMicro, false, time.Time{}, 0, true},
{TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
{TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
{TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
{TimeFormatUnixNano, "abc", time.Time{}, 0, true},
{TimeFormatUnixNano, false, time.Time{}, 0, true},
{TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
{TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
{TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199", reference, time.Second, false},
{TimeFormatAuto, "1381134199120", reference, 0, false},
{TimeFormatAuto, "1381134199120000", reference, 0, false},
{TimeFormatAuto, "1381134199120000000", reference, 0, false},
{TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{TimeFormatAuto, "04:23:19.12-04:00", reftime, 0, false},
{TimeFormatAuto, "abc", time.Time{}, 0, true},
{TimeFormatAuto, false, time.Time{}, 0, true},
{TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
{TimeFormat9, "04:23:19.12-04:00", reftime, 0, false},
{TimeFormat9, "08:23:19.12", reftime, 0, false},
{TimeFormat3, false, time.Time{}, 0, true},
{TimeFormat9, false, time.Time{}, 0, true},
{TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
{TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
{TimeFormatDefault, false, time.Time{}, 0, true},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
got, err := tt.fmt.Decode(tt.val)
if (err != nil) != tt.wantErr {
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
return
}
if tt.want.Sub(got).Abs() > tt.wantDelta {
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
}
})
}
}

16
util.go Normal file
View File

@@ -0,0 +1,16 @@
package sqlite3
// Return true if stmt is an empty SQL statement.
// This is used as an optimization.
// It's OK to always return false here.
func emptyStatement(stmt string) bool {
for _, b := range []byte(stmt) {
switch b {
case ' ', '\n', '\r', '\t', '\v', '\f':
case ';':
default:
return false
}
}
return true
}

View File

@@ -1,161 +1,60 @@
package sqlite3
import (
"context"
"encoding/binary"
"math"
"github.com/tetratelabs/wazero/api"
"testing"
)
func init() {
Path = "./embed/sqlite3.wasm"
}
func Test_emptyStatement(t *testing.T) {
t.Parallel()
func newMemory(size uint32) memory {
mem := make(mockMemory, size)
return memory{mockModule{&mem}}
}
type mockModule struct {
memory api.Memory
}
func (m mockModule) Memory() api.Memory { return m.memory }
func (m mockModule) String() string { return "mockModule" }
func (m mockModule) Name() string { return "mockModule" }
func (m mockModule) ExportedGlobal(name string) api.Global { return nil }
func (m mockModule) ExportedMemory(name string) api.Memory { return nil }
func (m mockModule) ExportedFunction(name string) api.Function { return nil }
func (m mockModule) ExportedMemoryDefinitions() map[string]api.MemoryDefinition { return nil }
func (m mockModule) ExportedFunctionDefinitions() map[string]api.FunctionDefinition { return nil }
func (m mockModule) CloseWithExitCode(ctx context.Context, exitCode uint32) error { return nil }
func (m mockModule) Close(context.Context) error { return nil }
type mockMemory []byte
func (m mockMemory) Definition() api.MemoryDefinition { return nil }
func (m mockMemory) Size() uint32 { return uint32(len(m)) }
func (m mockMemory) ReadByte(offset uint32) (byte, bool) {
if offset >= m.Size() {
return 0, false
tests := []struct {
name string
stmt string
want bool
}{
{"empty", "", true},
{"space", " ", true},
{"separator", ";\n ", true},
{"begin", "BEGIN", false},
{"select", "SELECT 1;", false},
}
return m[offset], true
}
func (m mockMemory) ReadUint16Le(offset uint32) (uint16, bool) {
if !m.hasSize(offset, 2) {
return 0, false
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := emptyStatement(tt.stmt); got != tt.want {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
return binary.LittleEndian.Uint16(m[offset : offset+2]), true
}
func (m mockMemory) ReadUint32Le(offset uint32) (uint32, bool) {
if !m.hasSize(offset, 4) {
return 0, false
func Fuzz_emptyStatement(f *testing.F) {
f.Add("")
f.Add(" ")
f.Add(";\n ")
f.Add("; ;\v")
f.Add("BEGIN")
f.Add("SELECT 1;")
db, err := Open(":memory:")
if err != nil {
f.Fatal(err)
}
return binary.LittleEndian.Uint32(m[offset : offset+4]), true
}
defer db.Close()
func (m mockMemory) ReadFloat32Le(offset uint32) (float32, bool) {
v, ok := m.ReadUint32Le(offset)
if !ok {
return 0, false
}
return math.Float32frombits(v), true
}
func (m mockMemory) ReadUint64Le(offset uint32) (uint64, bool) {
if !m.hasSize(offset, 8) {
return 0, false
}
return binary.LittleEndian.Uint64(m[offset : offset+8]), true
}
func (m mockMemory) ReadFloat64Le(offset uint32) (float64, bool) {
v, ok := m.ReadUint64Le(offset)
if !ok {
return 0, false
}
return math.Float64frombits(v), true
}
func (m mockMemory) Read(offset, byteCount uint32) ([]byte, bool) {
if !m.hasSize(offset, byteCount) {
return nil, false
}
return m[offset : offset+byteCount : offset+byteCount], true
}
func (m mockMemory) WriteByte(offset uint32, v byte) bool {
if offset >= m.Size() {
return false
}
m[offset] = v
return true
}
func (m mockMemory) WriteUint16Le(offset uint32, v uint16) bool {
if !m.hasSize(offset, 2) {
return false
}
binary.LittleEndian.PutUint16(m[offset:], v)
return true
}
func (m mockMemory) WriteUint32Le(offset, v uint32) bool {
if !m.hasSize(offset, 4) {
return false
}
binary.LittleEndian.PutUint32(m[offset:], v)
return true
}
func (m mockMemory) WriteFloat32Le(offset uint32, v float32) bool {
return m.WriteUint32Le(offset, math.Float32bits(v))
}
func (m mockMemory) WriteUint64Le(offset uint32, v uint64) bool {
if !m.hasSize(offset, 8) {
return false
}
binary.LittleEndian.PutUint64(m[offset:], v)
return true
}
func (m mockMemory) WriteFloat64Le(offset uint32, v float64) bool {
return m.WriteUint64Le(offset, math.Float64bits(v))
}
func (m mockMemory) Write(offset uint32, val []byte) bool {
if !m.hasSize(offset, uint32(len(val))) {
return false
}
copy(m[offset:], val)
return true
}
func (m mockMemory) WriteString(offset uint32, val string) bool {
if !m.hasSize(offset, uint32(len(val))) {
return false
}
copy(m[offset:], val)
return true
}
func (m *mockMemory) Grow(delta uint32) (result uint32, ok bool) {
prev := (len(*m) + 65535) / 65536
*m = append(*m, make([]byte, 65536*delta)...)
return uint32(prev), true
}
func (m mockMemory) PageSize() (result uint32) {
return uint32(len(m) / 65536)
}
func (m mockMemory) hasSize(offset uint32, byteCount uint32) bool {
return uint64(offset)+uint64(byteCount) <= uint64(len(m))
f.Fuzz(func(t *testing.T, sql string) {
// If empty, SQLite parses it as empty.
if emptyStatement(sql) {
stmt, tail, err := db.Prepare(sql)
if err != nil {
t.Errorf("%q, %v", sql, err)
}
if stmt != nil {
t.Errorf("%q, %v", sql, stmt)
}
if tail != "" {
t.Errorf("%q", sql)
}
stmt.Close()
}
})
}

View File

@@ -242,8 +242,9 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS
// Release the file lock only when all connections have released the lock.
ptr.SetLock(_NO_LOCK)
if fLock.shared--; fLock.shared == 0 {
rc := fLock.Release()
fLock.state = _NO_LOCK
return uint32(fLock.Release())
return uint32(rc)
}
return _OK
}

View File

@@ -3,6 +3,7 @@ package sqlite3
import (
"context"
"os"
"path/filepath"
"runtime"
"testing"
)
@@ -16,16 +17,15 @@ func Test_vfsLock(t *testing.T) {
t.Skip()
}
name := filepath.Join(t.TempDir(), "test.db")
// Create a temporary file.
file1, err := os.CreateTemp("", "sqlite3-")
file1, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
if err != nil {
t.Fatal(err)
}
defer file1.Close()
name := file1.Name()
defer os.RemoveAll(name)
// Open the temporary file again.
file2, err := os.OpenFile(name, os.O_RDWR, 0)
if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"io/fs"
"os"
"path/filepath"
"syscall"
"testing"
"time"
@@ -136,12 +137,12 @@ func Test_vfsFullPathname(t *testing.T) {
}
func Test_vfsDelete(t *testing.T) {
file, err := os.CreateTemp("", "sqlite3-")
name := filepath.Join(t.TempDir(), "test.db")
file, err := os.Create(name)
if err != nil {
t.Fatal(err)
}
name := file.Name()
defer os.RemoveAll(name)
file.Close()
mem := newMemory(128 + _MAX_PATHNAME)
@@ -163,8 +164,19 @@ func Test_vfsDelete(t *testing.T) {
}
func Test_vfsAccess(t *testing.T) {
dir := t.TempDir()
file := filepath.Join(t.TempDir(), "test.db")
if f, err := os.Create(file); err != nil {
t.Fatal(err)
} else {
f.Close()
}
if err := os.Chmod(file, syscall.S_IRUSR); err != nil {
t.Fatal(err)
}
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(8, t.TempDir())
mem.writeString(8, dir)
rc := vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_EXISTS, 4)
if rc != _OK {
@@ -181,6 +193,15 @@ func Test_vfsAccess(t *testing.T) {
if got := mem.readUint32(4); got != 1 {
t.Error("can't access directory")
}
mem.writeString(8, file)
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(4); got != 0 {
t.Error("can access file")
}
}
func Test_vfsFile(t *testing.T) {

View File

@@ -33,16 +33,17 @@ func (l *vfsFileLocker) GetExclusive() xErrorCode {
}
func (l *vfsFileLocker) Downgrade() xErrorCode {
// Downgrade to a SHARED lock.
if rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return IOERR_RDLOCK
if l.state >= _EXCLUSIVE_LOCK {
// Downgrade to a SHARED lock.
if rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return l.unlock(_PENDING_BYTE, 2)
}

View File

@@ -39,27 +39,39 @@ func (l *vfsFileLocker) GetExclusive() xErrorCode {
}
func (l *vfsFileLocker) Downgrade() xErrorCode {
// Release the SHARED lock.
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
if l.state >= _EXCLUSIVE_LOCK {
// Release the SHARED lock.
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
// Reacquire the SHARED lock.
if rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE); rc != _OK {
// This should never happen.
// We should always be able to reacquire the read lock.
return IOERR_RDLOCK
// Reacquire the SHARED lock.
if rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE); rc != _OK {
// This should never happen.
// We should always be able to reacquire the read lock.
return IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
l.unlock(_RESERVED_BYTE, 1)
l.unlock(_PENDING_BYTE, 1)
if l.state >= _RESERVED_LOCK {
l.unlock(_RESERVED_BYTE, 1)
}
if l.state >= _PENDING_LOCK {
l.unlock(_PENDING_BYTE, 1)
}
return _OK
}
func (l *vfsFileLocker) Release() xErrorCode {
// Release all locks.
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
l.unlock(_RESERVED_BYTE, 1)
l.unlock(_PENDING_BYTE, 1)
if l.state >= _RESERVED_LOCK {
l.unlock(_RESERVED_BYTE, 1)
}
if l.state >= _SHARED_LOCK {
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
}
if l.state >= _PENDING_LOCK {
l.unlock(_PENDING_BYTE, 1)
}
return _OK
}