mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-15 15:19:13 +00:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
71ae26e5c9 | ||
|
|
e91758c6a4 | ||
|
|
b749b32a62 | ||
|
|
3b4df71a94 | ||
|
|
df687a1c54 | ||
|
|
2f5b9837e1 | ||
|
|
c351400be7 | ||
|
|
231d3a0438 | ||
|
|
2f25e4eedb | ||
|
|
ad27d5d840 |
1
.github/FUNDING.yml
vendored
Normal file
1
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1 @@
|
||||
custom: https://www.paypal.com/donate/buttons/manage/33P59ELZWGMK6
|
||||
2
.github/workflows/go.yml
vendored
2
.github/workflows/go.yml
vendored
@@ -36,5 +36,5 @@ jobs:
|
||||
uses: ncruces/go-coverage-report@main
|
||||
if: |
|
||||
matrix.os == 'ubuntu-latest' &&
|
||||
github.event_name == 'push'
|
||||
github.event_name == 'push'
|
||||
continue-on-error: true
|
||||
|
||||
12
README.md
12
README.md
@@ -6,8 +6,7 @@
|
||||
|
||||
⚠️ CAUTION ⚠️
|
||||
|
||||
This is a WIP.\
|
||||
DO NOT USE with data you care about.
|
||||
This is a WIP.
|
||||
|
||||
Roadmap:
|
||||
- [x] build SQLite using `zig cc --target=wasm32-wasi`
|
||||
@@ -18,4 +17,11 @@ Roadmap:
|
||||
- [x] provide a simple `database/sql` driver
|
||||
- [x] file locking, compatible with SQLite on Windows/Unix
|
||||
- [ ] shared memory, compatible with SQLite on Windows/Unix
|
||||
- needed for improved WAL mode
|
||||
- needed for improved WAL mode
|
||||
- [ ] advanced SQLite features
|
||||
- [ ] nested transactions
|
||||
- [ ] incremental BLOB I/O
|
||||
- [ ] online backup
|
||||
- [ ] session extension
|
||||
- [ ] snapshots
|
||||
- [ ] SQL functions
|
||||
6
blob.go
Normal file
6
blob.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package sqlite3
|
||||
|
||||
// ZeroBlob represents a zero-filled, length n BLOB
|
||||
// that can be used as an argument to
|
||||
// [database/sql.DB.Exec] and similar methods.
|
||||
type ZeroBlob int64
|
||||
@@ -49,6 +49,10 @@ func (s *sqlite3Runtime) compileModule(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if bin == nil {
|
||||
s.err = binaryErr
|
||||
return
|
||||
}
|
||||
|
||||
s.compiled, s.err = s.runtime.CompileModule(ctx, bin)
|
||||
}
|
||||
|
||||
6
conn.go
6
conn.go
@@ -68,6 +68,8 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
|
||||
// open blob handles, and/or unfinished backup objects,
|
||||
// Close will leave the database connection open and return [BUSY].
|
||||
//
|
||||
// It is safe to close a nil, zero or closed connection.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/close.html
|
||||
func (c *Conn) Close() error {
|
||||
if c == nil || c.handle == 0 {
|
||||
@@ -179,6 +181,10 @@ func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/prepare.html
|
||||
func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) {
|
||||
if emptyStatement(sql) {
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
defer c.arena.reset()
|
||||
stmtPtr := c.arena.new(ptrlen)
|
||||
tailPtr := c.arena.new(ptrlen)
|
||||
|
||||
177
conn_test.go
177
conn_test.go
@@ -2,187 +2,10 @@ package sqlite3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConn_Close(t *testing.T) {
|
||||
var conn *Conn
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestConn_Close_BUSY(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(`BEGIN`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
err = db.Close()
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
var serr *Error
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != BUSY {
|
||||
t.Errorf("got %d, want sqlite3.BUSY", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: database is locked: unable to close due to unfinalized statements or unfinished backups` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_SetInterrupt(t *testing.T) {
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
db.SetInterrupt(ctx.Done())
|
||||
|
||||
// Interrupt doesn't interrupt this.
|
||||
err = db.Exec(`SELECT 1`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.SetInterrupt(nil)
|
||||
|
||||
stmt, _, err := db.Prepare(`
|
||||
WITH RECURSIVE
|
||||
fibonacci (curr, next)
|
||||
AS (
|
||||
SELECT 0, 1
|
||||
UNION ALL
|
||||
SELECT next, curr + next FROM fibonacci
|
||||
LIMIT 1e6
|
||||
)
|
||||
SELECT min(curr) FROM fibonacci
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
cancel()
|
||||
db.SetInterrupt(ctx.Done())
|
||||
|
||||
var serr *Error
|
||||
|
||||
// Interrupting works.
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != INTERRUPT {
|
||||
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: interrupted` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Interrupting sticks.
|
||||
err = db.Exec(`SELECT 1`)
|
||||
if err != nil {
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != INTERRUPT {
|
||||
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: interrupted` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
db.SetInterrupt(nil)
|
||||
|
||||
// Interrupting can be cleared.
|
||||
err = db.Exec(`SELECT 1`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Prepare_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(``)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt != nil {
|
||||
t.Error("want nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Prepare_Invalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
var serr *Error
|
||||
|
||||
_, _, err = db.Prepare(`SELECT`)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != ERROR {
|
||||
t.Errorf("got %d, want sqlite3.ERROR", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
|
||||
_, _, err = db.Prepare(`SELECT * FRM sqlite_schema`)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.ERROR", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != ERROR {
|
||||
t.Errorf("got %d, want sqlite3.ERROR", rc)
|
||||
}
|
||||
if got := serr.SQL(); got != `FRM sqlite_schema` {
|
||||
t.Error("got SQL: ", got)
|
||||
}
|
||||
if got := serr.Error(); got != `sqlite3: SQL logic error: near "FRM": syntax error` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_new(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
1
const.go
1
const.go
@@ -197,6 +197,7 @@ const (
|
||||
NULL Datatype = 5
|
||||
)
|
||||
|
||||
// String implements the [fmt.Stringer] interface.
|
||||
func (t Datatype) String() string {
|
||||
const name = "INTEGERFLOATTEXTBLOBNULL"
|
||||
switch t {
|
||||
|
||||
@@ -3,6 +3,8 @@ package sqlite3
|
||||
import "testing"
|
||||
|
||||
func TestDatatype_String(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
data Datatype
|
||||
want string
|
||||
|
||||
143
driver/driver.go
143
driver/driver.go
@@ -5,7 +5,10 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
@@ -22,26 +25,56 @@ func (sqlite) Open(name string) (driver.Conn, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// If the database is not in WAL mode,
|
||||
// use normal locking mode.
|
||||
journal, err := pragma(c, "journal_mode")
|
||||
|
||||
var txBegin string
|
||||
var pragmas strings.Builder
|
||||
if _, after, ok := strings.Cut(name, "?"); ok {
|
||||
query, _ := url.ParseQuery(after)
|
||||
|
||||
switch s := query.Get("_txlock"); s {
|
||||
case "":
|
||||
txBegin = "BEGIN"
|
||||
case "deferred", "immediate", "exclusive":
|
||||
txBegin = "BEGIN " + s
|
||||
default:
|
||||
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
|
||||
}
|
||||
|
||||
for _, p := range query["_pragma"] {
|
||||
pragmas.WriteString(`PRAGMA `)
|
||||
pragmas.WriteString(p)
|
||||
pragmas.WriteByte(';')
|
||||
}
|
||||
}
|
||||
if pragmas.Len() == 0 {
|
||||
pragmas.WriteString(`PRAGMA locking_mode=normal;`)
|
||||
pragmas.WriteString(`PRAGMA busy_timeout=60000;`)
|
||||
}
|
||||
|
||||
err = c.Exec(pragmas.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("sqlite3: invalid _pragma: %w", err)
|
||||
}
|
||||
if journal != "wal" {
|
||||
pragma(c, "locking_mode=normal")
|
||||
}
|
||||
return conn{c}, nil
|
||||
return conn{
|
||||
conn: c,
|
||||
txBegin: txBegin,
|
||||
pragmas: pragmas.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type conn struct{ conn *sqlite3.Conn }
|
||||
type conn struct {
|
||||
conn *sqlite3.Conn
|
||||
pragmas string
|
||||
txBegin string
|
||||
txReadOnly bool
|
||||
}
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ driver.Validator = conn{}
|
||||
_ driver.ExecerContext = conn{}
|
||||
// _ driver.ConnBeginTx = conn{}
|
||||
// _ driver.SessionResetter = conn{}
|
||||
_ driver.Validator = conn{}
|
||||
_ driver.SessionResetter = conn{}
|
||||
_ driver.ExecerContext = conn{}
|
||||
_ driver.ConnBeginTx = conn{}
|
||||
)
|
||||
|
||||
func (c conn) Close() error {
|
||||
@@ -50,12 +83,40 @@ func (c conn) Close() error {
|
||||
|
||||
func (c conn) IsValid() bool {
|
||||
// Pool only normal locking mode connections.
|
||||
mode, _ := pragma(c.conn, "locking_mode")
|
||||
return mode == "normal"
|
||||
stmt, _, err := c.conn.Prepare(`PRAGMA locking_mode`)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer stmt.Close()
|
||||
return stmt.Step() && stmt.ColumnText(0) == "normal"
|
||||
}
|
||||
|
||||
func (c conn) ResetSession(ctx context.Context) error {
|
||||
return c.conn.Exec(c.pragmas)
|
||||
}
|
||||
|
||||
func (c conn) Begin() (driver.Tx, error) {
|
||||
err := c.conn.Exec(`BEGIN`)
|
||||
return c.BeginTx(context.Background(), driver.TxOptions{})
|
||||
}
|
||||
|
||||
func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
switch opts.Isolation {
|
||||
default:
|
||||
return nil, isolationErr
|
||||
case driver.IsolationLevel(sql.LevelDefault):
|
||||
case driver.IsolationLevel(sql.LevelSerializable):
|
||||
}
|
||||
|
||||
txBegin := c.txBegin
|
||||
if opts.ReadOnly {
|
||||
txBegin = `
|
||||
BEGIN deferred;
|
||||
PRAGMA query_only=on;
|
||||
`
|
||||
}
|
||||
c.txReadOnly = opts.ReadOnly
|
||||
|
||||
err := c.conn.Exec(txBegin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -63,6 +124,9 @@ func (c conn) Begin() (driver.Tx, error) {
|
||||
}
|
||||
|
||||
func (c conn) Commit() error {
|
||||
if c.txReadOnly {
|
||||
return c.Rollback()
|
||||
}
|
||||
err := c.conn.Exec(`COMMIT`)
|
||||
if err != nil {
|
||||
c.Rollback()
|
||||
@@ -81,12 +145,14 @@ func (c conn) Prepare(query string) (driver.Stmt, error) {
|
||||
}
|
||||
if tail != "" {
|
||||
// Check if the tail contains any SQL.
|
||||
s, _, err := c.conn.Prepare(tail)
|
||||
st, _, err := c.conn.Prepare(tail)
|
||||
if err != nil {
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
if s != nil {
|
||||
if st != nil {
|
||||
s.Close()
|
||||
st.Close()
|
||||
return nil, tailErr
|
||||
}
|
||||
}
|
||||
@@ -113,18 +179,6 @@ func (c conn) ExecContext(ctx context.Context, query string, args []driver.Named
|
||||
}, nil
|
||||
}
|
||||
|
||||
func pragma(c *sqlite3.Conn, pragma string) (string, error) {
|
||||
stmt, _, err := c.Prepare(`PRAGMA ` + pragma)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if stmt.Step() {
|
||||
return stmt.ColumnText(0), nil
|
||||
}
|
||||
return "", stmt.Err()
|
||||
}
|
||||
|
||||
type stmt struct {
|
||||
stmt *sqlite3.Stmt
|
||||
conn *sqlite3.Conn
|
||||
@@ -132,8 +186,9 @@ type stmt struct {
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ driver.StmtExecContext = stmt{}
|
||||
_ driver.StmtQueryContext = stmt{}
|
||||
_ driver.StmtExecContext = stmt{}
|
||||
_ driver.StmtQueryContext = stmt{}
|
||||
_ driver.NamedValueChecker = stmt{}
|
||||
)
|
||||
|
||||
func (s stmt) Close() error {
|
||||
@@ -141,7 +196,13 @@ func (s stmt) Close() error {
|
||||
}
|
||||
|
||||
func (s stmt) NumInput() int {
|
||||
return s.stmt.BindCount()
|
||||
n := s.stmt.BindCount()
|
||||
for i := 1; i <= n; i++ {
|
||||
if s.stmt.BindName(i) != "" {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// Deprecated: use ExecContext instead.
|
||||
@@ -155,6 +216,8 @@ func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
}
|
||||
|
||||
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
// Use QueryContext to setup bindings.
|
||||
// No need to close rows: that simply resets the statement, exec does the same.
|
||||
_, err := s.QueryContext(ctx, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -194,6 +257,8 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
switch a := arg.Value.(type) {
|
||||
case bool:
|
||||
err = s.stmt.BindBool(id, a)
|
||||
case int:
|
||||
err = s.stmt.BindInt(id, a)
|
||||
case int64:
|
||||
err = s.stmt.BindInt64(id, a)
|
||||
case float64:
|
||||
@@ -202,6 +267,8 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
err = s.stmt.BindText(id, a)
|
||||
case []byte:
|
||||
err = s.stmt.BindBlob(id, a)
|
||||
case sqlite3.ZeroBlob:
|
||||
err = s.stmt.BindZeroBlob(id, int64(a))
|
||||
case time.Time:
|
||||
err = s.stmt.BindText(id, a.Format(time.RFC3339Nano))
|
||||
case nil:
|
||||
@@ -218,6 +285,16 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
return rows{ctx, s.stmt, s.conn}, nil
|
||||
}
|
||||
|
||||
func (s stmt) CheckNamedValue(arg *driver.NamedValue) error {
|
||||
switch arg.Value.(type) {
|
||||
case bool, int, int64, float64, string, []byte,
|
||||
sqlite3.ZeroBlob, time.Time, nil:
|
||||
return nil
|
||||
default:
|
||||
return driver.ErrSkip
|
||||
}
|
||||
}
|
||||
|
||||
type result struct{ lastInsertId, rowsAffected int64 }
|
||||
|
||||
func (r result) LastInsertId() (int64, error) {
|
||||
|
||||
348
driver/driver_test.go
Normal file
348
driver/driver_test.go
Normal file
@@ -0,0 +1,348 @@
|
||||
// Package driver provides a database/sql driver for SQLite.
|
||||
package driver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"math"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func Test_Open_dir(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ".")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Conn(context.TODO())
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
var serr *sqlite3.Error
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.CANTOPEN {
|
||||
t.Errorf("got %d, want sqlite3.CANTOPEN", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: unable to open database file` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Open_pragma(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", "file::memory:?_pragma=busy_timeout(1000)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
var timeout int
|
||||
err = db.QueryRow(`PRAGMA busy_timeout`).Scan(&timeout)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if timeout != 1000 {
|
||||
t.Errorf("got %v, want 1000", timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Open_pragma_invalid(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", "file::memory:?_pragma=busy_timeout+1000")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Conn(context.TODO())
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
var serr *sqlite3.Error
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.ERROR {
|
||||
t.Errorf("got %d, want sqlite3.ERROR", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: invalid _pragma: sqlite3: SQL logic error: near "1000": syntax error` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Open_txLock(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", "file:"+
|
||||
filepath.Join(t.TempDir(), "test.db")+
|
||||
"?_txlock=exclusive&_pragma=busy_timeout(0)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
tx1, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.Begin()
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
var serr *sqlite3.Error
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.BUSY {
|
||||
t.Errorf("got %d, want sqlite3.BUSY", rc)
|
||||
}
|
||||
var terr interface{ Temporary() bool }
|
||||
if !errors.As(err, &terr) || !terr.Temporary() {
|
||||
t.Error("not temporary", err)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: database is locked` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
|
||||
err = tx1.Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Open_txLock_invalid(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", "file::memory:?_txlock=xclusive")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Conn(context.TODO())
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: invalid _txlock: xclusive` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_BeginTx(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
|
||||
if err.Error() != string(isolationErr) {
|
||||
t.Error("want isolationErr")
|
||||
}
|
||||
|
||||
tx1, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tx2, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = tx1.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
var serr *sqlite3.Error
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.READONLY {
|
||||
t.Errorf("got %d, want sqlite3.READONLY", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: attempt to write a readonly database` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
|
||||
err = tx2.Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = tx1.Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Prepare(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, err := db.Prepare(`SELECT 1; -- HERE`)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
var serr *sqlite3.Error
|
||||
_, err = db.Prepare(`SELECT`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.ERROR {
|
||||
t.Errorf("got %d, want sqlite3.ERROR", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
|
||||
_, err = db.Prepare(`SELECT 1; SELECT`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.ERROR {
|
||||
t.Errorf("got %d, want sqlite3.ERROR", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
|
||||
_, err = db.Prepare(`SELECT 1; SELECT 2`)
|
||||
if err.Error() != string(tailErr) {
|
||||
t.Error("want tailErr")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_QueryRow_named(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn, err := db.Conn(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
stmt, err := conn.PrepareContext(ctx, `SELECT ?, ?5, :AAA, @AAA, $AAA`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
date := time.Now()
|
||||
row := stmt.QueryRow(true, sql.Named("AAA", math.Pi), nil /*3*/, nil /*4*/, date /*5*/)
|
||||
|
||||
var first bool
|
||||
var fifth time.Time
|
||||
var colon, at, dollar float32
|
||||
err = row.Scan(&first, &fifth, &colon, &at, &dollar)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if first != true {
|
||||
t.Errorf("want true, got %v", first)
|
||||
}
|
||||
if colon != math.Pi {
|
||||
t.Errorf("want π, got %v", colon)
|
||||
}
|
||||
if at != math.Pi {
|
||||
t.Errorf("want π, got %v", at)
|
||||
}
|
||||
if dollar != math.Pi {
|
||||
t.Errorf("want π, got %v", dollar)
|
||||
}
|
||||
if !fifth.Equal(date) {
|
||||
t.Errorf("want %v, got %v", date, fifth)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_QueryRow_blob_null(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows, err := db.Query(`
|
||||
SELECT NULL UNION ALL
|
||||
SELECT x'cafe' UNION ALL
|
||||
SELECT x'babe' UNION ALL
|
||||
SELECT NULL
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
want := [][]byte{nil, {0xca, 0xfe}, {0xba, 0xbe}, nil}
|
||||
for i := 0; rows.Next(); i++ {
|
||||
var buf []byte
|
||||
err = rows.Scan(&buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(buf, want[i]) {
|
||||
t.Errorf("got %q, want %q", buf, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ZeroBlob(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn, err := db.Conn(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = conn.ExecContext(ctx, `INSERT INTO test(col) VALUES(?)`, sqlite3.ZeroBlob(4))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var got []byte
|
||||
err = conn.QueryRowContext(ctx, `SELECT col FROM test`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(got) != "\x00\x00\x00\x00" {
|
||||
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ type errorString string
|
||||
func (e errorString) Error() string { return string(e) }
|
||||
|
||||
const (
|
||||
assertErr = errorString("sqlite3: assertion failed")
|
||||
tailErr = errorString("sqlite3: multiple statements")
|
||||
assertErr = errorString("sqlite3: assertion failed")
|
||||
tailErr = errorString("sqlite3: multiple statements")
|
||||
isolationErr = errorString("sqlite3: unsupported isolation level")
|
||||
)
|
||||
|
||||
@@ -8,9 +8,21 @@ import (
|
||||
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
|
||||
// if it roundtrips back to the same string.
|
||||
// This way times can be persisted to, and recovered from, the database,
|
||||
// but if a string is needed, [database.sql] will recover the same string.
|
||||
// TODO: optimize and fuzz test.
|
||||
// but if a string is needed, [database/sql] will recover the same string.
|
||||
func maybeDate(text string) driver.Value {
|
||||
// Weed out (some) values that can't possibly be
|
||||
// [time.RFC3339Nano] timestamps.
|
||||
if len(text) < len("2006-01-02T15:04:05Z") {
|
||||
return text
|
||||
}
|
||||
if len(text) > len(time.RFC3339Nano) {
|
||||
return text
|
||||
}
|
||||
if text[4] != '-' || text[10] != 'T' || text[16] != ':' {
|
||||
return text
|
||||
}
|
||||
|
||||
// Slow path.
|
||||
date, err := time.Parse(time.RFC3339Nano, text)
|
||||
if err == nil && date.Format(time.RFC3339Nano) == text {
|
||||
return date
|
||||
|
||||
46
driver/time_test.go
Normal file
46
driver/time_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package driver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Fuzz_maybeDate(f *testing.F) {
|
||||
f.Add("")
|
||||
f.Add(" ")
|
||||
f.Add("SQLite")
|
||||
f.Add(time.RFC3339)
|
||||
f.Add(time.RFC3339Nano)
|
||||
f.Add(time.Layout)
|
||||
f.Add(time.DateTime)
|
||||
f.Add(time.DateOnly)
|
||||
f.Add(time.TimeOnly)
|
||||
f.Add("2006-01-02T15:04:05Z")
|
||||
f.Add("2006-01-02T15:04:05.000Z")
|
||||
f.Add("2006-01-02T15:04:05.9999999999Z")
|
||||
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
|
||||
|
||||
f.Fuzz(func(t *testing.T, str string) {
|
||||
value := maybeDate(str)
|
||||
|
||||
switch v := value.(type) {
|
||||
case time.Time:
|
||||
// Make sure times round-trip to the same string:
|
||||
// https://pkg.go.dev/database/sql#Rows.Scan
|
||||
if v.Format(time.RFC3339Nano) != str {
|
||||
t.Fatalf("did not round-trip: %q", str)
|
||||
}
|
||||
case string:
|
||||
if v != str {
|
||||
t.Fatalf("did not round-trip: %q", str)
|
||||
}
|
||||
|
||||
date, err := time.Parse(time.RFC3339Nano, str)
|
||||
if err == nil && date.Format(time.RFC3339Nano) == str {
|
||||
t.Fatalf("would round-trip: %q", str)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("invalid type %T: %q", v, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
7
error.go
7
error.go
@@ -50,6 +50,11 @@ func (e *Error) Error() string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Temporary returns true for [BUSY] errors.
|
||||
func (e *Error) Temporary() bool {
|
||||
return e.Code() == BUSY
|
||||
}
|
||||
|
||||
// SQL returns the SQL starting at the token that triggered a syntax error.
|
||||
func (e *Error) SQL() string {
|
||||
return e.sql
|
||||
@@ -60,12 +65,14 @@ type errorString string
|
||||
func (e errorString) Error() string { return string(e) }
|
||||
|
||||
const (
|
||||
binaryErr = errorString("sqlite3: no SQLite binary embed/set/loaded")
|
||||
nilErr = errorString("sqlite3: invalid memory address or null pointer dereference")
|
||||
oomErr = errorString("sqlite3: out of memory")
|
||||
rangeErr = errorString("sqlite3: index out of range")
|
||||
noNulErr = errorString("sqlite3: missing NUL terminator")
|
||||
noGlobalErr = errorString("sqlite3: could not find global: ")
|
||||
noFuncErr = errorString("sqlite3: could not find function: ")
|
||||
timeErr = errorString("sqlite3: invalid time value")
|
||||
)
|
||||
|
||||
func assertErr() errorString {
|
||||
|
||||
2
go.mod
2
go.mod
@@ -4,7 +4,7 @@ go 1.19
|
||||
|
||||
require (
|
||||
github.com/ncruces/julianday v0.1.5
|
||||
github.com/tetratelabs/wazero v1.0.0-pre.8
|
||||
github.com/tetratelabs/wazero v1.0.0-pre.9
|
||||
golang.org/x/sync v0.1.0
|
||||
golang.org/x/sys v0.5.0
|
||||
)
|
||||
|
||||
4
go.sum
4
go.sum
@@ -1,7 +1,7 @@
|
||||
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
|
||||
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
|
||||
github.com/tetratelabs/wazero v1.0.0-pre.8 h1:Ir82PWj79WCppH+9ny73eGY2qv+oCnE3VwMY92cBSyI=
|
||||
github.com/tetratelabs/wazero v1.0.0-pre.8/go.mod h1:u8wrFmpdrykiFK0DFPiFm5a4+0RzsdmXYVtijBKqUVo=
|
||||
github.com/tetratelabs/wazero v1.0.0-pre.9 h1:2uVdi2bvTi/JQxG2cp3LRm2aRadd3nURn5jcfbvqZcw=
|
||||
github.com/tetratelabs/wazero v1.0.0-pre.9/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
|
||||
|
||||
161
mock_test.go
Normal file
161
mock_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"math"
|
||||
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Path = "./embed/sqlite3.wasm"
|
||||
}
|
||||
|
||||
func newMemory(size uint32) memory {
|
||||
mem := make(mockMemory, size)
|
||||
return memory{mockModule{&mem}}
|
||||
}
|
||||
|
||||
type mockModule struct {
|
||||
memory api.Memory
|
||||
}
|
||||
|
||||
func (m mockModule) Memory() api.Memory { return m.memory }
|
||||
func (m mockModule) String() string { return "mockModule" }
|
||||
func (m mockModule) Name() string { return "mockModule" }
|
||||
|
||||
func (m mockModule) ExportedGlobal(name string) api.Global { return nil }
|
||||
func (m mockModule) ExportedMemory(name string) api.Memory { return nil }
|
||||
func (m mockModule) ExportedFunction(name string) api.Function { return nil }
|
||||
func (m mockModule) ExportedMemoryDefinitions() map[string]api.MemoryDefinition { return nil }
|
||||
func (m mockModule) ExportedFunctionDefinitions() map[string]api.FunctionDefinition { return nil }
|
||||
func (m mockModule) CloseWithExitCode(ctx context.Context, exitCode uint32) error { return nil }
|
||||
func (m mockModule) Close(context.Context) error { return nil }
|
||||
|
||||
type mockMemory []byte
|
||||
|
||||
func (m mockMemory) Definition() api.MemoryDefinition { return nil }
|
||||
|
||||
func (m mockMemory) Size() uint32 { return uint32(len(m)) }
|
||||
|
||||
func (m mockMemory) ReadByte(offset uint32) (byte, bool) {
|
||||
if offset >= m.Size() {
|
||||
return 0, false
|
||||
}
|
||||
return m[offset], true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadUint16Le(offset uint32) (uint16, bool) {
|
||||
if !m.hasSize(offset, 2) {
|
||||
return 0, false
|
||||
}
|
||||
return binary.LittleEndian.Uint16(m[offset : offset+2]), true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadUint32Le(offset uint32) (uint32, bool) {
|
||||
if !m.hasSize(offset, 4) {
|
||||
return 0, false
|
||||
}
|
||||
return binary.LittleEndian.Uint32(m[offset : offset+4]), true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadFloat32Le(offset uint32) (float32, bool) {
|
||||
v, ok := m.ReadUint32Le(offset)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
return math.Float32frombits(v), true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadUint64Le(offset uint32) (uint64, bool) {
|
||||
if !m.hasSize(offset, 8) {
|
||||
return 0, false
|
||||
}
|
||||
return binary.LittleEndian.Uint64(m[offset : offset+8]), true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadFloat64Le(offset uint32) (float64, bool) {
|
||||
v, ok := m.ReadUint64Le(offset)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
return math.Float64frombits(v), true
|
||||
}
|
||||
|
||||
func (m mockMemory) Read(offset, byteCount uint32) ([]byte, bool) {
|
||||
if !m.hasSize(offset, byteCount) {
|
||||
return nil, false
|
||||
}
|
||||
return m[offset : offset+byteCount : offset+byteCount], true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteByte(offset uint32, v byte) bool {
|
||||
if offset >= m.Size() {
|
||||
return false
|
||||
}
|
||||
m[offset] = v
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteUint16Le(offset uint32, v uint16) bool {
|
||||
if !m.hasSize(offset, 2) {
|
||||
return false
|
||||
}
|
||||
binary.LittleEndian.PutUint16(m[offset:], v)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteUint32Le(offset, v uint32) bool {
|
||||
if !m.hasSize(offset, 4) {
|
||||
return false
|
||||
}
|
||||
binary.LittleEndian.PutUint32(m[offset:], v)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteFloat32Le(offset uint32, v float32) bool {
|
||||
return m.WriteUint32Le(offset, math.Float32bits(v))
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteUint64Le(offset uint32, v uint64) bool {
|
||||
if !m.hasSize(offset, 8) {
|
||||
return false
|
||||
}
|
||||
binary.LittleEndian.PutUint64(m[offset:], v)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteFloat64Le(offset uint32, v float64) bool {
|
||||
return m.WriteUint64Le(offset, math.Float64bits(v))
|
||||
}
|
||||
|
||||
func (m mockMemory) Write(offset uint32, val []byte) bool {
|
||||
if !m.hasSize(offset, uint32(len(val))) {
|
||||
return false
|
||||
}
|
||||
copy(m[offset:], val)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteString(offset uint32, val string) bool {
|
||||
if !m.hasSize(offset, uint32(len(val))) {
|
||||
return false
|
||||
}
|
||||
copy(m[offset:], val)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *mockMemory) Grow(delta uint32) (result uint32, ok bool) {
|
||||
prev := (len(*m) + 65535) / 65536
|
||||
*m = append(*m, make([]byte, 65536*delta)...)
|
||||
return uint32(prev), true
|
||||
}
|
||||
|
||||
func (m mockMemory) PageSize() (result uint32) {
|
||||
return uint32(len(m) / 65536)
|
||||
}
|
||||
|
||||
func (m mockMemory) hasSize(offset uint32, byteCount uint32) bool {
|
||||
return uint64(offset)+uint64(byteCount) <= uint64(len(m))
|
||||
}
|
||||
61
stmt.go
61
stmt.go
@@ -2,6 +2,7 @@ package sqlite3
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Stmt is a prepared statement object.
|
||||
@@ -15,6 +16,8 @@ type Stmt struct {
|
||||
|
||||
// Close destroys the prepared statement object.
|
||||
//
|
||||
// It is safe to close a nil, zero or closed prepared statement.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/finalize.html
|
||||
func (s *Stmt) Close() error {
|
||||
if s == nil || s.handle == 0 {
|
||||
@@ -219,6 +222,19 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
|
||||
return s.c.error(r[0])
|
||||
}
|
||||
|
||||
// BindZeroBlob binds a zero-filled, length n BLOB to the prepared statement.
|
||||
// The leftmost SQL parameter has an index of 1.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindZeroBlob(param int, n int64) error {
|
||||
r, err := s.c.api.bindZeroBlob.Call(s.c.ctx,
|
||||
uint64(s.handle), uint64(param), uint64(n))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return s.c.error(r[0])
|
||||
}
|
||||
|
||||
// BindNull binds a NULL to the prepared statement.
|
||||
// The leftmost SQL parameter has an index of 1.
|
||||
//
|
||||
@@ -232,6 +248,24 @@ func (s *Stmt) BindNull(param int) error {
|
||||
return s.c.error(r[0])
|
||||
}
|
||||
|
||||
// BindTime binds a [time.Time] to the prepared statement.
|
||||
// The leftmost SQL parameter has an index of 1.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
|
||||
switch v := format.Encode(value).(type) {
|
||||
case string:
|
||||
s.BindText(param, v)
|
||||
case int64:
|
||||
s.BindInt64(param, v)
|
||||
case float64:
|
||||
s.BindFloat(param, v)
|
||||
default:
|
||||
panic(assertErr())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ColumnCount returns the number of columns in a result set.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_count.html
|
||||
@@ -257,7 +291,7 @@ func (s *Stmt) ColumnName(col int) string {
|
||||
|
||||
ptr := uint32(r[0])
|
||||
if ptr == 0 {
|
||||
return ""
|
||||
panic(oomErr)
|
||||
}
|
||||
return s.c.mem.readString(ptr, _MAX_STRING)
|
||||
}
|
||||
@@ -323,6 +357,31 @@ func (s *Stmt) ColumnFloat(col int) float64 {
|
||||
return math.Float64frombits(r[0])
|
||||
}
|
||||
|
||||
// ColumnTime returns the value of the result column as a [time.Time].
|
||||
// The leftmost column of the result set has the index 0.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_blob.html
|
||||
func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time {
|
||||
var v any
|
||||
switch s.ColumnType(col) {
|
||||
case INTEGER:
|
||||
v = s.ColumnInt64(col)
|
||||
case FLOAT:
|
||||
v = s.ColumnFloat(col)
|
||||
case TEXT, BLOB:
|
||||
v = s.ColumnText(col)
|
||||
case NULL:
|
||||
return time.Time{}
|
||||
default:
|
||||
panic(assertErr())
|
||||
}
|
||||
t, err := format.Decode(v)
|
||||
if err != nil {
|
||||
s.err = err
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// ColumnText returns the value of the result column as a string.
|
||||
// The leftmost column of the result set has the index 0.
|
||||
//
|
||||
|
||||
180
tests/bradfitz/sql_test.go
Normal file
180
tests/bradfitz/sql_test.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package bradfitz
|
||||
|
||||
// Adapted from: https://github.com/bradfitz/go-sql-test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
type Tester interface {
|
||||
RunTest(*testing.T, func(params))
|
||||
}
|
||||
|
||||
var (
|
||||
sqlite Tester = sqliteDB{}
|
||||
)
|
||||
|
||||
const TablePrefix = "gosqltest_"
|
||||
|
||||
type sqliteDB struct{}
|
||||
|
||||
type params struct {
|
||||
dbType Tester
|
||||
*testing.T
|
||||
*sql.DB
|
||||
}
|
||||
|
||||
func (t params) mustExec(sql string, args ...interface{}) sql.Result {
|
||||
res, err := t.DB.Exec(sql, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("Error running %q: %v", sql, err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (sqliteDB) RunTest(t *testing.T, fn func(params)) {
|
||||
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "foo.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("foo.db open fail: %v", err)
|
||||
}
|
||||
fn(params{sqlite, t, db})
|
||||
if err := db.Close(); err != nil {
|
||||
t.Fatalf("foo.db close fail: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlobs_SQLite(t *testing.T) { sqlite.RunTest(t, testBlobs) }
|
||||
|
||||
func testBlobs(t params) {
|
||||
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||
t.mustExec("create table " + TablePrefix + "foo (id integer primary key, bar blob)")
|
||||
t.mustExec("insert into "+TablePrefix+"foo (id, bar) values(?,?)", 0, blob)
|
||||
|
||||
want := fmt.Sprintf("%x", blob)
|
||||
|
||||
b := make([]byte, 16)
|
||||
err := t.QueryRow("select bar from "+TablePrefix+"foo where id = ?", 0).Scan(&b)
|
||||
got := fmt.Sprintf("%x", b)
|
||||
if err != nil {
|
||||
t.Errorf("[]byte scan: %v", err)
|
||||
} else if got != want {
|
||||
t.Errorf("for []byte, got %q; want %q", got, want)
|
||||
}
|
||||
|
||||
err = t.QueryRow("select bar from "+TablePrefix+"foo where id = ?", 0).Scan(&got)
|
||||
want = string(blob)
|
||||
if err != nil {
|
||||
t.Errorf("string scan: %v", err)
|
||||
} else if got != want {
|
||||
t.Errorf("for string, got %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManyQueryRow_SQLite(t *testing.T) { sqlite.RunTest(t, testManyQueryRow) }
|
||||
|
||||
func testManyQueryRow(t params) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
t.mustExec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))")
|
||||
t.mustExec("insert into "+TablePrefix+"foo (id, name) values(?,?)", 1, "bob")
|
||||
var name string
|
||||
for i := 0; i < 10000; i++ {
|
||||
err := t.QueryRow("select name from "+TablePrefix+"foo where id = ?", 1).Scan(&name)
|
||||
if err != nil || name != "bob" {
|
||||
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTxQuery_SQLite(t *testing.T) { sqlite.RunTest(t, testTxQuery) }
|
||||
|
||||
func testTxQuery(t params) {
|
||||
tx, err := t.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = t.DB.Exec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))")
|
||||
if err != nil {
|
||||
t.Logf("cannot drop table "+TablePrefix+"foo: %s", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec("insert into "+TablePrefix+"foo (id, name) values(?,?)", 1, "bob")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := tx.Query("select name from "+TablePrefix+"foo where id = ?", 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
if !r.Next() {
|
||||
if r.Err() != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Fatal("expected one rows")
|
||||
}
|
||||
|
||||
var name string
|
||||
err = r.Scan(&name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreparedStmt_SQLite(t *testing.T) { sqlite.RunTest(t, testPreparedStmt) }
|
||||
|
||||
func testPreparedStmt(t params) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
t.mustExec("CREATE TABLE " + TablePrefix + "t (count INT)")
|
||||
sel, err := t.Prepare("SELECT count FROM " + TablePrefix + "t ORDER BY count DESC")
|
||||
if err != nil {
|
||||
t.Fatalf("prepare 1: %v", err)
|
||||
}
|
||||
ins, err := t.Prepare("INSERT INTO " + TablePrefix + "t (count) VALUES (?)")
|
||||
if err != nil {
|
||||
t.Fatalf("prepare 2: %v", err)
|
||||
}
|
||||
|
||||
for n := 1; n <= 3; n++ {
|
||||
if _, err := ins.Exec(n); err != nil {
|
||||
t.Fatalf("insert(%d) = %v", n, err)
|
||||
}
|
||||
}
|
||||
|
||||
const nRuns = 10
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < nRuns; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 10; j++ {
|
||||
count := 0
|
||||
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
|
||||
t.Errorf("Query: %v", err)
|
||||
return
|
||||
}
|
||||
if _, err := ins.Exec(rand.Intn(100)); err != nil {
|
||||
t.Errorf("Insert: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package compile_empty
|
||||
package compile
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package compile_empty
|
||||
package compile
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestCompile_empty(t *testing.T) {
|
||||
func TestCompile_missing(t *testing.T) {
|
||||
sqlite3.Path = "sqlite3.wasm"
|
||||
_, err := sqlite3.Open(":memory:")
|
||||
if err == nil {
|
||||
|
||||
14
tests/compile/nil/compile_test.go
Normal file
14
tests/compile/nil/compile_test.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package compile
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestCompile_nil(t *testing.T) {
|
||||
_, err := sqlite3.Open(":memory:")
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
}
|
||||
229
tests/conn_test.go
Normal file
229
tests/conn_test.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestConn_Open_dir(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := sqlite3.Open(".")
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
var serr *sqlite3.Error
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.CANTOPEN {
|
||||
t.Errorf("got %d, want sqlite3.CANTOPEN", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: unable to open database file` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Close(t *testing.T) {
|
||||
var conn *sqlite3.Conn
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestConn_Close_BUSY(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(`BEGIN`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
err = db.Close()
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
var serr *sqlite3.Error
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.BUSY {
|
||||
t.Errorf("got %d, want sqlite3.BUSY", rc)
|
||||
}
|
||||
var terr interface{ Temporary() bool }
|
||||
if !errors.As(err, &terr) || !terr.Temporary() {
|
||||
t.Error("not temporary", err)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: database is locked: unable to close due to unfinalized statements or unfinished backups` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_SetInterrupt(t *testing.T) {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
db.SetInterrupt(ctx.Done())
|
||||
|
||||
// Interrupt doesn't interrupt this.
|
||||
err = db.Exec(`SELECT 1`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.SetInterrupt(nil)
|
||||
|
||||
stmt, _, err := db.Prepare(`
|
||||
WITH RECURSIVE
|
||||
fibonacci (curr, next)
|
||||
AS (
|
||||
SELECT 0, 1
|
||||
UNION ALL
|
||||
SELECT next, curr + next FROM fibonacci
|
||||
LIMIT 1e6
|
||||
)
|
||||
SELECT min(curr) FROM fibonacci
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
cancel()
|
||||
db.SetInterrupt(ctx.Done())
|
||||
|
||||
var serr *sqlite3.Error
|
||||
|
||||
// Interrupting works.
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.INTERRUPT {
|
||||
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: interrupted` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Interrupting sticks.
|
||||
err = db.Exec(`SELECT 1`)
|
||||
if err != nil {
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.INTERRUPT {
|
||||
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: interrupted` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
db.SetInterrupt(nil)
|
||||
|
||||
// Interrupting can be cleared.
|
||||
err = db.Exec(`SELECT 1`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Prepare_empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(``)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt != nil {
|
||||
t.Error("want nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Prepare_tail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, tail, err := db.Prepare(`SELECT 1; -- HERE`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if !strings.Contains(tail, "-- HERE") {
|
||||
t.Errorf("got %q", tail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Prepare_invalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
var serr *sqlite3.Error
|
||||
|
||||
_, _, err = db.Prepare(`SELECT`)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.ERROR {
|
||||
t.Errorf("got %d, want sqlite3.ERROR", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
|
||||
_, _, err = db.Prepare(`SELECT * FRM sqlite_schema`)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.ERROR", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.ERROR {
|
||||
t.Errorf("got %d, want sqlite3.ERROR", rc)
|
||||
}
|
||||
if got := serr.SQL(); got != `FRM sqlite_schema` {
|
||||
t.Error("got SQL: ", got)
|
||||
}
|
||||
if got := serr.Error(); got != `sqlite3: SQL logic error: near "FRM": syntax error` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,8 @@ func TestDB_file(t *testing.T) {
|
||||
}
|
||||
|
||||
func testDB(t *testing.T, name string) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -32,6 +34,10 @@ func testDB(t *testing.T, name string) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
changes := db.Changes()
|
||||
if changes != 3 {
|
||||
t.Errorf("got %d want 3", changes)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
|
||||
if err != nil {
|
||||
@@ -43,18 +49,22 @@ func testDB(t *testing.T, name string) {
|
||||
ids := []int{0, 1, 2}
|
||||
names := []string{"go", "zig", "whatever"}
|
||||
for ; stmt.Step(); row++ {
|
||||
if ids[row] != stmt.ColumnInt(0) {
|
||||
t.Errorf("got %d, want %d", stmt.ColumnInt(0), ids[row])
|
||||
id := stmt.ColumnInt(0)
|
||||
name := stmt.ColumnText(1)
|
||||
|
||||
if id != ids[row] {
|
||||
t.Errorf("got %d, want %d", id, ids[row])
|
||||
}
|
||||
if names[row] != stmt.ColumnText(1) {
|
||||
t.Errorf("got %q, want %q", stmt.ColumnText(1), names[row])
|
||||
if name != names[row] {
|
||||
t.Errorf("got %q, want %q", name, names[row])
|
||||
}
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if row != 3 {
|
||||
t.Errorf("got %d rows, want %d", row, len(ids))
|
||||
t.Errorf("got %d, want %d", row, len(ids))
|
||||
}
|
||||
|
||||
if err := stmt.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestDir(t *testing.T) {
|
||||
_, err := sqlite3.Open(".")
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
var serr *sqlite3.Error
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.CANTOPEN {
|
||||
t.Errorf("got %d, want sqlite3.CANTOPEN", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: unable to open database file` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
103
tests/driver_test.go
Normal file
103
tests/driver_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestDriver(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn, err := db.Conn(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.ExecContext(ctx,
|
||||
`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
res, err := conn.ExecContext(ctx,
|
||||
`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
changes, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if changes != 3 {
|
||||
t.Errorf("got %d want 3", changes)
|
||||
}
|
||||
|
||||
stmt, err := conn.PrepareContext(context.Background(),
|
||||
`SELECT id, name FROM users`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
row := 0
|
||||
ids := []int{0, 1, 2}
|
||||
names := []string{"go", "zig", "whatever"}
|
||||
for ; rows.Next(); row++ {
|
||||
var id int
|
||||
var name string
|
||||
err := rows.Scan(&id, &name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if id != ids[row] {
|
||||
t.Errorf("got %d, want %d", id, ids[row])
|
||||
}
|
||||
if name != names[row] {
|
||||
t.Errorf("got %q, want %q", name, names[row])
|
||||
}
|
||||
}
|
||||
if row != 3 {
|
||||
t.Errorf("got %d, want %d", row, len(ids))
|
||||
}
|
||||
|
||||
err = rows.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -21,7 +22,7 @@ func TestParallel(t *testing.T) {
|
||||
|
||||
func TestMultiProcess(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
@@ -44,7 +45,11 @@ func TestMultiProcess(t *testing.T) {
|
||||
|
||||
testParallel(t, name, 1000)
|
||||
if err := cmd.Wait(); err != nil {
|
||||
t.Fatal(err)
|
||||
t.Error(err)
|
||||
var eerr *exec.ExitError
|
||||
if errors.As(err, &eerr) {
|
||||
t.Error(eerr.Stderr)
|
||||
}
|
||||
}
|
||||
testIntegrity(t, name)
|
||||
}
|
||||
@@ -52,7 +57,7 @@ func TestMultiProcess(t *testing.T) {
|
||||
func TestChildProcess(t *testing.T) {
|
||||
name := os.Getenv("TestMultiProcess_dbname")
|
||||
if name == "" || testing.Short() {
|
||||
return
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
testParallel(t, name, 1000)
|
||||
@@ -1,14 +1,17 @@
|
||||
package sqlite3
|
||||
package tests
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestStmt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -29,103 +32,80 @@ func TestStmt(t *testing.T) {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
|
||||
err = stmt.BindBool(1, false)
|
||||
if err != nil {
|
||||
if err := stmt.BindBool(1, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
if err := stmt.BindBool(1, true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.ClearBindings()
|
||||
if err != nil {
|
||||
if err := stmt.BindInt(1, 2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
if err := stmt.BindFloat(1, math.Pi); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.BindBool(1, true)
|
||||
if err != nil {
|
||||
if err := stmt.BindNull(1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
if err := stmt.BindText(1, ""); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.BindInt(1, 2)
|
||||
if err != nil {
|
||||
if err := stmt.BindText(1, "text"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
if err := stmt.BindBlob(1, []byte("blob")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.BindFloat(1, math.Pi)
|
||||
if err != nil {
|
||||
if err := stmt.BindBlob(1, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
if err := stmt.BindZeroBlob(1, 4); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.BindNull(1)
|
||||
if err != nil {
|
||||
if err := stmt.ClearBindings(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.BindText(1, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.BindText(1, "text")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.BindBlob(1, []byte("blob"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.BindBlob(1, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -134,7 +114,7 @@ func TestStmt(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// The table should have: 0, NULL, 1, 2, π, NULL, "", "text", `blob`, NULL
|
||||
// The table should have: 0, 1, 2, π, NULL, "", "text", "blob", NULL, "\0\0\0\0", NULL
|
||||
stmt, _, err = db.Prepare(`SELECT col FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -142,7 +122,7 @@ func TestStmt(t *testing.T) {
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != INTEGER {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
@@ -163,28 +143,7 @@ func TestStmt(t *testing.T) {
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != INTEGER {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != true {
|
||||
@@ -205,7 +164,7 @@ func TestStmt(t *testing.T) {
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != INTEGER {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != true {
|
||||
@@ -226,7 +185,7 @@ func TestStmt(t *testing.T) {
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != FLOAT {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.FLOAT {
|
||||
t.Errorf("got %v, want FLOAT", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != true {
|
||||
@@ -247,7 +206,7 @@ func TestStmt(t *testing.T) {
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != NULL {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
@@ -268,7 +227,7 @@ func TestStmt(t *testing.T) {
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != TEXT {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
@@ -289,7 +248,7 @@ func TestStmt(t *testing.T) {
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != TEXT {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
@@ -310,7 +269,7 @@ func TestStmt(t *testing.T) {
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != BLOB {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
|
||||
t.Errorf("got %v, want BLOB", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
@@ -331,7 +290,7 @@ func TestStmt(t *testing.T) {
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != NULL {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
@@ -351,24 +310,66 @@ func TestStmt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
|
||||
t.Errorf("got %v, want BLOB", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "\x00\x00\x00\x00" {
|
||||
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "\x00\x00\x00\x00" {
|
||||
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
if err := stmt.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
if err := db.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStmt_Close(t *testing.T) {
|
||||
var stmt *Stmt
|
||||
var stmt *sqlite3.Stmt
|
||||
stmt.Close()
|
||||
}
|
||||
|
||||
func TestStmt_BindName(t *testing.T) {
|
||||
db, err := Open(":memory:")
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -398,3 +399,65 @@ func TestStmt_BindName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStmt_Time(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT ?, ?, ?, datetime(), unixepoch(), julianday(), NULL, 'abc'`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
|
||||
err = stmt.BindTime(1, reference, sqlite3.TimeFormat4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = stmt.BindTime(2, reference, sqlite3.TimeFormatUnixMilli)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = stmt.BindTime(3, reference, sqlite3.TimeFormatJulianDay)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if now := time.Now(); stmt.Step() {
|
||||
if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); !reference.Equal(got) {
|
||||
t.Errorf("got %v, want %v", got, reference)
|
||||
}
|
||||
if got := stmt.ColumnTime(1, sqlite3.TimeFormatAuto); !reference.Equal(got) {
|
||||
t.Errorf("got %v, want %v", got, reference)
|
||||
}
|
||||
if got := stmt.ColumnTime(2, sqlite3.TimeFormatAuto); reference.Sub(got) > time.Millisecond {
|
||||
t.Errorf("got %v, want %v", got, reference)
|
||||
}
|
||||
|
||||
if got := stmt.ColumnTime(3, sqlite3.TimeFormatAuto); now.Sub(got) > time.Second {
|
||||
t.Errorf("got %v, want %v", got, now)
|
||||
}
|
||||
if got := stmt.ColumnTime(4, sqlite3.TimeFormatAuto); now.Sub(got) > time.Second {
|
||||
t.Errorf("got %v, want %v", got, now)
|
||||
}
|
||||
if got := stmt.ColumnTime(5, sqlite3.TimeFormatAuto); now.Sub(got) > time.Millisecond {
|
||||
t.Errorf("got %v, want %v", got, now)
|
||||
}
|
||||
|
||||
if got := stmt.ColumnTime(6, sqlite3.TimeFormatAuto); got != (time.Time{}) {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnTime(7, sqlite3.TimeFormatAuto); got != (time.Time{}) {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if stmt.Err() == nil {
|
||||
t.Errorf("want error")
|
||||
}
|
||||
}
|
||||
}
|
||||
317
time.go
Normal file
317
time.go
Normal file
@@ -0,0 +1,317 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/julianday"
|
||||
)
|
||||
|
||||
// TimeFormat specifies how to encode/decode time values.
|
||||
//
|
||||
// https://www.sqlite.org/lang_datefunc.html
|
||||
type TimeFormat string
|
||||
|
||||
// TimeFormats recognized by SQLite to encode/decode time values.
|
||||
//
|
||||
// https://www.sqlite.org/lang_datefunc.html
|
||||
const (
|
||||
TimeFormatDefault TimeFormat = "" // time.RFC3339Nano
|
||||
|
||||
// Text formats
|
||||
TimeFormat1 TimeFormat = "2006-01-02"
|
||||
TimeFormat2 TimeFormat = "2006-01-02 15:04"
|
||||
TimeFormat3 TimeFormat = "2006-01-02 15:04:05"
|
||||
TimeFormat4 TimeFormat = "2006-01-02 15:04:05.000"
|
||||
TimeFormat5 TimeFormat = "2006-01-02T15:04"
|
||||
TimeFormat6 TimeFormat = "2006-01-02T15:04:05"
|
||||
TimeFormat7 TimeFormat = "2006-01-02T15:04:05.000"
|
||||
TimeFormat8 TimeFormat = "15:04"
|
||||
TimeFormat9 TimeFormat = "15:04:05"
|
||||
TimeFormat10 TimeFormat = "15:04:05.000"
|
||||
|
||||
TimeFormat2TZ = TimeFormat2 + "Z07:00"
|
||||
TimeFormat3TZ = TimeFormat3 + "Z07:00"
|
||||
TimeFormat4TZ = TimeFormat4 + "Z07:00"
|
||||
TimeFormat5TZ = TimeFormat5 + "Z07:00"
|
||||
TimeFormat6TZ = TimeFormat6 + "Z07:00"
|
||||
TimeFormat7TZ = TimeFormat7 + "Z07:00"
|
||||
TimeFormat8TZ = TimeFormat8 + "Z07:00"
|
||||
TimeFormat9TZ = TimeFormat9 + "Z07:00"
|
||||
TimeFormat10TZ = TimeFormat10 + "Z07:00"
|
||||
|
||||
// Numeric formats
|
||||
TimeFormatJulianDay TimeFormat = "julianday"
|
||||
TimeFormatUnix TimeFormat = "unixepoch"
|
||||
TimeFormatUnixFrac TimeFormat = "unixepoch_frac"
|
||||
TimeFormatUnixMilli TimeFormat = "unixepoch_milli" // not an SQLite format
|
||||
TimeFormatUnixMicro TimeFormat = "unixepoch_micro" // not an SQLite format
|
||||
TimeFormatUnixNano TimeFormat = "unixepoch_nano" // not an SQLite format
|
||||
|
||||
// Auto
|
||||
TimeFormatAuto TimeFormat = "auto"
|
||||
)
|
||||
|
||||
// Encode encodes a time value using this format.
|
||||
//
|
||||
// [TimeFormatDefault] and [TimeFormatAuto] encode using [time.RFC3339Nano],
|
||||
// with nanosecond accuracy, and preserving timezone.
|
||||
//
|
||||
// Formats [TimeFormat1] through [TimeFormat10]
|
||||
// convert time values to UTC before encoding.
|
||||
//
|
||||
// Returns a string for the text formats,
|
||||
// a float64 for [TimeFormatJulianDay] and [TimeFormatUnixFrac],
|
||||
// or an int64 for the other numeric formats.
|
||||
//
|
||||
// https://www.sqlite.org/lang_datefunc.html
|
||||
func (f TimeFormat) Encode(t time.Time) any {
|
||||
switch f {
|
||||
// Numeric formats
|
||||
case TimeFormatJulianDay:
|
||||
return julianday.Float(t)
|
||||
case TimeFormatUnix:
|
||||
return t.Unix()
|
||||
case TimeFormatUnixFrac:
|
||||
return float64(t.Unix()) + float64(t.Nanosecond())*1e-9
|
||||
case TimeFormatUnixMilli:
|
||||
return t.UnixMilli()
|
||||
case TimeFormatUnixMicro:
|
||||
return t.UnixMicro()
|
||||
case TimeFormatUnixNano:
|
||||
return t.UnixNano()
|
||||
// Special formats
|
||||
case TimeFormatDefault, TimeFormatAuto:
|
||||
f = time.RFC3339Nano
|
||||
// SQLite assumes UTC if unspecified.
|
||||
case
|
||||
TimeFormat1, TimeFormat2,
|
||||
TimeFormat3, TimeFormat4,
|
||||
TimeFormat5, TimeFormat6,
|
||||
TimeFormat7, TimeFormat8,
|
||||
TimeFormat9, TimeFormat10:
|
||||
t = t.UTC()
|
||||
}
|
||||
return t.Format(string(f))
|
||||
}
|
||||
|
||||
// Decode decodes a time value using this format.
|
||||
//
|
||||
// The time value can be a string, an int64, or a float64.
|
||||
//
|
||||
// Formats [TimeFormat8] through [TimeFormat10]
|
||||
// assume a date of 2000-01-01.
|
||||
//
|
||||
// The timezone indicator and fractional seconds are always optional
|
||||
// for formats [TimeFormat2] through [TimeFormat10].
|
||||
//
|
||||
// [TimeFormatAuto] implements (and extends) the SQLite auto modifier.
|
||||
// The julian day number is safe to use for historical dates,
|
||||
// from 4712BC through 9999AD.
|
||||
// Unix timestamps (expressed in seconds, milliseconds, microseconds, or nanoseconds),
|
||||
// are safe to use for current events, from 1980 through at least 2260.
|
||||
//
|
||||
// https://www.sqlite.org/lang_datefunc.html
|
||||
func (f TimeFormat) Decode(v any) (time.Time, error) {
|
||||
switch f {
|
||||
// Numeric formats
|
||||
case TimeFormatJulianDay:
|
||||
switch v := v.(type) {
|
||||
case string:
|
||||
return julianday.Parse(v)
|
||||
case float64:
|
||||
return julianday.FloatTime(v), nil
|
||||
case int64:
|
||||
return julianday.Time(v, 0), nil
|
||||
default:
|
||||
return time.Time{}, timeErr
|
||||
}
|
||||
|
||||
case TimeFormatUnix, TimeFormatUnixFrac:
|
||||
if s, ok := v.(string); ok {
|
||||
f, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
v = f
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
sec, frac := math.Modf(v)
|
||||
nsec := math.Floor(frac * 1e9)
|
||||
return time.Unix(int64(sec), int64(nsec)), nil
|
||||
case int64:
|
||||
return time.Unix(v, 0), nil
|
||||
default:
|
||||
return time.Time{}, timeErr
|
||||
}
|
||||
|
||||
case TimeFormatUnixMilli:
|
||||
if s, ok := v.(string); ok {
|
||||
i, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
v = i
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
return time.UnixMilli(int64(math.Floor(v))), nil
|
||||
case int64:
|
||||
return time.UnixMilli(int64(v)), nil
|
||||
default:
|
||||
return time.Time{}, timeErr
|
||||
}
|
||||
|
||||
case TimeFormatUnixMicro:
|
||||
if s, ok := v.(string); ok {
|
||||
i, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
v = i
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
return time.UnixMicro(int64(math.Floor(v))), nil
|
||||
case int64:
|
||||
return time.UnixMicro(int64(v)), nil
|
||||
default:
|
||||
return time.Time{}, timeErr
|
||||
}
|
||||
|
||||
case TimeFormatUnixNano:
|
||||
if s, ok := v.(string); ok {
|
||||
i, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return time.Time{}, timeErr
|
||||
}
|
||||
v = i
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
return time.Unix(0, int64(math.Floor(v))), nil
|
||||
case int64:
|
||||
return time.Unix(0, int64(v)), nil
|
||||
default:
|
||||
return time.Time{}, timeErr
|
||||
}
|
||||
|
||||
// Special formats
|
||||
case TimeFormatAuto:
|
||||
switch s := v.(type) {
|
||||
case string:
|
||||
i, err := strconv.ParseInt(s, 10, 64)
|
||||
if err == nil {
|
||||
v = i
|
||||
break
|
||||
}
|
||||
f, err := strconv.ParseFloat(s, 64)
|
||||
if err == nil {
|
||||
v = f
|
||||
break
|
||||
}
|
||||
|
||||
dates := []TimeFormat{
|
||||
TimeFormat6TZ, TimeFormat6, TimeFormat3TZ, TimeFormat3,
|
||||
TimeFormat5TZ, TimeFormat5, TimeFormat2TZ, TimeFormat2,
|
||||
TimeFormat1,
|
||||
}
|
||||
for _, f := range dates {
|
||||
t, err := time.Parse(string(f), s)
|
||||
if err == nil {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
times := []TimeFormat{
|
||||
TimeFormat9TZ, TimeFormat9, TimeFormat8TZ, TimeFormat8,
|
||||
}
|
||||
for _, f := range times {
|
||||
t, err := time.Parse(string(f), s)
|
||||
if err == nil {
|
||||
return t.AddDate(2000, 0, 0), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
if 0 <= v && v < 5373484.5 {
|
||||
return TimeFormatJulianDay.Decode(v)
|
||||
}
|
||||
if v < 253402300800 {
|
||||
return TimeFormatUnixFrac.Decode(v)
|
||||
}
|
||||
if v < 253402300800_000 {
|
||||
return TimeFormatUnixMilli.Decode(v)
|
||||
}
|
||||
if v < 253402300800_000000 {
|
||||
return TimeFormatUnixMicro.Decode(v)
|
||||
}
|
||||
return TimeFormatUnixNano.Decode(v)
|
||||
case int64:
|
||||
if 0 <= v && v < 5373485 {
|
||||
return TimeFormatJulianDay.Decode(v)
|
||||
}
|
||||
if v < 253402300800 {
|
||||
return TimeFormatUnixFrac.Decode(v)
|
||||
}
|
||||
if v < 253402300800_000 {
|
||||
return TimeFormatUnixMilli.Decode(v)
|
||||
}
|
||||
if v < 253402300800_000000 {
|
||||
return TimeFormatUnixMicro.Decode(v)
|
||||
}
|
||||
return TimeFormatUnixNano.Decode(v)
|
||||
default:
|
||||
return time.Time{}, timeErr
|
||||
}
|
||||
|
||||
case
|
||||
TimeFormat2, TimeFormat2TZ,
|
||||
TimeFormat3, TimeFormat3TZ,
|
||||
TimeFormat4, TimeFormat4TZ,
|
||||
TimeFormat5, TimeFormat5TZ,
|
||||
TimeFormat6, TimeFormat6TZ,
|
||||
TimeFormat7, TimeFormat7TZ:
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return time.Time{}, timeErr
|
||||
}
|
||||
return f.parseRelaxed(s)
|
||||
|
||||
case
|
||||
TimeFormat8, TimeFormat8TZ,
|
||||
TimeFormat9, TimeFormat9TZ,
|
||||
TimeFormat10, TimeFormat10TZ:
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return time.Time{}, timeErr
|
||||
}
|
||||
t, err := f.parseRelaxed(s)
|
||||
return t.AddDate(2000, 0, 0), err
|
||||
|
||||
default:
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return time.Time{}, timeErr
|
||||
}
|
||||
if f == "" {
|
||||
f = time.RFC3339Nano
|
||||
}
|
||||
return time.Parse(string(f), s)
|
||||
}
|
||||
}
|
||||
|
||||
func (f TimeFormat) parseRelaxed(s string) (time.Time, error) {
|
||||
fs := string(f)
|
||||
fs = strings.TrimSuffix(fs, "Z07:00")
|
||||
fs = strings.TrimSuffix(fs, ".000")
|
||||
t, err := time.Parse(fs+"Z07:00", s)
|
||||
if err != nil {
|
||||
return time.Parse(fs, s)
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
118
time_test.go
Normal file
118
time_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTimeFormat_Encode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
|
||||
|
||||
tests := []struct {
|
||||
fmt TimeFormat
|
||||
time time.Time
|
||||
want any
|
||||
}{
|
||||
{TimeFormatDefault, reference, "2013-10-07T04:23:19.12-04:00"},
|
||||
{TimeFormatJulianDay, reference, 2456572.849526851851852},
|
||||
{TimeFormatUnix, reference, int64(1381134199)},
|
||||
{TimeFormatUnixFrac, reference, 1381134199.120},
|
||||
{TimeFormatUnixMilli, reference, int64(1381134199_120)},
|
||||
{TimeFormatUnixMicro, reference, int64(1381134199_120000)},
|
||||
{TimeFormatUnixNano, reference, int64(1381134199_120000000)},
|
||||
{TimeFormat7, reference, "2013-10-07T08:23:19.120"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run("", func(t *testing.T) {
|
||||
if got := tt.fmt.Encode(tt.time); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("%q.Encode(%v) = %v, want %v", tt.fmt, tt.time, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeFormat_Decode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
|
||||
reftime := time.Date(2000, 1, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
|
||||
|
||||
tests := []struct {
|
||||
fmt TimeFormat
|
||||
val any
|
||||
want time.Time
|
||||
wantDelta time.Duration
|
||||
wantErr bool
|
||||
}{
|
||||
{TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
|
||||
{TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
|
||||
{TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
|
||||
{TimeFormatJulianDay, false, time.Time{}, 0, true},
|
||||
|
||||
{TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
|
||||
{TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
|
||||
{TimeFormatUnix, int64(1381134199), reference, time.Second, false},
|
||||
{TimeFormatUnix, "abc", time.Time{}, 0, true},
|
||||
{TimeFormatUnix, false, time.Time{}, 0, true},
|
||||
|
||||
{TimeFormatUnixMilli, "1381134199120", reference, 0, false},
|
||||
{TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
|
||||
{TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
|
||||
{TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
|
||||
{TimeFormatUnixMilli, false, time.Time{}, 0, true},
|
||||
|
||||
{TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
|
||||
{TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
|
||||
{TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
|
||||
{TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
|
||||
{TimeFormatUnixMicro, false, time.Time{}, 0, true},
|
||||
|
||||
{TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
|
||||
{TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
|
||||
{TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
|
||||
{TimeFormatUnixNano, "abc", time.Time{}, 0, true},
|
||||
{TimeFormatUnixNano, false, time.Time{}, 0, true},
|
||||
|
||||
{TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
|
||||
{TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
|
||||
{TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
|
||||
{TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
|
||||
{TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
|
||||
{TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
|
||||
{TimeFormatAuto, "1381134199", reference, time.Second, false},
|
||||
{TimeFormatAuto, "1381134199120", reference, 0, false},
|
||||
{TimeFormatAuto, "1381134199120000", reference, 0, false},
|
||||
{TimeFormatAuto, "1381134199120000000", reference, 0, false},
|
||||
{TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
|
||||
{TimeFormatAuto, "04:23:19.12-04:00", reftime, 0, false},
|
||||
{TimeFormatAuto, "abc", time.Time{}, 0, true},
|
||||
{TimeFormatAuto, false, time.Time{}, 0, true},
|
||||
|
||||
{TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
|
||||
{TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
|
||||
{TimeFormat9, "04:23:19.12-04:00", reftime, 0, false},
|
||||
{TimeFormat9, "08:23:19.12", reftime, 0, false},
|
||||
{TimeFormat3, false, time.Time{}, 0, true},
|
||||
{TimeFormat9, false, time.Time{}, 0, true},
|
||||
|
||||
{TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
|
||||
{TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
|
||||
{TimeFormatDefault, false, time.Time{}, 0, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("", func(t *testing.T) {
|
||||
got, err := tt.fmt.Decode(tt.val)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.want.Sub(got).Abs() > tt.wantDelta {
|
||||
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
16
util.go
Normal file
16
util.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package sqlite3
|
||||
|
||||
// Return true if stmt is an empty SQL statement.
|
||||
// This is used as an optimization.
|
||||
// It's OK to always return false here.
|
||||
func emptyStatement(stmt string) bool {
|
||||
for _, b := range []byte(stmt) {
|
||||
switch b {
|
||||
case ' ', '\n', '\r', '\t', '\v', '\f':
|
||||
case ';':
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
195
util_test.go
195
util_test.go
@@ -1,161 +1,60 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"math"
|
||||
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Path = "./embed/sqlite3.wasm"
|
||||
}
|
||||
func Test_emptyStatement(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
func newMemory(size uint32) memory {
|
||||
mem := make(mockMemory, size)
|
||||
return memory{mockModule{&mem}}
|
||||
}
|
||||
|
||||
type mockModule struct {
|
||||
memory api.Memory
|
||||
}
|
||||
|
||||
func (m mockModule) Memory() api.Memory { return m.memory }
|
||||
func (m mockModule) String() string { return "mockModule" }
|
||||
func (m mockModule) Name() string { return "mockModule" }
|
||||
|
||||
func (m mockModule) ExportedGlobal(name string) api.Global { return nil }
|
||||
func (m mockModule) ExportedMemory(name string) api.Memory { return nil }
|
||||
func (m mockModule) ExportedFunction(name string) api.Function { return nil }
|
||||
func (m mockModule) ExportedMemoryDefinitions() map[string]api.MemoryDefinition { return nil }
|
||||
func (m mockModule) ExportedFunctionDefinitions() map[string]api.FunctionDefinition { return nil }
|
||||
func (m mockModule) CloseWithExitCode(ctx context.Context, exitCode uint32) error { return nil }
|
||||
func (m mockModule) Close(context.Context) error { return nil }
|
||||
|
||||
type mockMemory []byte
|
||||
|
||||
func (m mockMemory) Definition() api.MemoryDefinition { return nil }
|
||||
|
||||
func (m mockMemory) Size() uint32 { return uint32(len(m)) }
|
||||
|
||||
func (m mockMemory) ReadByte(offset uint32) (byte, bool) {
|
||||
if offset >= m.Size() {
|
||||
return 0, false
|
||||
tests := []struct {
|
||||
name string
|
||||
stmt string
|
||||
want bool
|
||||
}{
|
||||
{"empty", "", true},
|
||||
{"space", " ", true},
|
||||
{"separator", ";\n ", true},
|
||||
{"begin", "BEGIN", false},
|
||||
{"select", "SELECT 1;", false},
|
||||
}
|
||||
return m[offset], true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadUint16Le(offset uint32) (uint16, bool) {
|
||||
if !m.hasSize(offset, 2) {
|
||||
return 0, false
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := emptyStatement(tt.stmt); got != tt.want {
|
||||
t.Errorf("got %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
return binary.LittleEndian.Uint16(m[offset : offset+2]), true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadUint32Le(offset uint32) (uint32, bool) {
|
||||
if !m.hasSize(offset, 4) {
|
||||
return 0, false
|
||||
func Fuzz_emptyStatement(f *testing.F) {
|
||||
f.Add("")
|
||||
f.Add(" ")
|
||||
f.Add(";\n ")
|
||||
f.Add("; ;\v")
|
||||
f.Add("BEGIN")
|
||||
f.Add("SELECT 1;")
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
f.Fatal(err)
|
||||
}
|
||||
return binary.LittleEndian.Uint32(m[offset : offset+4]), true
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
func (m mockMemory) ReadFloat32Le(offset uint32) (float32, bool) {
|
||||
v, ok := m.ReadUint32Le(offset)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
return math.Float32frombits(v), true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadUint64Le(offset uint32) (uint64, bool) {
|
||||
if !m.hasSize(offset, 8) {
|
||||
return 0, false
|
||||
}
|
||||
return binary.LittleEndian.Uint64(m[offset : offset+8]), true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadFloat64Le(offset uint32) (float64, bool) {
|
||||
v, ok := m.ReadUint64Le(offset)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
return math.Float64frombits(v), true
|
||||
}
|
||||
|
||||
func (m mockMemory) Read(offset, byteCount uint32) ([]byte, bool) {
|
||||
if !m.hasSize(offset, byteCount) {
|
||||
return nil, false
|
||||
}
|
||||
return m[offset : offset+byteCount : offset+byteCount], true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteByte(offset uint32, v byte) bool {
|
||||
if offset >= m.Size() {
|
||||
return false
|
||||
}
|
||||
m[offset] = v
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteUint16Le(offset uint32, v uint16) bool {
|
||||
if !m.hasSize(offset, 2) {
|
||||
return false
|
||||
}
|
||||
binary.LittleEndian.PutUint16(m[offset:], v)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteUint32Le(offset, v uint32) bool {
|
||||
if !m.hasSize(offset, 4) {
|
||||
return false
|
||||
}
|
||||
binary.LittleEndian.PutUint32(m[offset:], v)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteFloat32Le(offset uint32, v float32) bool {
|
||||
return m.WriteUint32Le(offset, math.Float32bits(v))
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteUint64Le(offset uint32, v uint64) bool {
|
||||
if !m.hasSize(offset, 8) {
|
||||
return false
|
||||
}
|
||||
binary.LittleEndian.PutUint64(m[offset:], v)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteFloat64Le(offset uint32, v float64) bool {
|
||||
return m.WriteUint64Le(offset, math.Float64bits(v))
|
||||
}
|
||||
|
||||
func (m mockMemory) Write(offset uint32, val []byte) bool {
|
||||
if !m.hasSize(offset, uint32(len(val))) {
|
||||
return false
|
||||
}
|
||||
copy(m[offset:], val)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteString(offset uint32, val string) bool {
|
||||
if !m.hasSize(offset, uint32(len(val))) {
|
||||
return false
|
||||
}
|
||||
copy(m[offset:], val)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *mockMemory) Grow(delta uint32) (result uint32, ok bool) {
|
||||
prev := (len(*m) + 65535) / 65536
|
||||
*m = append(*m, make([]byte, 65536*delta)...)
|
||||
return uint32(prev), true
|
||||
}
|
||||
|
||||
func (m mockMemory) PageSize() (result uint32) {
|
||||
return uint32(len(m) / 65536)
|
||||
}
|
||||
|
||||
func (m mockMemory) hasSize(offset uint32, byteCount uint32) bool {
|
||||
return uint64(offset)+uint64(byteCount) <= uint64(len(m))
|
||||
f.Fuzz(func(t *testing.T, sql string) {
|
||||
// If empty, SQLite parses it as empty.
|
||||
if emptyStatement(sql) {
|
||||
stmt, tail, err := db.Prepare(sql)
|
||||
if err != nil {
|
||||
t.Errorf("%q, %v", sql, err)
|
||||
}
|
||||
if stmt != nil {
|
||||
t.Errorf("%q, %v", sql, stmt)
|
||||
}
|
||||
if tail != "" {
|
||||
t.Errorf("%q", sql)
|
||||
}
|
||||
stmt.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -242,8 +242,9 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS
|
||||
// Release the file lock only when all connections have released the lock.
|
||||
ptr.SetLock(_NO_LOCK)
|
||||
if fLock.shared--; fLock.shared == 0 {
|
||||
rc := fLock.Release()
|
||||
fLock.state = _NO_LOCK
|
||||
return uint32(fLock.Release())
|
||||
return uint32(rc)
|
||||
}
|
||||
return _OK
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package sqlite3
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
@@ -16,16 +17,15 @@ func Test_vfsLock(t *testing.T) {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
|
||||
// Create a temporary file.
|
||||
file1, err := os.CreateTemp("", "sqlite3-")
|
||||
file1, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer file1.Close()
|
||||
|
||||
name := file1.Name()
|
||||
defer os.RemoveAll(name)
|
||||
|
||||
// Open the temporary file again.
|
||||
file2, err := os.OpenFile(name, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
|
||||
29
vfs_test.go
29
vfs_test.go
@@ -7,6 +7,7 @@ import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -136,12 +137,12 @@ func Test_vfsFullPathname(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_vfsDelete(t *testing.T) {
|
||||
file, err := os.CreateTemp("", "sqlite3-")
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
|
||||
file, err := os.Create(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
name := file.Name()
|
||||
defer os.RemoveAll(name)
|
||||
file.Close()
|
||||
|
||||
mem := newMemory(128 + _MAX_PATHNAME)
|
||||
@@ -163,8 +164,19 @@ func Test_vfsDelete(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_vfsAccess(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
file := filepath.Join(t.TempDir(), "test.db")
|
||||
if f, err := os.Create(file); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
f.Close()
|
||||
}
|
||||
if err := os.Chmod(file, syscall.S_IRUSR); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mem := newMemory(128 + _MAX_PATHNAME)
|
||||
mem.writeString(8, t.TempDir())
|
||||
mem.writeString(8, dir)
|
||||
|
||||
rc := vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_EXISTS, 4)
|
||||
if rc != _OK {
|
||||
@@ -181,6 +193,15 @@ func Test_vfsAccess(t *testing.T) {
|
||||
if got := mem.readUint32(4); got != 1 {
|
||||
t.Error("can't access directory")
|
||||
}
|
||||
|
||||
mem.writeString(8, file)
|
||||
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_READWRITE, 4)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
if got := mem.readUint32(4); got != 0 {
|
||||
t.Error("can access file")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_vfsFile(t *testing.T) {
|
||||
|
||||
19
vfs_unix.go
19
vfs_unix.go
@@ -33,16 +33,17 @@ func (l *vfsFileLocker) GetExclusive() xErrorCode {
|
||||
}
|
||||
|
||||
func (l *vfsFileLocker) Downgrade() xErrorCode {
|
||||
// Downgrade to a SHARED lock.
|
||||
if rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE); rc != _OK {
|
||||
// In theory, the downgrade to a SHARED cannot fail because another
|
||||
// process is holding an incompatible lock. If it does, this
|
||||
// indicates that the other process is not following the locking
|
||||
// protocol. If this happens, return IOERR_RDLOCK. Returning
|
||||
// BUSY would confuse the upper layer.
|
||||
return IOERR_RDLOCK
|
||||
if l.state >= _EXCLUSIVE_LOCK {
|
||||
// Downgrade to a SHARED lock.
|
||||
if rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE); rc != _OK {
|
||||
// In theory, the downgrade to a SHARED cannot fail because another
|
||||
// process is holding an incompatible lock. If it does, this
|
||||
// indicates that the other process is not following the locking
|
||||
// protocol. If this happens, return IOERR_RDLOCK. Returning
|
||||
// BUSY would confuse the upper layer.
|
||||
return IOERR_RDLOCK
|
||||
}
|
||||
}
|
||||
|
||||
// Release the PENDING and RESERVED locks.
|
||||
return l.unlock(_PENDING_BYTE, 2)
|
||||
}
|
||||
|
||||
@@ -39,27 +39,39 @@ func (l *vfsFileLocker) GetExclusive() xErrorCode {
|
||||
}
|
||||
|
||||
func (l *vfsFileLocker) Downgrade() xErrorCode {
|
||||
// Release the SHARED lock.
|
||||
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
|
||||
if l.state >= _EXCLUSIVE_LOCK {
|
||||
// Release the SHARED lock.
|
||||
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
|
||||
|
||||
// Reacquire the SHARED lock.
|
||||
if rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE); rc != _OK {
|
||||
// This should never happen.
|
||||
// We should always be able to reacquire the read lock.
|
||||
return IOERR_RDLOCK
|
||||
// Reacquire the SHARED lock.
|
||||
if rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE); rc != _OK {
|
||||
// This should never happen.
|
||||
// We should always be able to reacquire the read lock.
|
||||
return IOERR_RDLOCK
|
||||
}
|
||||
}
|
||||
|
||||
// Release the PENDING and RESERVED locks.
|
||||
l.unlock(_RESERVED_BYTE, 1)
|
||||
l.unlock(_PENDING_BYTE, 1)
|
||||
if l.state >= _RESERVED_LOCK {
|
||||
l.unlock(_RESERVED_BYTE, 1)
|
||||
}
|
||||
if l.state >= _PENDING_LOCK {
|
||||
l.unlock(_PENDING_BYTE, 1)
|
||||
}
|
||||
return _OK
|
||||
}
|
||||
|
||||
func (l *vfsFileLocker) Release() xErrorCode {
|
||||
// Release all locks.
|
||||
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
|
||||
l.unlock(_RESERVED_BYTE, 1)
|
||||
l.unlock(_PENDING_BYTE, 1)
|
||||
if l.state >= _RESERVED_LOCK {
|
||||
l.unlock(_RESERVED_BYTE, 1)
|
||||
}
|
||||
if l.state >= _SHARED_LOCK {
|
||||
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
|
||||
}
|
||||
if l.state >= _PENDING_LOCK {
|
||||
l.unlock(_PENDING_BYTE, 1)
|
||||
}
|
||||
return _OK
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user