Compare commits

..

58 Commits

Author SHA1 Message Date
Nuno Cruces
dbc400eb15 Refactor native code. 2023-03-01 13:12:32 +00:00
Nuno Cruces
35265271aa Rename. 2023-03-01 11:18:25 +00:00
Nuno Cruces
c7165a2e56 Documentation. 2023-03-01 10:34:39 +00:00
Nuno Cruces
e64bffa520 Pragmas. 2023-02-28 16:03:31 +00:00
Nuno Cruces
54046b6adc Documentation. 2023-02-28 16:02:13 +00:00
Nuno Cruces
1b3823483f Incremental blobs. 2023-02-27 13:45:32 +00:00
Nuno Cruces
ce6d0627b2 Tests. 2023-02-27 12:07:48 +00:00
Nuno Cruces
dd30215702 Incremental blobs. 2023-02-27 04:08:55 +00:00
Nuno Cruces
21aff4c9f5 Towards incremental blobs. 2023-02-27 03:20:23 +00:00
Nuno Cruces
b30f127547 WAL mode, extensions. 2023-02-26 04:49:10 +00:00
Nuno Cruces
6509e5deb2 Transactions. 2023-02-26 03:22:08 +00:00
Nuno Cruces
125b8053f8 Fix readonly transactions. 2023-02-25 15:34:24 +00:00
Nuno Cruces
1e4a246d2f Error handling. 2023-02-25 15:11:07 +00:00
Nuno Cruces
e6cd0aaf87 MustPrepare. 2023-02-25 01:29:46 +00:00
Nuno Cruces
c1472a48b0 Tests. 2023-02-25 00:50:03 +00:00
Nuno Cruces
a69ab1ebe3 Fix data race. 2023-02-24 15:19:57 +00:00
Nuno Cruces
1190c21684 Refactor. 2023-02-24 15:06:19 +00:00
Nuno Cruces
8c28c3a6f4 Interrupt API. 2023-02-24 14:56:49 +00:00
Nuno Cruces
0146496036 Nested transactions. 2023-02-24 14:31:41 +00:00
Nuno Cruces
fcd33d2f0f Time improvements. 2023-02-24 11:09:30 +00:00
Nuno Cruces
627df5db0f No sandbox. 2023-02-23 14:16:37 +00:00
Nuno Cruces
1ed62d300d Require OFD locks. 2023-02-23 13:29:51 +00:00
Nuno Cruces
5b2451c3ad Default sector size. 2023-02-23 03:22:39 +00:00
Nuno Cruces
d52e0371eb Only reuse main db files. 2023-02-23 02:22:57 +00:00
Nuno Cruces
75f2644b0e SQLite 3.41.0. 2023-02-22 20:08:50 +00:00
Nuno Cruces
71ae26e5c9 Documentation. 2023-02-22 17:51:30 +00:00
Nuno Cruces
e91758c6a4 Zero blobs, tests, documentation 2023-02-22 14:19:56 +00:00
Nuno Cruces
b749b32a62 Unlock tweaks, tests. 2023-02-21 12:56:39 +00:00
Nuno Cruces
3b4df71a94 Time handling. 2023-02-21 04:45:25 +00:00
Nuno Cruces
df687a1c54 Tests. 2023-02-20 14:43:19 +00:00
Edoardo Vacchi
2f5b9837e1 deps: updates wazero to 1.0.0-pre.9
This updates [wazero](https://wazero.io/) to [1.0.0-pre.9][1]. Notably:

* This release includes our last breaking changes before 1.0.0 final:
  * Requires at least Go 1.8
  * Renames `Runtime.InstantiateModuleFromBinary` to `Runtime.Instantiate`
* This release also integrates Go context to limit execution time.
  More details on the [Release Notes][1]
* We are now passing third-party integration test suites: wasi-testsuite,
  TinyGo's, Zig's.

[1]: https://github.com/tetratelabs/wazero/releases/tag/v1.0.0-pre.9

Signed-off-by: Edoardo Vacchi <evacchi@users.noreply.github.com>
2023-02-20 13:32:52 +00:00
Nuno Cruces
c351400be7 Tests. 2023-02-20 13:30:01 +00:00
Nuno Cruces
231d3a0438 Read-only transactions, locking. 2023-02-19 16:16:13 +00:00
Nuno Cruces
2f25e4eedb Bug fixes, optimizations, fuzz testing. 2023-02-19 12:44:26 +00:00
Nuno Cruces
ad27d5d840 Support pragmas, integration test. 2023-02-18 13:15:01 +00:00
Nuno Cruces
ec5bd236f8 Documentation. 2023-02-18 03:46:52 +00:00
Nuno Cruces
a51cdb04e6 Exec fast path. 2023-02-18 02:57:47 +00:00
Nuno Cruces
f50d5df3d0 Context cancellation. 2023-02-18 02:16:11 +00:00
Nuno Cruces
4ac2ccf473 Named parameters. 2023-02-18 00:47:56 +00:00
Nuno Cruces
5f7a72a553 Connection reuse. 2023-02-17 16:36:47 +00:00
Nuno Cruces
643b004727 Reuse byte slices. 2023-02-17 12:30:07 +00:00
Nuno Cruces
72e0415184 Time handling. 2023-02-17 10:40:43 +00:00
Nuno Cruces
28cb558d10 Minimal database/sql driver. 2023-02-17 02:21:07 +00:00
Nuno Cruces
23806b0db1 More tests. 2023-02-16 13:58:53 +00:00
Nuno Cruces
6a80499823 Panic consistently. 2023-02-16 13:52:05 +00:00
Nuno Cruces
110f36bdf9 Fix flakiness. 2023-02-16 13:37:29 +00:00
Nuno Cruces
f85426022d Test data races. 2023-02-15 16:24:34 +00:00
Nuno Cruces
78fd0cbee5 Towards database/sql. 2023-02-15 16:15:14 +00:00
Nuno Cruces
0d59065719 Lock errors. 2023-02-14 11:38:05 +00:00
Nuno Cruces
6110e2d6e2 Memory arenas. 2023-02-14 11:34:24 +00:00
Nuno Cruces
275b8c38a2 Documentation. 2023-02-14 11:33:41 +00:00
Nuno Cruces
fd1244c471 Support utf16 DBs. 2023-02-14 01:21:12 +00:00
Nuno Cruces
f11d294825 Check integrity. 2023-02-13 16:00:27 +00:00
Nuno Cruces
22b702fcda Synchronize IPC test. 2023-02-13 15:23:11 +00:00
Nuno Cruces
831817a737 Test IPC. 2023-02-13 15:01:36 +00:00
Nuno Cruces
7329d9f2fb Avoid writer starvation. 2023-02-13 13:53:32 +00:00
Nuno Cruces
3aad1d5d79 Towards xFileControl. 2023-02-13 13:52:52 +00:00
Nuno Cruces
f72c599d2d illumos OFD locks. 2023-02-13 13:51:35 +00:00
62 changed files with 5417 additions and 1253 deletions

1
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1 @@
custom: https://www.paypal.com/donate/buttons/manage/33P59ELZWGMK6

View File

@@ -28,7 +28,13 @@ jobs:
- name: Test
run: go test -v ./...
- if: matrix.os == 'ubuntu-latest'
name: Update coverage report
- name: Test data races
run: go test -v -race ./...
if: matrix.os == 'ubuntu-latest'
- name: Update coverage report
uses: ncruces/go-coverage-report@main
if: |
matrix.os == 'ubuntu-latest' &&
github.event_name == 'push'
continue-on-error: true

View File

@@ -4,15 +4,73 @@
[![Go Report](https://goreportcard.com/badge/github.com/ncruces/go-sqlite3)](https://goreportcard.com/report/github.com/ncruces/go-sqlite3)
[![Go Coverage](https://github.com/ncruces/go-sqlite3/wiki/coverage.svg)](https://raw.githack.com/wiki/ncruces/go-sqlite3/coverage.html)
⚠️ CAUTION ⚠️
Go module `github.com/ncruces/go-sqlite3` wraps a [WASM](https://webassembly.org/) build of [SQLite](https://sqlite.org/),
and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings.
This is still very much a WIP.\
DO NOT USE this with data you care about.
- Package [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
wraps the [C SQLite API](https://www.sqlite.org/cintro.html)
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
- Package [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver)
provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
- Package [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
embeds a build of SQLite into your application.
### Caveats
#### Write-Ahead Logging
Because WASM does not support shared memory,
[WAL](https://www.sqlite.org/wal.html) support is [limited](https://www.sqlite.org/wal.html#noshm).
To work around this limitation, SQLite is compiled with
[`SQLITE_DEFAULT_LOCKING_MODE=1`](https://www.sqlite.org/compile.html#default_locking_mode),
making `EXCLUSIVE` the default locking mode.
For non-WAL databases, `NORMAL` locking mode can be activated with
[`PRAGMA locking_mode=NORMAL`](https://www.sqlite.org/pragma.html#pragma_locking_mode).
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
the `database/sql` driver defaults to `NORMAL` locking mode,
and WAL databases are not supported.
#### Open File Description Locks
On Unix, this module uses [OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
to synchronize access to database files.
POSIX advisory locks, which SQLite uses, are [broken by design](https://www.sqlite.org/src/artifact/90c4fa?ln=1073-1161).
OFD locks are fully compatible with process-associated POSIX advisory locks,
and are supported on Linux, macOS and illumos.
As a work around for other Unixes, you can use [`nolock=1`](https://www.sqlite.org/uri.html).
### Roadmap
Roadmap:
- [x] build SQLite using `zig cc --target=wasm32-wasi`
- [x] `:memory:` databases
- [x] port [`test_demovfs.c`](https://www.sqlite.org/src/doc/trunk/src/test_demovfs.c) to Go
- branch [`wasi`](https://github.com/ncruces/go-sqlite3/tree/wasi) uses `test_demovfs.c` directly
- [x] come up with a simple, nice API, enough for simple queries
- [x] file locking, compatible with SQLite on Windows/Unix
- [x] design a nice API, enough for simple use cases
- [x] provide a simple `database/sql` driver
- [x] file locking, compatible with SQLite on macOS/Linux/Windows
- [ ] advanced SQLite features
- [x] nested transactions
- [x] incremental BLOB I/O
- [ ] online backup
- [ ] snapshots
- [ ] session extension
- [ ] resumable bulk update
- [ ] shared cache mode
- [ ] unlock-notify
- [ ] custom SQL functions
- [ ] custom VFSes
- [ ] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
- [ ] in-memory VFS, wrapping a [`bytes.Buffer`](https://pkg.go.dev/bytes#Buffer)
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
- [ ] custom VFS API
### Alternatives
- [`modernc.org/sqlite`](https://pkg.go.dev/modernc.org/sqlite)
- [`crawshaw.io/sqlite`](https://pkg.go.dev/crawshaw.io/sqlite)
- [`github.com/mattn/go-sqlite3`](https://pkg.go.dev/github.com/mattn/go-sqlite3)
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)

47
api.go
View File

@@ -1,3 +1,4 @@
// Package sqlite3 wraps the C SQLite API.
package sqlite3
import (
@@ -16,7 +17,7 @@ func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) {
return f
}
getPtr := func(name string) uint32 {
getVal := func(name string) uint32 {
global := module.ExportedGlobal(name)
if global == nil {
err = noGlobalErr + errorString(name)
@@ -29,9 +30,9 @@ func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) {
ctx: ctx,
mem: memory{module},
api: sqliteAPI{
malloc: getFun("malloc"),
free: getFun("free"),
destructor: uint64(getPtr("malloc_destructor")),
malloc: getFun("malloc"),
destructor: uint64(getVal("malloc_destructor")),
errcode: getFun("sqlite3_errcode"),
errstr: getFun("sqlite3_errstr"),
errmsg: getFun("sqlite3_errmsg"),
@@ -44,18 +45,33 @@ func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) {
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
bindName: getFun("sqlite3_bind_parameter_name"),
bindNull: getFun("sqlite3_bind_null"),
bindInteger: getFun("sqlite3_bind_int64"),
bindFloat: getFun("sqlite3_bind_double"),
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
bindNull: getFun("sqlite3_bind_null"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
columnInteger: getFun("sqlite3_column_int64"),
columnFloat: getFun("sqlite3_column_double"),
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
columnType: getFun("sqlite3_column_type"),
autocommit: getFun("sqlite3_get_autocommit"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
changes: getFun("sqlite3_changes64"),
blobOpen: getFun("sqlite3_blob_open"),
blobClose: getFun("sqlite3_blob_close"),
blobReopen: getFun("sqlite3_blob_reopen"),
blobBytes: getFun("sqlite3_blob_bytes"),
blobRead: getFun("sqlite3_blob_read"),
blobWrite: getFun("sqlite3_blob_write"),
interrupt: getVal("sqlite3_interrupt_offset"),
},
}
if err != nil {
@@ -65,8 +81,8 @@ func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) {
}
type sqliteAPI struct {
malloc api.Function
free api.Function
malloc api.Function
destructor uint64
errcode api.Function
errstr api.Function
@@ -80,16 +96,31 @@ type sqliteAPI struct {
step api.Function
exec api.Function
clearBindings api.Function
bindNull api.Function
bindCount api.Function
bindIndex api.Function
bindName api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
bindNull api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
columnInteger api.Function
columnFloat api.Function
columnText api.Function
columnBlob api.Function
columnBytes api.Function
columnType api.Function
autocommit api.Function
lastRowid api.Function
changes api.Function
blobOpen api.Function
blobClose api.Function
blobReopen api.Function
blobBytes api.Function
blobRead api.Function
blobWrite api.Function
interrupt uint32
}

158
blob.go Normal file
View File

@@ -0,0 +1,158 @@
package sqlite3
import "io"
// ZeroBlob represents a zero-filled, length n BLOB
// that can be used as an argument to
// [database/sql.DB.Exec] and similar methods.
type ZeroBlob int64
// Blob is a handle to an open BLOB.
//
// It implements [io.ReadWriteSeeker] for incremental BLOB I/O.
//
// https://www.sqlite.org/c3ref/blob.html
type Blob struct {
c *Conn
handle uint32
bytes int64
offset int64
}
var _ io.ReadWriteSeeker = &Blob{}
// OpenBlob opens a BLOB for incremental I/O.
//
// https://www.sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
defer c.arena.reset()
blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db)
tablePtr := c.arena.string(table)
columnPtr := c.arena.string(column)
var flags uint64
if write {
flags = 1
}
r := c.call(c.api.blobOpen, uint64(c.handle),
uint64(dbPtr), uint64(tablePtr), uint64(columnPtr),
uint64(row), flags, uint64(blobPtr))
if err := c.error(r[0]); err != nil {
return nil, err
}
blob := Blob{c: c}
blob.handle = c.mem.readUint32(blobPtr)
blob.bytes = int64(c.call(c.api.blobBytes, uint64(blob.handle))[0])
return &blob, nil
}
// Close closes a BLOB handle.
//
// It is safe to close a nil, zero or closed Blob.
//
// https://www.sqlite.org/c3ref/blob_close.html
func (b *Blob) Close() error {
if b == nil || b.handle == 0 {
return nil
}
r := b.c.call(b.c.api.blobClose, uint64(b.handle))
b.handle = 0
return b.c.error(r[0])
}
// Size returns the size of the BLOB in bytes.
//
// https://www.sqlite.org/c3ref/blob_bytes.html
func (b *Blob) Size() int64 {
return b.bytes
}
// Read implements the [io.Reader] interface.
//
// https://www.sqlite.org/c3ref/blob_read.html
func (b *Blob) Read(p []byte) (n int, err error) {
if b.offset >= b.bytes {
return 0, io.EOF
}
want := int64(len(p))
avail := b.bytes - b.offset
if want > avail {
want = avail
}
ptr := b.c.new(uint64(want))
defer b.c.free(ptr)
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
uint64(ptr), uint64(want), uint64(b.offset))
err = b.c.error(r[0])
if err != nil {
return 0, err
}
mem := b.c.mem.view(ptr, uint64(want))
copy(p, mem)
b.offset += want
if b.offset >= b.bytes {
err = io.EOF
}
return int(want), err
}
// Write implements the [io.Writer] interface.
//
// https://www.sqlite.org/c3ref/blob_write.html
func (b *Blob) Write(p []byte) (n int, err error) {
offset := b.offset
if offset > b.bytes {
offset = b.bytes
}
ptr := b.c.newBytes(p)
defer b.c.free(ptr)
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
uint64(ptr), uint64(len(p)), uint64(offset))
err = b.c.error(r[0])
if err != nil {
return 0, err
}
b.offset += int64(len(p))
return len(p), nil
}
// Seek implements the [io.Seeker] interface.
func (b *Blob) Seek(offset int64, whence int) (int64, error) {
switch whence {
default:
return 0, whenceErr
case io.SeekStart:
break
case io.SeekCurrent:
offset += b.offset
case io.SeekEnd:
offset += b.bytes
}
if offset < 0 {
return 0, offsetErr
}
b.offset = offset
return offset, nil
}
// Reopen moves a BLOB handle to a new row of the same database table.
//
// https://www.sqlite.org/c3ref/blob_reopen.html
func (b *Blob) Reopen(row int64) error {
r := b.c.call(b.c.api.blobReopen, uint64(b.handle), uint64(row))
b.bytes = int64(b.c.call(b.c.api.blobBytes, uint64(b.handle))[0])
b.offset = 0
return b.c.error(r[0])
}

View File

@@ -2,7 +2,9 @@ package sqlite3
import (
"context"
"crypto/rand"
"os"
"runtime"
"strconv"
"sync"
"sync/atomic"
@@ -11,9 +13,14 @@ import (
"github.com/tetratelabs/wazero/api"
)
// Configure SQLite.
// Configure SQLite WASM.
//
// Importing package embed initializes these
// with an appropriate build of SQLite:
//
// import _ "github.com/ncruces/go-sqlite3/embed"
var (
Binary []byte // Binary to load.
Binary []byte // WASM binary to load.
Path string // Path to load the binary from.
)
@@ -24,28 +31,26 @@ type sqlite3Runtime struct {
runtime wazero.Runtime
compiled wazero.CompiledModule
instances atomic.Uint64
ctx context.Context
err error
}
func (s *sqlite3Runtime) instantiateModule(ctx context.Context) (api.Module, error) {
s.ctx = ctx
s.once.Do(s.compileModule)
s.once.Do(func() { s.compileModule(ctx) })
if s.err != nil {
return nil, s.err
}
cfg := wazero.NewModuleConfig().
WithName("sqlite3-" + strconv.FormatUint(s.instances.Add(1), 10))
WithName("sqlite3-" + strconv.FormatUint(s.instances.Add(1), 10)).
WithSysWalltime().WithSysNanotime().WithSysNanosleep().
WithOsyield(runtime.Gosched).
WithRandSource(rand.Reader)
return s.runtime.InstantiateModule(ctx, s.compiled, cfg)
}
func (s *sqlite3Runtime) compileModule() {
s.runtime = wazero.NewRuntime(s.ctx)
s.err = vfsInstantiate(s.ctx, s.runtime)
if s.err != nil {
return
}
func (s *sqlite3Runtime) compileModule(ctx context.Context) {
s.runtime = wazero.NewRuntime(ctx)
vfsInstantiate(ctx, s.runtime)
bin := Binary
if bin == nil && Path != "" {
@@ -54,6 +59,10 @@ func (s *sqlite3Runtime) compileModule() {
return
}
}
if bin == nil {
s.err = binaryErr
return
}
s.compiled, s.err = s.runtime.CompileModule(s.ctx, bin)
s.compiled, s.err = s.runtime.CompileModule(ctx, bin)
}

344
conn.go
View File

@@ -2,10 +2,19 @@ package sqlite3
import (
"context"
"database/sql/driver"
"fmt"
"math"
"net/url"
"strings"
"sync/atomic"
"unsafe"
"github.com/tetratelabs/wazero/api"
)
// Conn is a database connection handle.
// A Conn is not safe for concurrent use by multiple goroutines.
//
// https://www.sqlite.org/c3ref/sqlite3.html
type Conn struct {
@@ -13,15 +22,24 @@ type Conn struct {
api sqliteAPI
mem memory
handle uint32
arena arena
interrupt context.Context
waiter chan struct{}
pending *Stmt
}
// Open calls [OpenFlags] with [OPEN_READWRITE] and [OPEN_CREATE].
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI].
func Open(filename string) (conn *Conn, err error) {
return OpenFlags(filename, OPEN_READWRITE|OPEN_CREATE)
return OpenFlags(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
}
// OpenFlags opens an SQLite database file as specified by the filename argument.
//
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
//
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
//
// https://www.sqlite.org/c3ref/open.html
func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
ctx := context.Background()
@@ -39,21 +57,33 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
if err != nil {
return nil, err
}
c.arena = c.newArena(1024)
namePtr := c.newString(filename)
connPtr := c.new(ptrlen)
defer c.free(namePtr)
defer c.free(connPtr)
defer c.arena.reset()
connPtr := c.arena.new(ptrlen)
namePtr := c.arena.string(filename)
r, err := c.api.open.Call(c.ctx, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
if err != nil {
return nil, err
}
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
c.handle = c.mem.readUint32(connPtr)
if err := c.error(r[0]); err != nil {
return nil, err
}
if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
var pragmas strings.Builder
if _, after, ok := strings.Cut(filename, "?"); ok {
query, _ := url.ParseQuery(after)
for _, p := range query["_pragma"] {
pragmas.WriteString(`PRAGMA `)
pragmas.WriteString(p)
pragmas.WriteByte(';')
}
}
if err := c.Exec(pragmas.String()); err != nil {
return nil, fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
}
return c, nil
}
@@ -63,17 +93,17 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
// open blob handles, and/or unfinished backup objects,
// Close will leave the database connection open and return [BUSY].
//
// It is safe to close a nil, zero or closed Conn.
//
// https://www.sqlite.org/c3ref/close.html
func (c *Conn) Close() error {
if c == nil {
if c == nil || c.handle == 0 {
return nil
}
r, err := c.api.close.Call(c.ctx, uint64(c.handle))
if err != nil {
return err
}
c.SetInterrupt(context.Background())
r := c.call(c.api.close, uint64(c.handle))
if err := c.error(r[0]); err != nil {
return err
}
@@ -87,16 +117,30 @@ func (c *Conn) Close() error {
//
// https://www.sqlite.org/c3ref/exec.html
func (c *Conn) Exec(sql string) error {
sqlPtr := c.newString(sql)
defer c.free(sqlPtr)
c.checkInterrupt()
defer c.arena.reset()
sqlPtr := c.arena.string(sql)
r, err := c.api.exec.Call(c.ctx, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
if err != nil {
return err
}
r := c.call(c.api.exec, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
return c.error(r[0])
}
// MustPrepare calls [Conn.Prepare] and panics on error,
// a nil Stmt, or a non-empty tail.
func (c *Conn) MustPrepare(sql string) *Stmt {
s, tail, err := c.PrepareFlags(sql, 0)
if err != nil {
panic(err)
}
if s == nil {
panic(emptyErr)
}
if !emptyStatement(tail) {
panic(tailErr)
}
return s
}
// Prepare calls [Conn.PrepareFlags] with no flags.
func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
return c.PrepareFlags(sql, 0)
@@ -109,19 +153,18 @@ func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
//
// https://www.sqlite.org/c3ref/prepare.html
func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) {
sqlPtr := c.newString(sql)
stmtPtr := c.new(ptrlen)
tailPtr := c.new(ptrlen)
defer c.free(sqlPtr)
defer c.free(stmtPtr)
defer c.free(tailPtr)
if emptyStatement(sql) {
return nil, "", nil
}
r, err := c.api.prepare.Call(c.ctx, uint64(c.handle),
defer c.arena.reset()
stmtPtr := c.arena.new(ptrlen)
tailPtr := c.arena.new(ptrlen)
sqlPtr := c.arena.string(sql)
r := c.call(c.api.prepare, uint64(c.handle),
uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),
uint64(stmtPtr), uint64(tailPtr))
if err != nil {
return nil, "", err
}
stmt = &Stmt{c: c}
stmt.handle = c.mem.readUint32(stmtPtr)
@@ -137,6 +180,124 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
return
}
// GetAutocommit tests the connection for auto-commit mode.
//
// https://www.sqlite.org/c3ref/get_autocommit.html
func (c *Conn) GetAutocommit() bool {
r := c.call(c.api.autocommit, uint64(c.handle))
return r[0] != 0
}
// LastInsertRowID returns the rowid of the most recent successful INSERT
// on the database connection.
//
// https://www.sqlite.org/c3ref/last_insert_rowid.html
func (c *Conn) LastInsertRowID() int64 {
r := c.call(c.api.lastRowid, uint64(c.handle))
return int64(r[0])
}
// Changes returns the number of rows modified, inserted or deleted
// by the most recently completed INSERT, UPDATE or DELETE statement
// on the database connection.
//
// https://www.sqlite.org/c3ref/changes.html
func (c *Conn) Changes() int64 {
r := c.call(c.api.changes, uint64(c.handle))
return int64(r[0])
}
// SetInterrupt interrupts a long-running query when a context is done.
//
// Subsequent uses of the connection will return [INTERRUPT]
// until the context is reset by another call to SetInterrupt.
//
// To associate a timeout with a connection:
//
// ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
// conn.SetInterrupt(ctx)
// defer cancel()
//
// SetInterrupt returns the old context assigned to the connection.
//
// https://www.sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
// Is a waiter running?
if c.waiter != nil {
c.waiter <- struct{}{} // Cancel the waiter.
<-c.waiter // Wait for it to finish.
c.waiter = nil
}
old = c.interrupt
c.interrupt = ctx
if ctx == nil || ctx.Done() == nil {
// Finalize the uncompleted SQL statement.
if c.pending != nil {
c.pending.Close()
c.pending = nil
}
return old
}
// Creating an uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
if c.pending == nil {
c.pending = c.MustPrepare(`SELECT 1 UNION ALL SELECT 2`)
c.pending.Step()
} else {
c.pending.Reset()
}
// Don't create the goroutine if we're already interrupted.
// This happens frequently while restoring to a previously interrupted state.
if c.checkInterrupt() {
return old
}
waiter := make(chan struct{})
c.waiter = waiter
go func() {
select {
case <-waiter: // Waiter was cancelled.
break
case <-ctx.Done(): // Done was closed.
buf := c.mem.view(c.handle+c.api.interrupt, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
// Wait for the next call to SetInterrupt.
<-waiter
}
// Signal that the waiter has finished.
waiter <- struct{}{}
}()
return old
}
func (c *Conn) checkInterrupt() bool {
if c.interrupt == nil || c.interrupt.Err() == nil {
return false
}
buf := c.mem.view(c.handle+c.api.interrupt, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
return true
}
// Pragma executes a PRAGMA statement and returns any results.
//
// https://www.sqlite.org/pragma.html
func (c *Conn) Pragma(str string) []string {
stmt := c.MustPrepare(`PRAGMA ` + str)
defer stmt.Close()
var pragmas []string
for stmt.Step() {
pragmas = append(pragmas, stmt.ColumnText(0))
}
return pragmas
}
func (c *Conn) error(rc uint64, sql ...string) error {
if rc == _OK {
return nil
@@ -150,49 +311,52 @@ func (c *Conn) error(rc uint64, sql ...string) error {
var r []uint64
// sqlite3_errmsg is guaranteed to never change the value of the error code.
r, _ = c.api.errstr.Call(c.ctx, rc)
if r != nil {
err.str = c.mem.readString(uint32(r[0]), _MAX_STRING)
}
r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle))
if r != nil {
err.msg = c.mem.readString(uint32(r[0]), 512)
err.msg = c.mem.readString(uint32(r[0]), _MAX_STRING)
}
if sql != nil {
// sqlite3_error_offset is guaranteed to never change the value of the error code.
r, _ = c.api.erroff.Call(c.ctx, uint64(c.handle))
if r != nil && r[0] != math.MaxUint32 {
err.sql = sql[0][r[0]:]
}
}
r, _ = c.api.errstr.Call(c.ctx, rc)
if r != nil {
err.str = c.mem.readString(uint32(r[0]), 512)
}
if err.msg == err.str {
switch err.msg {
case err.str, "not an error":
err.msg = ""
}
return &err
}
func (c *Conn) call(fn api.Function, params ...uint64) []uint64 {
r, err := fn.Call(c.ctx, params...)
if err != nil {
panic(err)
}
return r
}
func (c *Conn) free(ptr uint32) {
if ptr == 0 {
return
}
_, err := c.api.free.Call(c.ctx, uint64(ptr))
if err != nil {
panic(err)
}
c.call(c.api.free, uint64(ptr))
}
func (c *Conn) new(len uint32) uint32 {
r, err := c.api.malloc.Call(c.ctx, uint64(len))
if err != nil {
panic(err)
func (c *Conn) new(size uint64) uint32 {
if size > _MAX_ALLOCATION_SIZE {
panic(oomErr)
}
r := c.call(c.api.malloc, size)
ptr := uint32(r[0])
if ptr == 0 && len != 0 {
if ptr == 0 && size != 0 {
panic(oomErr)
}
return ptr
@@ -202,19 +366,79 @@ func (c *Conn) newBytes(b []byte) uint32 {
if b == nil {
return 0
}
siz := uint32(len(b))
ptr := c.new(siz)
buf := c.mem.view(ptr, siz)
copy(buf, b)
ptr := c.new(uint64(len(b)))
c.mem.writeBytes(ptr, b)
return ptr
}
func (c *Conn) newString(s string) uint32 {
siz := uint32(len(s) + 1)
ptr := c.new(siz)
buf := c.mem.view(ptr, siz)
buf[len(s)] = 0
copy(buf, s)
ptr := c.new(uint64(len(s) + 1))
c.mem.writeString(ptr, s)
return ptr
}
func (c *Conn) newArena(size uint64) arena {
return arena{
c: c,
base: c.new(size),
size: uint32(size),
}
}
type arena struct {
c *Conn
base uint32
next uint32
size uint32
ptrs []uint32
}
func (a *arena) free() {
if a.c == nil {
return
}
a.reset()
a.c.free(a.base)
a.c = nil
}
func (a *arena) reset() {
for _, ptr := range a.ptrs {
a.c.free(ptr)
}
a.ptrs = nil
a.next = 0
}
func (a *arena) new(size uint64) uint32 {
if size <= uint64(a.size-a.next) {
ptr := a.base + a.next
a.next += uint32(size)
return ptr
}
ptr := a.c.new(size)
a.ptrs = append(a.ptrs, ptr)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint64(len(s) + 1))
a.c.mem.writeString(ptr, s)
return ptr
}
// DriverConn is implemented by the SQLite [database/sql] driver connection.
//
// It can be used to access advanced SQLite features like
// [savepoints] and [incremental BLOB I/O].
//
// [savepoints]: https://www.sqlite.org/lang_savepoint.html
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
type DriverConn interface {
driver.ConnBeginTx
driver.ExecerContext
driver.ConnPrepareContext
Savepoint() (release func(*error))
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
}

View File

@@ -2,105 +2,13 @@ package sqlite3
import (
"bytes"
"errors"
"math"
"testing"
)
func TestConn_Close(t *testing.T) {
var conn *Conn
conn.Close()
}
func TestConn_error_OOM(t *testing.T) {
t.Parallel()
func TestConn_Close_BUSY(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare("BEGIN")
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
err = db.Close()
if err == nil {
t.Fatal("want error")
}
var serr *Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != BUSY {
t.Errorf("got %d, want sqlite3.BUSY", rc)
}
if got := err.Error(); got != `sqlite3: database is locked: unable to close due to unfinalized statements or unfinished backups` {
t.Error("got message: ", got)
}
}
func TestConn_Prepare_Empty(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare("")
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt != nil {
t.Error("want nil")
}
}
func TestConn_Prepare_Invalid(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
var serr *Error
_, _, err = db.Prepare("SELECT")
if err == nil {
t.Fatal("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message: ", got)
}
_, _, err = db.Prepare("SELECT * FRM sqlite_schema")
if err == nil {
t.Fatal("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.ERROR", err)
}
if rc := serr.Code(); rc != ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := serr.SQL(); got != `FRM sqlite_schema` {
t.Error("got SQL: ", got)
}
if got := serr.Error(); got != `sqlite3: SQL logic error: near "FRM": syntax error` {
t.Error("got message: ", got)
}
}
func TestConn_new(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
@@ -108,11 +16,79 @@ func TestConn_new(t *testing.T) {
defer db.Close()
defer func() { _ = recover() }()
db.new(math.MaxUint32)
db.error(uint64(NOMEM))
t.Error("want panic")
}
func TestConn_call_nil(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
db.call(db.api.free)
t.Error("want panic")
}
func TestConn_new(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
testOOM := func(size uint64) {
defer func() { _ = recover() }()
db.new(size)
t.Error("want panic")
}
testOOM(math.MaxUint32)
testOOM(_MAX_ALLOCATION_SIZE)
}
func TestConn_newArena(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
arena := db.newArena(16)
defer arena.free()
const title = "Lorem ipsum"
ptr := arena.string(title)
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := db.mem.readString(ptr, math.MaxUint32); got != title {
t.Errorf("got %q, want %q", got, title)
}
const body = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
ptr = arena.string(body)
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := db.mem.readString(ptr, math.MaxUint32); got != body {
t.Errorf("got %q, want %q", got, body)
}
arena.free()
}
func TestConn_newBytes(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
@@ -131,12 +107,14 @@ func TestConn_newBytes(t *testing.T) {
}
want := buf
if got := db.mem.view(ptr, uint32(len(want))); !bytes.Equal(got, want) {
if got := db.mem.view(ptr, uint64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
func TestConn_newString(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
@@ -155,12 +133,14 @@ func TestConn_newString(t *testing.T) {
}
want := str + "\000"
if got := db.mem.view(ptr, uint32(len(want))); string(got) != want {
if got := db.mem.view(ptr, uint64(len(want))); string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestConn_getString(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
@@ -200,6 +180,8 @@ func TestConn_getString(t *testing.T) {
}
func TestConn_free(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)

View File

@@ -9,11 +9,17 @@ const (
_UTF8 = 1
_MAX_STRING = 512 // Used for short strings: names, error messages…
_MAX_PATHNAME = 512
_MAX_ALLOCATION_SIZE = 0x7ffffeff
ptrlen = 4
)
// ErrorCode is a result code that [Error.Code] might return.
//
// https://www.sqlite.org/rescode.html
type ErrorCode uint8
const (
@@ -47,6 +53,9 @@ const (
WARNING ErrorCode = 28 /* Warnings from sqlite3_log() */
)
// ExtendedErrorCode is a result code that [Error.ExtendedCode] might return.
//
// https://www.sqlite.org/rescode.html
type (
ExtendedErrorCode uint16
xErrorCode = ExtendedErrorCode
@@ -128,6 +137,9 @@ const (
AUTH_USER ExtendedErrorCode = xErrorCode(AUTH) | (1 << 8)
)
// OpenFlag is a flag for a file open operation.
//
// https://www.sqlite.org/c3ref/c_open_autoproxy.html
type OpenFlag uint32
const (
@@ -155,14 +167,17 @@ const (
OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */
)
type AccessFlag uint32
type _AccessFlag uint32
const (
ACCESS_EXISTS AccessFlag = 0
ACCESS_READWRITE AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
ACCESS_READ AccessFlag = 2 /* Unused */
_ACCESS_EXISTS _AccessFlag = 0
_ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
_ACCESS_READ _AccessFlag = 2 /* Unused */
)
// PrepareFlag is a flag that can be passed to [Conn.PrepareFlags].
//
// https://www.sqlite.org/c3ref/c_prepare_normalize.html
type PrepareFlag uint32
const (
@@ -171,6 +186,9 @@ const (
PREPARE_NO_VTAB PrepareFlag = 0x04
)
// Datatype is a fundamental datatype of SQLite.
//
// https://www.sqlite.org/c3ref/c_blob.html
type Datatype uint32
const (
@@ -181,6 +199,7 @@ const (
NULL Datatype = 5
)
// String implements the [fmt.Stringer] interface.
func (t Datatype) String() string {
const name = "INTEGERFLOATTEXTBLOBNULL"
switch t {

View File

@@ -3,6 +3,8 @@ package sqlite3
import "testing"
func TestDatatype_String(t *testing.T) {
t.Parallel()
tests := []struct {
data Datatype
want string

379
driver/driver.go Normal file
View File

@@ -0,0 +1,379 @@
// Package driver provides a database/sql driver for SQLite.
//
// Importing package driver registers a [database/sql] driver named "sqlite3".
// You may also need to import package embed.
//
// import _ "github.com/ncruces/go-sqlite3/driver"
// import _ "github.com/ncruces/go-sqlite3/embed"
//
// The data source name for "sqlite3" databases can be a filename or a "file:" [URI].
//
// The [TRANSACTION] mode can be specified using "_txlock":
//
// sql.Open("sqlite3", "file:demo.db?_txlock=immediate")
//
// [PRAGMA] statements can be specified using "_pragma":
//
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
//
// If no PRAGMAs are specifed, a busy timeout of 1 minute
// and normal locking mode are used.
//
// [URI]: https://www.sqlite.org/uri.html
// [PRAGMA]: https://www.sqlite.org/pragma.html
// [TRANSACTION]: https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
package driver
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"net/url"
"strings"
"time"
"github.com/ncruces/go-sqlite3"
)
func init() {
sql.Register("sqlite3", sqlite{})
}
type sqlite struct{}
func (sqlite) Open(name string) (_ driver.Conn, err error) {
c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE)
if err != nil {
return nil, err
}
var txBegin string
var pragmas []string
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
query, _ := url.ParseQuery(after)
switch s := query.Get("_txlock"); s {
case "":
txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
txBegin = "BEGIN " + s
default:
c.Close()
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
}
pragmas = query["_pragma"]
}
}
if len(pragmas) == 0 {
err := c.Exec(`
PRAGMA busy_timeout=60000;
PRAGMA locking_mode=normal;
`)
if err != nil {
c.Close()
return nil, err
}
}
return conn{
conn: c,
txBegin: txBegin,
}, nil
}
type conn struct {
conn *sqlite3.Conn
txBegin string
txCommit string
}
var (
// Ensure these interfaces are implemented:
_ driver.ExecerContext = conn{}
_ driver.ConnBeginTx = conn{}
_ sqlite3.DriverConn = conn{}
)
func (c conn) Close() error {
return c.conn.Close()
}
func (c conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
switch opts.Isolation {
default:
return nil, isolationErr
case driver.IsolationLevel(sql.LevelDefault):
case driver.IsolationLevel(sql.LevelSerializable):
}
txBegin := c.txBegin
c.txCommit = `COMMIT`
if opts.ReadOnly {
c.txCommit = `
ROLLBACK;
PRAGMA query_only=` + c.conn.Pragma("query_only")[0]
txBegin = `
BEGIN deferred;
PRAGMA query_only=on`
}
err := c.conn.Exec(txBegin)
if err != nil {
return nil, err
}
return c, nil
}
func (c conn) Commit() error {
err := c.conn.Exec(c.txCommit)
if err != nil {
c.Rollback()
}
return err
}
func (c conn) Rollback() error {
return c.conn.Exec(`ROLLBACK`)
}
func (c conn) Prepare(query string) (driver.Stmt, error) {
s, tail, err := c.conn.Prepare(query)
if err != nil {
return nil, err
}
if tail != "" {
// Check if the tail contains any SQL.
st, _, err := c.conn.Prepare(tail)
if err != nil {
s.Close()
return nil, err
}
if st != nil {
s.Close()
st.Close()
return nil, tailErr
}
}
return stmt{s, c.conn}, nil
}
func (c conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
return c.Prepare(query)
}
func (c conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if len(args) != 0 {
// Slow path.
return nil, driver.ErrSkip
}
old := c.conn.SetInterrupt(ctx)
defer c.conn.SetInterrupt(old)
err := c.conn.Exec(query)
if err != nil {
return nil, err
}
return result{
c.conn.LastInsertRowID(),
c.conn.Changes(),
}, nil
}
func (c conn) Savepoint() (release func(*error)) {
return c.conn.Savepoint()
}
func (c conn) OpenBlob(db, table, column string, row int64, write bool) (*sqlite3.Blob, error) {
return c.conn.OpenBlob(db, table, column, row, write)
}
type stmt struct {
stmt *sqlite3.Stmt
conn *sqlite3.Conn
}
var (
// Ensure these interfaces are implemented:
_ driver.StmtExecContext = stmt{}
_ driver.StmtQueryContext = stmt{}
_ driver.NamedValueChecker = stmt{}
)
func (s stmt) Close() error {
return s.stmt.Close()
}
func (s stmt) NumInput() int {
n := s.stmt.BindCount()
for i := 1; i <= n; i++ {
if s.stmt.BindName(i) != "" {
return -1
}
}
return n
}
// Deprecated: use ExecContext instead.
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
return s.ExecContext(context.Background(), namedValues(args))
}
// Deprecated: use QueryContext instead.
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
return s.QueryContext(context.Background(), namedValues(args))
}
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
// Use QueryContext to setup bindings.
// No need to close rows: that simply resets the statement, exec does the same.
_, err := s.QueryContext(ctx, args)
if err != nil {
return nil, err
}
err = s.stmt.Exec()
if err != nil {
return nil, err
}
return result{
int64(s.conn.LastInsertRowID()),
int64(s.conn.Changes()),
}, nil
}
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
err := s.stmt.ClearBindings()
if err != nil {
return nil, err
}
var ids [3]int
for _, arg := range args {
ids := ids[:0]
if arg.Name == "" {
ids = append(ids, arg.Ordinal)
} else {
for _, prefix := range []string{":", "@", "$"} {
if id := s.stmt.BindIndex(prefix + arg.Name); id != 0 {
ids = append(ids, id)
}
}
}
for _, id := range ids {
switch a := arg.Value.(type) {
case bool:
err = s.stmt.BindBool(id, a)
case int:
err = s.stmt.BindInt(id, a)
case int64:
err = s.stmt.BindInt64(id, a)
case float64:
err = s.stmt.BindFloat(id, a)
case string:
err = s.stmt.BindText(id, a)
case []byte:
err = s.stmt.BindBlob(id, a)
case sqlite3.ZeroBlob:
err = s.stmt.BindZeroBlob(id, int64(a))
case time.Time:
err = s.stmt.BindText(id, a.Format(time.RFC3339Nano))
case nil:
err = s.stmt.BindNull(id)
default:
panic(assertErr)
}
}
if err != nil {
return nil, err
}
}
return rows{ctx, s.stmt, s.conn}, nil
}
func (s stmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) {
case bool, int, int64, float64, string, []byte,
sqlite3.ZeroBlob, time.Time, nil:
return nil
default:
return driver.ErrSkip
}
}
type result struct{ lastInsertId, rowsAffected int64 }
func (r result) LastInsertId() (int64, error) {
return r.lastInsertId, nil
}
func (r result) RowsAffected() (int64, error) {
return r.rowsAffected, nil
}
type rows struct {
ctx context.Context
stmt *sqlite3.Stmt
conn *sqlite3.Conn
}
func (r rows) Close() error {
return r.stmt.Reset()
}
func (r rows) Columns() []string {
count := r.stmt.ColumnCount()
columns := make([]string, count)
for i := range columns {
columns[i] = r.stmt.ColumnName(i)
}
return columns
}
func (r rows) Next(dest []driver.Value) error {
old := r.conn.SetInterrupt(r.ctx)
defer r.conn.SetInterrupt(old)
if !r.stmt.Step() {
if err := r.stmt.Err(); err != nil {
return err
}
return io.EOF
}
for i := range dest {
switch r.stmt.ColumnType(i) {
case sqlite3.INTEGER:
dest[i] = r.stmt.ColumnInt64(i)
case sqlite3.FLOAT:
dest[i] = r.stmt.ColumnFloat(i)
case sqlite3.TEXT:
dest[i] = maybeTime(r.stmt.ColumnText(i))
case sqlite3.BLOB:
buf, _ := dest[i].([]byte)
dest[i] = r.stmt.ColumnBlob(i, buf)
case sqlite3.NULL:
if buf, ok := dest[i].([]byte); ok {
dest[i] = buf[0:0]
} else {
dest[i] = nil
}
default:
panic(assertErr)
}
}
return r.stmt.Err()
}

309
driver/driver_test.go Normal file
View File

@@ -0,0 +1,309 @@
// Package driver provides a database/sql driver for SQLite.
package driver
import (
"bytes"
"context"
"database/sql"
"errors"
"math"
"path/filepath"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
)
func Test_Open_dir(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", ".")
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Conn(context.TODO())
if err == nil {
t.Fatal("want error")
}
if !errors.Is(err, sqlite3.CANTOPEN) {
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
}
}
func Test_Open_pragma(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", "file::memory:?_pragma=busy_timeout(1000)")
if err != nil {
t.Fatal(err)
}
defer db.Close()
var timeout int
err = db.QueryRow(`PRAGMA busy_timeout`).Scan(&timeout)
if err != nil {
t.Fatal(err)
}
if timeout != 1000 {
t.Errorf("got %v, want 1000", timeout)
}
}
func Test_Open_pragma_invalid(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", "file::memory:?_pragma=busy_timeout+1000")
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Conn(context.TODO())
if err == nil {
t.Fatal("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: invalid _pragma: sqlite3: SQL logic error: near "1000": syntax error` {
t.Error("got message:", got)
}
}
func Test_Open_txLock(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", "file:"+
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
"?_txlock=exclusive&_pragma=busy_timeout(0)")
if err != nil {
t.Fatal(err)
}
defer db.Close()
tx1, err := db.Begin()
if err != nil {
t.Fatal(err)
}
_, err = db.Begin()
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.BUSY) {
t.Errorf("got %v, want sqlite3.BUSY", err)
}
var terr interface{ Temporary() bool }
if !errors.As(err, &terr) || !terr.Temporary() {
t.Error("not temporary", err)
}
err = tx1.Commit()
if err != nil {
t.Fatal(err)
}
}
func Test_Open_txLock_invalid(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", "file::memory:?_txlock=xclusive")
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Conn(context.TODO())
if err == nil {
t.Fatal("want error")
}
if got := err.Error(); got != `sqlite3: invalid _txlock: xclusive` {
t.Error("got message:", got)
}
}
func Test_BeginTx(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.db"))
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err.Error() != string(isolationErr) {
t.Error("want isolationErr")
}
tx1, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
t.Fatal(err)
}
tx2, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
t.Fatal(err)
}
_, err = tx1.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.READONLY) {
t.Errorf("got %v, want sqlite3.READONLY", err)
}
err = tx2.Commit()
if err != nil {
t.Fatal(err)
}
err = tx1.Commit()
if err != nil {
t.Fatal(err)
}
}
func Test_Prepare(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, err := db.Prepare(`SELECT 1; -- HERE`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
var serr *sqlite3.Error
_, err = db.Prepare(`SELECT`)
if err == nil {
t.Error("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message:", got)
}
_, err = db.Prepare(`SELECT 1; SELECT`)
if err == nil {
t.Error("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message:", got)
}
_, err = db.Prepare(`SELECT 1; SELECT 2`)
if err.Error() != string(tailErr) {
t.Error("want tailErr")
}
}
func Test_QueryRow_named(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
stmt, err := conn.PrepareContext(ctx, `SELECT ?, ?5, :AAA, @AAA, $AAA`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
date := time.Now()
row := stmt.QueryRow(true, sql.Named("AAA", math.Pi), nil /*3*/, nil /*4*/, date /*5*/)
var first bool
var fifth time.Time
var colon, at, dollar float32
err = row.Scan(&first, &fifth, &colon, &at, &dollar)
if err != nil {
t.Fatal(err)
}
if first != true {
t.Errorf("want true, got %v", first)
}
if colon != math.Pi {
t.Errorf("want π, got %v", colon)
}
if at != math.Pi {
t.Errorf("want π, got %v", at)
}
if dollar != math.Pi {
t.Errorf("want π, got %v", dollar)
}
if !fifth.Equal(date) {
t.Errorf("want %v, got %v", date, fifth)
}
}
func Test_QueryRow_blob_null(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
rows, err := db.Query(`
SELECT NULL UNION ALL
SELECT x'cafe' UNION ALL
SELECT x'babe' UNION ALL
SELECT NULL
`)
if err != nil {
t.Fatal(err)
}
want := [][]byte{nil, {0xca, 0xfe}, {0xba, 0xbe}, nil}
for i := 0; rows.Next(); i++ {
var buf []byte
err = rows.Scan(&buf)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf, want[i]) {
t.Errorf("got %q, want %q", buf, want[i])
}
}
}

11
driver/error.go Normal file
View File

@@ -0,0 +1,11 @@
package driver
type errorString string
func (e errorString) Error() string { return string(e) }
const (
assertErr = errorString("sqlite3: assertion failed")
tailErr = errorString("sqlite3: multiple statements")
isolationErr = errorString("sqlite3: unsupported isolation level")
)

149
driver/example_test.go Normal file
View File

@@ -0,0 +1,149 @@
package driver_test
// Adapted from: https://go.dev/doc/tutorial/database-access
import (
"database/sql"
"fmt"
"log"
"os"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
var db *sql.DB
type Album struct {
ID int64
Title string
Artist string
Price float32
}
func Example() {
// Get a database handle.
var err error
db, err = sql.Open("sqlite3", "./recordings.db")
if err != nil {
log.Fatal(err)
}
defer db.Close()
defer os.Remove("./recordings.db")
// Create a table with some data in it.
err = albumsSetup()
if err != nil {
log.Fatal(err)
}
albums, err := albumsByArtist("John Coltrane")
if err != nil {
log.Fatal(err)
}
fmt.Printf("Albums found: %v\n", albums)
// Hard-code ID 2 here to test the query.
alb, err := albumByID(2)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Album found: %v\n", alb)
albID, err := addAlbum(Album{
Title: "The Modern Sound of Betty Carter",
Artist: "Betty Carter",
Price: 49.99,
})
if err != nil {
log.Fatal(err)
}
fmt.Printf("ID of added album: %v\n", albID)
// Output:
// Albums found: [{1 Blue Train John Coltrane 56.99} {2 Giant Steps John Coltrane 63.99}]
// Album found: {2 Giant Steps John Coltrane 63.99}
// ID of added album: 5
}
func albumsSetup() error {
_, err := db.Exec(`
DROP TABLE IF EXISTS album;
CREATE TABLE album (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title VARCHAR(128) NOT NULL,
artist VARCHAR(255) NOT NULL,
price DECIMAL(5,2) NOT NULL
);
`)
if err != nil {
return err
}
_, err = db.Exec(`
INSERT INTO album
(title, artist, price)
VALUES
('Blue Train', 'John Coltrane', 56.99),
('Giant Steps', 'John Coltrane', 63.99),
('Jeru', 'Gerry Mulligan', 17.99),
('Sarah Vaughan', 'Sarah Vaughan', 34.98)
`)
if err != nil {
return err
}
return nil
}
// albumsByArtist queries for albums that have the specified artist name.
func albumsByArtist(name string) ([]Album, error) {
// An albums slice to hold data from returned rows.
var albums []Album
rows, err := db.Query("SELECT * FROM album WHERE artist = ?", name)
if err != nil {
return nil, fmt.Errorf("albumsByArtist %q: %w", name, err)
}
defer rows.Close()
// Loop through rows, using Scan to assign column data to struct fields.
for rows.Next() {
var alb Album
if err := rows.Scan(&alb.ID, &alb.Title, &alb.Artist, &alb.Price); err != nil {
return nil, fmt.Errorf("albumsByArtist %q: %w", name, err)
}
albums = append(albums, alb)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("albumsByArtist %q: %w", name, err)
}
return albums, nil
}
// albumByID queries for the album with the specified ID.
func albumByID(id int64) (Album, error) {
// An album to hold data from the returned row.
var alb Album
row := db.QueryRow("SELECT * FROM album WHERE id = ?", id)
if err := row.Scan(&alb.ID, &alb.Title, &alb.Artist, &alb.Price); err != nil {
if err == sql.ErrNoRows {
return alb, fmt.Errorf("albumsById %d: no such album", id)
}
return alb, fmt.Errorf("albumsById %d: %w", id, err)
}
return alb, nil
}
// addAlbum adds the specified album to the database,
// returning the album ID of the new entry
func addAlbum(alb Album) (int64, error) {
result, err := db.Exec("INSERT INTO album (title, artist, price) VALUES (?, ?, ?)", alb.Title, alb.Artist, alb.Price)
if err != nil {
return 0, fmt.Errorf("addAlbum: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("addAlbum: %w", err)
}
return id, nil
}

31
driver/time.go Normal file
View File

@@ -0,0 +1,31 @@
package driver
import (
"database/sql/driver"
"time"
)
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
// if it roundtrips back to the same string.
// This way times can be persisted to, and recovered from, the database,
// but if a string is needed, [database/sql] will recover the same string.
func maybeTime(text string) driver.Value {
// Weed out (some) values that can't possibly be
// [time.RFC3339Nano] timestamps.
if len(text) < len("2006-01-02T15:04:05Z") {
return text
}
if len(text) > len(time.RFC3339Nano) {
return text
}
if text[4] != '-' || text[10] != 'T' || text[16] != ':' {
return text
}
// Slow path.
date, err := time.Parse(time.RFC3339Nano, text)
if err == nil && date.Format(time.RFC3339Nano) == text {
return date
}
return text
}

100
driver/time_test.go Normal file
View File

@@ -0,0 +1,100 @@
package driver
import (
"testing"
"time"
)
// This checks that any string can be recovered as the same string.
func Fuzz_maybeTime_1(f *testing.F) {
f.Add("")
f.Add(" ")
f.Add("SQLite")
f.Add(time.RFC3339)
f.Add(time.RFC3339Nano)
f.Add(time.Layout)
f.Add(time.DateTime)
f.Add(time.DateOnly)
f.Add(time.TimeOnly)
f.Add("2006-01-02T15:04:05Z")
f.Add("2006-01-02T15:04:05.000Z")
f.Add("2006-01-02T15:04:05.9999999999Z")
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
f.Fuzz(func(t *testing.T, str string) {
value := maybeTime(str)
switch v := value.(type) {
case time.Time:
// Make sure times round-trip to the same string:
// https://pkg.go.dev/database/sql#Rows.Scan
if v.Format(time.RFC3339Nano) != str {
t.Fatalf("did not round-trip: %q", str)
}
case string:
if v != str {
t.Fatalf("did not round-trip: %q", str)
}
date, err := time.Parse(time.RFC3339Nano, str)
if err == nil && date.Format(time.RFC3339Nano) == str {
t.Fatalf("would round-trip: %q", str)
}
default:
t.Fatalf("invalid type %T: %q", v, str)
}
})
}
// This checks that any [time.Time] can be recovered as a [time.Time],
// with nanosecond accuracy, and preserving any timezone offset.
func Fuzz_maybeTime_2(f *testing.F) {
f.Add(0, 0)
f.Add(0, 1)
f.Add(0, -1)
f.Add(0, 999_999_999)
f.Add(0, 1_000_000_000)
f.Add(7956915742, 222_222_222) // twosday
f.Add(639095955742, 222_222_222) // twosday, year 22222AD
f.Add(-763421161058, 222_222_222) // twosday, year 22222BC
checkTime := func(t *testing.T, date time.Time) {
value := maybeTime(date.Format(time.RFC3339Nano))
switch v := value.(type) {
case time.Time:
// Make sure times round-trip to the same time:
if !v.Equal(date) {
t.Fatalf("did not round-trip: %v", date)
}
// Make with the same zone offset:
_, off1 := v.Zone()
_, off2 := date.Zone()
if off1 != off2 {
t.Fatalf("did not round-trip: %v", date)
}
case string:
t.Fatalf("was not recovered: %v", date)
default:
t.Fatalf("invalid type %T: %v", v, date)
}
}
f.Fuzz(func(t *testing.T, sec, nsec int) {
// Reduce the search space.
if 1e12 < sec || sec < -1e12 {
// Dates before 29000BC and after 33000AD; I think we're safe.
return
}
if 0 < nsec || nsec > 1e10 {
// Out of range nsec: [time.Time.Unix] handles these.
return
}
unix := time.Unix(int64(sec), int64(nsec))
checkTime(t, unix)
checkTime(t, unix.UTC())
checkTime(t, unix.In(time.FixedZone("", -8*3600)))
checkTime(t, unix.In(time.FixedZone("", +8*3600)))
})
}

14
driver/util.go Normal file
View File

@@ -0,0 +1,14 @@
package driver
import "database/sql/driver"
func namedValues(args []driver.Value) []driver.NamedValue {
named := make([]driver.NamedValue, len(args))
for i, v := range args {
named[i] = driver.NamedValue{
Ordinal: i + 1,
Value: v,
}
}
return named
}

18
driver/util_test.go Normal file
View File

@@ -0,0 +1,18 @@
package driver
import (
"database/sql/driver"
"reflect"
"testing"
)
func Test_namedValues(t *testing.T) {
want := []driver.NamedValue{
{Ordinal: 1, Value: true},
{Ordinal: 2, Value: false},
}
got := namedValues([]driver.Value{true, false})
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
}

74
driver_test.go Normal file
View File

@@ -0,0 +1,74 @@
package sqlite3_test
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
var db *sql.DB
func ExampleDriverConn() {
var err error
db, err = sql.Open("sqlite3", "demo.db")
if err != nil {
log.Fatal(err)
}
defer db.Close()
defer os.Remove("demo.db")
ctx := context.Background()
conn, err := db.Conn(ctx)
if err != nil {
log.Fatal(err)
}
defer conn.Close()
_, err = conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
log.Fatal(err)
}
res, err := conn.ExecContext(ctx, `INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(11))
if err != nil {
log.Fatal(err)
}
id, err := res.LastInsertId()
if err != nil {
log.Fatal(err)
}
err = conn.Raw(func(driverConn any) error {
conn := driverConn.(sqlite3.DriverConn)
defer conn.Savepoint()(&err)
blob, err := conn.OpenBlob("main", "test", "col", id, true)
if err != nil {
return err
}
defer blob.Close()
_, err = fmt.Fprint(blob, "Hello BLOB!")
return err
})
if err != nil {
log.Fatal(err)
}
var msg string
err = conn.QueryRowContext(ctx, `SELECT col FROM test`).Scan(&msg)
if err != nil {
log.Fatal(err)
}
fmt.Println(msg)
// Output:
// Hello BLOB!
}

View File

@@ -8,13 +8,13 @@ cd -P -- "$(dirname -- "$0")"
# build SQLite
zig cc --target=wasm32-wasi -flto -g0 -Os \
-o sqlite3.wasm ../sqlite3/*.c \
-o sqlite3.wasm ../sqlite3/amalg.c \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-D_HAVE_SQLITE_CONFIG_H \
-Wl,--export=malloc \
-Wl,--export=free \
-Wl,--export=malloc \
-Wl,--export=malloc_destructor \
-Wl,--export=sqlite3_errcode \
-Wl,--export=sqlite3_errstr \
@@ -28,15 +28,36 @@ zig cc --target=wasm32-wasi -flto -g0 -Os \
-Wl,--export=sqlite3_step \
-Wl,--export=sqlite3_exec \
-Wl,--export=sqlite3_clear_bindings \
-Wl,--export=sqlite3_bind_parameter_count \
-Wl,--export=sqlite3_bind_parameter_index \
-Wl,--export=sqlite3_bind_parameter_name \
-Wl,--export=sqlite3_bind_null \
-Wl,--export=sqlite3_bind_int64 \
-Wl,--export=sqlite3_bind_double \
-Wl,--export=sqlite3_bind_text64 \
-Wl,--export=sqlite3_bind_blob64 \
-Wl,--export=sqlite3_bind_zeroblob64 \
-Wl,--export=sqlite3_bind_null \
-Wl,--export=sqlite3_column_count \
-Wl,--export=sqlite3_column_name \
-Wl,--export=sqlite3_column_type \
-Wl,--export=sqlite3_column_int64 \
-Wl,--export=sqlite3_column_double \
-Wl,--export=sqlite3_column_text \
-Wl,--export=sqlite3_column_blob \
-Wl,--export=sqlite3_column_bytes \
-Wl,--export=sqlite3_column_type \
-Wl,--export=sqlite3_blob_open \
-Wl,--export=sqlite3_blob_close \
-Wl,--export=sqlite3_blob_bytes \
-Wl,--export=sqlite3_blob_read \
-Wl,--export=sqlite3_blob_write \
-Wl,--export=sqlite3_blob_reopen \
-Wl,--export=sqlite3_get_autocommit \
-Wl,--export=sqlite3_last_insert_rowid \
-Wl,--export=sqlite3_changes64 \
-Wl,--export=sqlite3_unlock_notify \
-Wl,--export=sqlite3_backup_init \
-Wl,--export=sqlite3_backup_step \
-Wl,--export=sqlite3_backup_finish \
-Wl,--export=sqlite3_backup_remaining \
-Wl,--export=sqlite3_backup_pagecount \
-Wl,--export=sqlite3_interrupt_offset \

View File

@@ -1,3 +1,12 @@
// Package embed embeds SQLite into your application.
//
// Importing package embed initializes the [sqlite3.Binary] variable
// with an appropriate build of SQLite:
//
// import _ "github.com/ncruces/go-sqlite3/embed"
//
// You can obtain this build of SQLite from:
// https://github.com/ncruces/go-sqlite3/tree/main/embed
package embed
import (

Binary file not shown.

140
error.go
View File

@@ -50,11 +50,144 @@ func (e *Error) Error() string {
return b.String()
}
// Is tests whether this error matches a given [ErrorCode] or [ExtendedErrorCode].
//
// It makes it possible to do:
//
// if errors.Is(err, sqlite3.BUSY) {
// // ... handle BUSY
// }
func (e *Error) Is(err error) bool {
switch c := err.(type) {
case ErrorCode:
return c == e.Code()
case ExtendedErrorCode:
return c == e.ExtendedCode()
}
return false
}
// Temporary returns true for [BUSY] errors.
func (e *Error) Temporary() bool {
return e.Code() == BUSY
}
// Timeout returns true for [BUSY_TIMEOUT] errors.
func (e *Error) Timeout() bool {
return e.ExtendedCode() == BUSY_TIMEOUT
}
// SQL returns the SQL starting at the token that triggered a syntax error.
func (e *Error) SQL() string {
return e.sql
}
// Error implements the error interface.
func (e ErrorCode) Error() string {
switch e {
case _OK:
return "sqlite3: not an error"
case _ROW:
return "sqlite3: another row available"
case _DONE:
return "sqlite3: no more rows available"
case ERROR:
return "sqlite3: SQL logic error"
case INTERNAL:
break
case PERM:
return "sqlite3: access permission denied"
case ABORT:
return "sqlite3: query aborted"
case BUSY:
return "sqlite3: database is locked"
case LOCKED:
return "sqlite3: database table is locked"
case NOMEM:
return "sqlite3: out of memory"
case READONLY:
return "sqlite3: attempt to write a readonly database"
case INTERRUPT:
return "sqlite3: interrupted"
case IOERR:
return "sqlite3: disk I/O error"
case CORRUPT:
return "sqlite3: database disk image is malformed"
case NOTFOUND:
return "sqlite3: unknown operation"
case FULL:
return "sqlite3: database or disk is full"
case CANTOPEN:
return "sqlite3: unable to open database file"
case PROTOCOL:
return "sqlite3: locking protocol"
case FORMAT:
break
case SCHEMA:
return "sqlite3: database schema has changed"
case TOOBIG:
return "sqlite3: string or blob too big"
case CONSTRAINT:
return "sqlite3: constraint failed"
case MISMATCH:
return "sqlite3: datatype mismatch"
case MISUSE:
return "sqlite3: bad parameter or other API misuse"
case NOLFS:
break
case AUTH:
return "sqlite3: authorization denied"
case EMPTY:
break
case RANGE:
return "sqlite3: column index out of range"
case NOTADB:
return "sqlite3: file is not a database"
case NOTICE:
return "sqlite3: notification message"
case WARNING:
return "sqlite3: warning message"
}
return "sqlite3: unknown error"
}
// Temporary returns true for [BUSY] errors.
func (e ErrorCode) Temporary() bool {
return e == BUSY
}
// Error implements the error interface.
func (e ExtendedErrorCode) Error() string {
switch x := ErrorCode(e); {
case e == ABORT_ROLLBACK:
return "sqlite3: abort due to ROLLBACK"
case x < _ROW:
return x.Error()
case e == _ROW:
return "sqlite3: another row available"
case e == _DONE:
return "sqlite3: no more rows available"
}
return "sqlite3: unknown error"
}
// Is tests whether this error matches a given [ErrorCode].
func (e ExtendedErrorCode) Is(err error) bool {
c, ok := err.(ErrorCode)
return ok && c == ErrorCode(e)
}
// Temporary returns true for [BUSY] errors.
func (e ExtendedErrorCode) Temporary() bool {
return ErrorCode(e) == BUSY
}
// Timeout returns true for [BUSY_TIMEOUT] errors.
func (e ExtendedErrorCode) Timeout() bool {
return e == BUSY_TIMEOUT
}
type errorString string
func (e errorString) Error() string { return string(e) }
@@ -66,6 +199,13 @@ const (
noNulErr = errorString("sqlite3: missing NUL terminator")
noGlobalErr = errorString("sqlite3: could not find global: ")
noFuncErr = errorString("sqlite3: could not find function: ")
binaryErr = errorString("sqlite3: no SQLite binary embed/set/loaded")
timeErr = errorString("sqlite3: invalid time value")
emptyErr = errorString("sqlite3: empty statement")
tailErr = errorString("sqlite3: non-empty tail")
notImplErr = errorString("sqlite3: not implemented")
whenceErr = errorString("sqlite3: invalid whence")
offsetErr = errorString("sqlite3: invalid offset")
)
func assertErr() errorString {

View File

@@ -1,26 +1,157 @@
package sqlite3
import (
"context"
"errors"
"strings"
"testing"
)
func TestError(t *testing.T) {
err := Error{code: 0x8080}
if rc := err.Code(); rc != 0x80 {
t.Errorf("got %#x, want 0x80", rc)
}
if rc := err.ExtendedCode(); rc != 0x8080 {
t.Errorf("got %#x, want 0x8080", rc)
}
if s := err.Error(); s != "sqlite3: 32896" {
func Test_assertErr(t *testing.T) {
err := assertErr()
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:11)") {
t.Errorf("got %q", s)
}
}
func Test_assertErr(t *testing.T) {
err := assertErr()
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:22)") {
func TestError(t *testing.T) {
t.Parallel()
err := Error{code: 0x8080}
if rc := err.Code(); rc != 0x80 {
t.Errorf("got %#x, want 0x80", rc)
}
if !errors.Is(&err, ErrorCode(0x80)) {
t.Errorf("want true")
}
if rc := err.ExtendedCode(); rc != 0x8080 {
t.Errorf("got %#x, want 0x8080", rc)
}
if !errors.Is(&err, ExtendedErrorCode(0x8080)) {
t.Errorf("want true")
}
if s := err.Error(); s != "sqlite3: 32896" {
t.Errorf("got %q", s)
}
if !errors.Is(err.ExtendedCode(), ErrorCode(0x80)) {
t.Errorf("want true")
}
}
func TestError_Temporary(t *testing.T) {
t.Parallel()
tests := []struct {
name string
code uint64
want bool
}{
{"ERROR", uint64(ERROR), false},
{"BUSY", uint64(BUSY), true},
{"BUSY_RECOVERY", uint64(BUSY_RECOVERY), true},
{"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), true},
{"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
{
err := &Error{code: tt.code}
if got := err.Temporary(); got != tt.want {
t.Errorf("Error.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
}
}
{
err := ErrorCode(tt.code)
if got := err.Temporary(); got != tt.want {
t.Errorf("ErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
}
}
{
err := ExtendedErrorCode(tt.code)
if got := err.Temporary(); got != tt.want {
t.Errorf("ExtendedErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
}
}
})
}
}
func TestError_Timeout(t *testing.T) {
t.Parallel()
tests := []struct {
name string
code uint64
want bool
}{
{"ERROR", uint64(ERROR), false},
{"BUSY", uint64(BUSY), false},
{"BUSY_RECOVERY", uint64(BUSY_RECOVERY), false},
{"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), false},
{"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
{
err := &Error{code: tt.code}
if got := err.Timeout(); got != tt.want {
t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want)
}
}
{
err := ExtendedErrorCode(tt.code)
if got := err.Timeout(); got != tt.want {
t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want)
}
}
})
}
}
func Test_ErrorCode_Error(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
// Test all error codes.
for i := 0; i == int(ErrorCode(i)); i++ {
want := "sqlite3: "
r, _ := db.api.errstr.Call(context.TODO(), uint64(i))
if r != nil {
want += db.mem.readString(uint32(r[0]), _MAX_STRING)
}
got := ErrorCode(i).Error()
if got != want {
t.Fatalf("got %q, want %q, with %d", got, want, i)
}
}
}
func Test_ExtendedErrorCode_Error(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
// Test all extended error codes.
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
want := "sqlite3: "
r, _ := db.api.errstr.Call(context.TODO(), uint64(i))
if r != nil {
want += db.mem.readString(uint32(r[0]), _MAX_STRING)
}
got := ExtendedErrorCode(i).Error()
if got != want {
t.Fatalf("got %q, want %q, with %d", got, want, i)
}
}
}

View File

@@ -1,4 +1,4 @@
package main
package sqlite3_test
import (
"fmt"
@@ -8,8 +8,10 @@ import (
_ "github.com/ncruces/go-sqlite3/embed"
)
func main() {
db, err := sqlite3.Open(":memory:")
const memory = ":memory:"
func Example() {
db, err := sqlite3.Open(memory)
if err != nil {
log.Fatal(err)
}
@@ -19,15 +21,12 @@ func main() {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
if err != nil {
log.Fatal(err)
}
stmt := db.MustPrepare(`SELECT id, name FROM users`)
for stmt.Step() {
fmt.Println(stmt.ColumnInt(0), stmt.ColumnText(1))
@@ -45,4 +44,8 @@ func main() {
if err != nil {
log.Fatal(err)
}
// Output:
// 0 go
// 1 zig
// 2 whatever
}

2
go.mod
View File

@@ -4,7 +4,7 @@ go 1.19
require (
github.com/ncruces/julianday v0.1.5
github.com/tetratelabs/wazero v1.0.0-pre.8
github.com/tetratelabs/wazero v1.0.0-pre.9
golang.org/x/sync v0.1.0
golang.org/x/sys v0.5.0
)

4
go.sum
View File

@@ -1,7 +1,7 @@
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.0.0-pre.8 h1:Ir82PWj79WCppH+9ny73eGY2qv+oCnE3VwMY92cBSyI=
github.com/tetratelabs/wazero v1.0.0-pre.8/go.mod h1:u8wrFmpdrykiFK0DFPiFm5a4+0RzsdmXYVtijBKqUVo=
github.com/tetratelabs/wazero v1.0.0-pre.9 h1:2uVdi2bvTi/JQxG2cp3LRm2aRadd3nURn5jcfbvqZcw=
github.com/tetratelabs/wazero v1.0.0-pre.9/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=

15
mem.go
View File

@@ -11,11 +11,14 @@ type memory struct {
mod api.Module
}
func (m memory) view(ptr, size uint32) []byte {
func (m memory) view(ptr uint32, size uint64) []byte {
if ptr == 0 {
panic(nilErr)
}
buf, ok := m.mod.Memory().Read(ptr, size)
if size > math.MaxUint32 {
panic(rangeErr)
}
buf, ok := m.mod.Memory().Read(ptr, uint32(size))
if !ok {
panic(rangeErr)
}
@@ -99,9 +102,13 @@ func (m memory) readString(ptr, maxlen uint32) string {
}
}
func (m memory) writeBytes(ptr uint32, b []byte) {
buf := m.view(ptr, uint64(len(b)))
copy(buf, b)
}
func (m memory) writeString(ptr uint32, s string) {
siz := uint32(len(s) + 1)
buf := m.view(ptr, siz)
buf := m.view(ptr, uint64(len(s)+1))
buf[len(s)] = 0
copy(buf, s)
}

View File

@@ -19,6 +19,13 @@ func Test_memory_view_range(t *testing.T) {
t.Error("want panic")
}
func Test_memory_view_overflow(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.view(1, math.MaxInt64)
t.Error("want panic")
}
func Test_memory_readUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)

9
sqlite3/amalg.c Normal file
View File

@@ -0,0 +1,9 @@
#include <stddef.h>
#include "main.c"
#include "os.c"
#include "qsort.c"
#include "sqlite3.c"
sqlite3_destructor_type malloc_destructor = &free;
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);

View File

@@ -4,7 +4,7 @@ set -eo pipefail
cd -P -- "$(dirname -- "$0")"
if [ ! -f "sqlite3.c" ]; then
url="https://www.sqlite.org/2022/sqlite-amalgamation-3400100.zip"
url="https://sqlite.org/2023/sqlite-amalgamation-3410000.zip"
curl "$url" > sqlite.zip
unzip -d . sqlite.zip
mv sqlite-amalgamation-*/sqlite3* .

8
sqlite3/format.sh Executable file
View File

@@ -0,0 +1,8 @@
#!/usr/bin/env bash
cd -P -- "$(dirname -- "$0")"
clang-format -i \
main.c \
os.c \
qsort.c \
amalg.c

View File

@@ -1,5 +1,4 @@
#include <stdlib.h>
#include <time.h>
#include <stdbool.h>
#include "sqlite3.h"
@@ -8,93 +7,8 @@ int main() {
if (rc != SQLITE_OK) return 1;
}
int go_localtime(sqlite3_int64, struct tm *);
int go_randomness(sqlite3_vfs *, int nByte, char *zOut);
int go_sleep(sqlite3_vfs *, int microseconds);
int go_current_time(sqlite3_vfs *, double *);
int go_current_time_64(sqlite3_vfs *, sqlite3_int64 *);
int go_open(sqlite3_vfs *, sqlite3_filename zName, sqlite3_file *, int flags,
int *pOutFlags);
int go_delete(sqlite3_vfs *, const char *zName, int syncDir);
int go_access(sqlite3_vfs *, const char *zName, int flags, int *pResOut);
int go_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut);
struct go_file {
sqlite3_file base;
int id;
int eLock;
};
int go_close(sqlite3_file *);
int go_read(sqlite3_file *, void *, int iAmt, sqlite3_int64 iOfst);
int go_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
int go_truncate(sqlite3_file *, sqlite3_int64 size);
int go_sync(sqlite3_file *, int flags);
int go_file_size(sqlite3_file *, sqlite3_int64 *pSize);
int go_lock(sqlite3_file *pFile, int eLock);
int go_unlock(sqlite3_file *pFile, int eLock);
int go_check_reserved_lock(sqlite3_file *pFile, int *pResOut);
static int no_lock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; }
static int no_unlock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; }
static int no_check_reserved_lock(sqlite3_file *pFile, int *pResOut) {
*pResOut = 0;
return SQLITE_OK;
}
static int no_file_control(sqlite3_file *pFile, int op, void *pArg) {
return SQLITE_NOTFOUND;
}
static int no_sector_size(sqlite3_file *pFile) { return 0; }
static int no_device_characteristics(sqlite3_file *pFile) { return 0; }
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
return go_localtime((sqlite3_int64)*pTime, pTm);
}
static int go_open_c(sqlite3_vfs *vfs, sqlite3_filename zName,
sqlite3_file *file, int flags, int *pOutFlags) {
static const sqlite3_io_methods go_io = {
.iVersion = 1,
.xClose = go_close,
.xRead = go_read,
.xWrite = go_write,
.xTruncate = go_truncate,
.xSync = go_sync,
.xFileSize = go_file_size,
.xLock = go_lock,
.xUnlock = go_unlock,
.xCheckReservedLock = go_check_reserved_lock,
.xFileControl = no_file_control,
.xSectorSize = no_sector_size,
.xDeviceCharacteristics = no_device_characteristics,
};
int rc = go_open(vfs, zName, file, flags, pOutFlags);
file->pMethods = (char)rc == SQLITE_OK ? &go_io : NULL;
return rc;
}
sqlite3_vfs *os_vfs();
int sqlite3_os_init() {
static sqlite3_vfs go_vfs = {
.iVersion = 2,
.szOsFile = sizeof(struct go_file),
.mxPathname = 512,
.zName = "go",
.xOpen = go_open_c,
.xDelete = go_delete,
.xAccess = go_access,
.xFullPathname = go_full_pathname,
.xRandomness = go_randomness,
.xSleep = go_sleep,
.xCurrentTime = go_current_time,
.xCurrentTimeInt64 = go_current_time_64,
};
return sqlite3_vfs_register(&go_vfs, /*default=*/1);
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
}
sqlite3_destructor_type malloc_destructor = &free;

92
sqlite3/os.c Normal file
View File

@@ -0,0 +1,92 @@
#include <time.h>
#include "sqlite3.h"
int os_localtime(sqlite3_int64, struct tm *);
int os_randomness(sqlite3_vfs *, int nByte, char *zOut);
int os_sleep(sqlite3_vfs *, int microseconds);
int os_current_time(sqlite3_vfs *, double *);
int os_current_time_64(sqlite3_vfs *, sqlite3_int64 *);
int os_open(sqlite3_vfs *, sqlite3_filename zName, sqlite3_file *, int flags,
int *pOutFlags);
int os_delete(sqlite3_vfs *, const char *zName, int syncDir);
int os_access(sqlite3_vfs *, const char *zName, int flags, int *pResOut);
int os_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut);
struct os_file {
sqlite3_file base;
int id;
int lock;
};
int os_close(sqlite3_file *);
int os_read(sqlite3_file *, void *, int iAmt, sqlite3_int64 iOfst);
int os_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
int os_truncate(sqlite3_file *, sqlite3_int64 size);
int os_sync(sqlite3_file *, int flags);
int os_file_size(sqlite3_file *, sqlite3_int64 *pSize);
int os_file_control(sqlite3_file *pFile, int op, void *pArg);
int os_lock(sqlite3_file *pFile, int eLock);
int os_unlock(sqlite3_file *pFile, int eLock);
int os_check_reserved_lock(sqlite3_file *pFile, int *pResOut);
static int no_lock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; }
static int no_unlock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; }
static int no_check_reserved_lock(sqlite3_file *pFile, int *pResOut) {
*pResOut = 0;
return SQLITE_OK;
}
static int no_file_control(sqlite3_file *pFile, int op, void *pArg) {
return SQLITE_NOTFOUND;
}
static int no_sector_size(sqlite3_file *pFile) { return 0; }
static int no_device_characteristics(sqlite3_file *pFile) { return 0; }
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
return os_localtime((sqlite3_int64)*pTime, pTm);
}
static int os_open_w(sqlite3_vfs *vfs, sqlite3_filename zName,
sqlite3_file *file, int flags, int *pOutFlags) {
static const sqlite3_io_methods os_io = {
.iVersion = 1,
.xClose = os_close,
.xRead = os_read,
.xWrite = os_write,
.xTruncate = os_truncate,
.xSync = os_sync,
.xFileSize = os_file_size,
.xLock = os_lock,
.xUnlock = os_unlock,
.xCheckReservedLock = os_check_reserved_lock,
.xFileControl = no_file_control,
.xDeviceCharacteristics = no_device_characteristics,
};
int rc = os_open(vfs, zName, file, flags, pOutFlags);
file->pMethods = (char)rc == SQLITE_OK ? &os_io : NULL;
return rc;
}
sqlite3_vfs *os_vfs() {
static sqlite3_vfs os_vfs = {
.iVersion = 2,
.szOsFile = sizeof(struct os_file),
.mxPathname = 512,
.zName = "os",
.xOpen = os_open_w,
.xDelete = os_delete,
.xAccess = os_access,
.xFullPathname = os_full_pathname,
.xRandomness = os_randomness,
.xSleep = os_sleep,
.xCurrentTime = os_current_time,
.xCurrentTimeInt64 = os_current_time_64,
};
return &os_vfs;
}

14
sqlite3/qsort.c Normal file
View File

@@ -0,0 +1,14 @@
#include <stddef.h>
void qsort_r(void *, size_t, size_t,
int (*)(const void *, const void *, void *), void *);
typedef int (*cmpfun)(const void *, const void *);
static int wrapper_cmp(const void *v1, const void *v2, void *cmp) {
return ((cmpfun)cmp)(v1, v2);
}
void qsort(void *base, size_t nel, size_t width, cmpfun cmp) {
qsort_r(base, nel, width, wrapper_cmp, cmp);
}

View File

@@ -5,6 +5,9 @@
#define SQLITE_OS_OTHER 1
#define SQLITE_BYTEORDER 1234
#define HAVE_STDINT_H 1
#define HAVE_INTTYPES_H 1
#define HAVE_ISNAN 1
#define HAVE_USLEEP 1
#define HAVE_LOCALTIME_S 1
@@ -25,12 +28,33 @@
#define SQLITE_OMIT_AUTOINIT
#define SQLITE_USE_ALLOCA
// Need this to access WAL databases without the use of shared memory.
// Because WASM does not support shared memory,
// SQLite disables it for WASM builds.
// We set the default locking mode to EXCLUSIVE instead.
// https://www.sqlite.org/wal.html#noshm
#undef SQLITE_OMIT_WAL
#define SQLITE_DEFAULT_LOCKING_MODE 1
// Go uses UTF-8 everywhere.
#define SQLITE_OMIT_UTF16
// Remove some testing code.
#define SQLITE_UNTESTABLE
// Recommended Extensions
#define SQLITE_ENABLE_MATH_FUNCTIONS 1
#define SQLITE_ENABLE_JSON1 1
#define SQLITE_ENABLE_FTS3 1
#define SQLITE_ENABLE_FTS3_PARENTHESIS 1
#define SQLITE_ENABLE_FTS4 1
#define SQLITE_ENABLE_FTS5 1
#define SQLITE_ENABLE_RTREE 1
#define SQLITE_ENABLE_GEOPOLY 1
// Snapshot
// #define SQLITE_ENABLE_SNAPSHOT 1
// Session Extension
// #define SQLITE_ENABLE_SESSION 1
// #define SQLITE_ENABLE_PREUPDATE_HOOK 1
// Resumable Bulk Update Extension
// #define SQLITE_ENABLE_RBU 1
// Implemented in Go.
int localtime_s(struct tm *const pTm, time_t const *const pTime);

229
stmt.go
View File

@@ -2,6 +2,7 @@ package sqlite3
import (
"math"
"time"
)
// Stmt is a prepared statement object.
@@ -15,16 +16,15 @@ type Stmt struct {
// Close destroys the prepared statement object.
//
// It is safe to close a nil, zero or closed Stmt.
//
// https://www.sqlite.org/c3ref/finalize.html
func (s *Stmt) Close() error {
if s == nil {
if s == nil || s.handle == 0 {
return nil
}
r, err := s.c.api.finalize.Call(s.c.ctx, uint64(s.handle))
if err != nil {
return err
}
r := s.c.call(s.c.api.finalize, uint64(s.handle))
s.handle = 0
return s.c.error(r[0])
@@ -34,10 +34,7 @@ func (s *Stmt) Close() error {
//
// https://www.sqlite.org/c3ref/reset.html
func (s *Stmt) Reset() error {
r, err := s.c.api.reset.Call(s.c.ctx, uint64(s.handle))
if err != nil {
return err
}
r := s.c.call(s.c.api.reset, uint64(s.handle))
s.err = nil
return s.c.error(r[0])
}
@@ -46,10 +43,7 @@ func (s *Stmt) Reset() error {
//
// https://www.sqlite.org/c3ref/clear_bindings.html
func (s *Stmt) ClearBindings() error {
r, err := s.c.api.clearBindings.Call(s.c.ctx, uint64(s.handle))
if err != nil {
return err
}
r := s.c.call(s.c.api.clearBindings, uint64(s.handle))
return s.c.error(r[0])
}
@@ -63,11 +57,8 @@ func (s *Stmt) ClearBindings() error {
//
// https://www.sqlite.org/c3ref/step.html
func (s *Stmt) Step() bool {
r, err := s.c.api.step.Call(s.c.ctx, uint64(s.handle))
if err != nil {
s.err = err
return false
}
s.c.checkInterrupt()
r := s.c.call(s.c.api.step, uint64(s.handle))
if r[0] == _ROW {
return true
}
@@ -95,6 +86,42 @@ func (s *Stmt) Exec() error {
return s.Reset()
}
// BindCount returns the number of SQL parameters in the prepared statement.
//
// https://www.sqlite.org/c3ref/bind_parameter_count.html
func (s *Stmt) BindCount() int {
r := s.c.call(s.c.api.bindCount,
uint64(s.handle))
return int(r[0])
}
// BindIndex returns the index of a parameter in the prepared statement
// given its name.
//
// https://www.sqlite.org/c3ref/bind_parameter_index.html
func (s *Stmt) BindIndex(name string) int {
defer s.c.arena.reset()
namePtr := s.c.arena.string(name)
r := s.c.call(s.c.api.bindIndex,
uint64(s.handle), uint64(namePtr))
return int(r[0])
}
// BindName returns the name of a parameter in the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://www.sqlite.org/c3ref/bind_parameter_name.html
func (s *Stmt) BindName(param int) string {
r := s.c.call(s.c.api.bindName,
uint64(s.handle), uint64(param))
ptr := uint32(r[0])
if ptr == 0 {
return ""
}
return s.c.mem.readString(ptr, _MAX_STRING)
}
// BindBool binds a bool to the prepared statement.
// The leftmost SQL parameter has an index of 1.
// SQLite does not have a separate boolean storage class.
@@ -121,11 +148,8 @@ func (s *Stmt) BindInt(param int, value int) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindInt64(param int, value int64) error {
r, err := s.c.api.bindInteger.Call(s.c.ctx,
r := s.c.call(s.c.api.bindInteger,
uint64(s.handle), uint64(param), uint64(value))
if err != nil {
return err
}
return s.c.error(r[0])
}
@@ -134,11 +158,8 @@ func (s *Stmt) BindInt64(param int, value int64) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindFloat(param int, value float64) error {
r, err := s.c.api.bindFloat.Call(s.c.ctx,
r := s.c.call(s.c.api.bindFloat,
uint64(s.handle), uint64(param), math.Float64bits(value))
if err != nil {
return err
}
return s.c.error(r[0])
}
@@ -148,13 +169,10 @@ func (s *Stmt) BindFloat(param int, value float64) error {
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindText(param int, value string) error {
ptr := s.c.newString(value)
r, err := s.c.api.bindText.Call(s.c.ctx,
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
s.c.api.destructor, _UTF8)
if err != nil {
return err
}
return s.c.error(r[0])
}
@@ -165,13 +183,20 @@ func (s *Stmt) BindText(param int, value string) error {
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindBlob(param int, value []byte) error {
ptr := s.c.newBytes(value)
r, err := s.c.api.bindBlob.Call(s.c.ctx,
r := s.c.call(s.c.api.bindBlob,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
s.c.api.destructor)
if err != nil {
return err
}
return s.c.error(r[0])
}
// BindZeroBlob binds a zero-filled, length n BLOB to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindZeroBlob(param int, n int64) error {
r := s.c.call(s.c.api.bindZeroBlob,
uint64(s.handle), uint64(param), uint64(n))
return s.c.error(r[0])
}
@@ -180,24 +205,60 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindNull(param int) error {
r, err := s.c.api.bindNull.Call(s.c.ctx,
r := s.c.call(s.c.api.bindNull,
uint64(s.handle), uint64(param))
if err != nil {
return err
}
return s.c.error(r[0])
}
// BindTime binds a [time.Time] to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
switch v := format.Encode(value).(type) {
case string:
s.BindText(param, v)
case int64:
s.BindInt64(param, v)
case float64:
s.BindFloat(param, v)
default:
panic(assertErr())
}
return nil
}
// ColumnCount returns the number of columns in a result set.
//
// https://www.sqlite.org/c3ref/column_count.html
func (s *Stmt) ColumnCount() int {
r := s.c.call(s.c.api.columnCount,
uint64(s.handle))
return int(r[0])
}
// ColumnName returns the name of the result column.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_name.html
func (s *Stmt) ColumnName(col int) string {
r := s.c.call(s.c.api.columnName,
uint64(s.handle), uint64(col))
ptr := uint32(r[0])
if ptr == 0 {
panic(oomErr)
}
return s.c.mem.readString(ptr, _MAX_STRING)
}
// ColumnType returns the initial [Datatype] of the result column.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnType(col int) Datatype {
r, err := s.c.api.columnType.Call(s.c.ctx,
r := s.c.call(s.c.api.columnType,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
return Datatype(r[0])
}
@@ -228,11 +289,8 @@ func (s *Stmt) ColumnInt(col int) int {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnInt64(col int) int64 {
r, err := s.c.api.columnInteger.Call(s.c.ctx,
r := s.c.call(s.c.api.columnInteger,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
return int64(r[0])
}
@@ -241,42 +299,55 @@ func (s *Stmt) ColumnInt64(col int) int64 {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnFloat(col int) float64 {
r, err := s.c.api.columnFloat.Call(s.c.ctx,
r := s.c.call(s.c.api.columnFloat,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
return math.Float64frombits(r[0])
}
// ColumnTime returns the value of the result column as a [time.Time].
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time {
var v any
switch s.ColumnType(col) {
case INTEGER:
v = s.ColumnInt64(col)
case FLOAT:
v = s.ColumnFloat(col)
case TEXT, BLOB:
v = s.ColumnText(col)
case NULL:
return time.Time{}
default:
panic(assertErr())
}
t, err := format.Decode(v)
if err != nil {
s.err = err
}
return t
}
// ColumnText returns the value of the result column as a string.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnText(col int) string {
r, err := s.c.api.columnText.Call(s.c.ctx,
r := s.c.call(s.c.api.columnText,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
r, err = s.c.api.errcode.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r = s.c.call(s.c.api.errcode, uint64(s.handle))
s.err = s.c.error(r[0])
return ""
}
r, err = s.c.api.columnBytes.Call(s.c.ctx,
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
mem := s.c.mem.view(ptr, uint32(r[0]))
mem := s.c.mem.view(ptr, r[0])
return string(mem)
}
@@ -286,28 +357,34 @@ func (s *Stmt) ColumnText(col int) string {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
r, err := s.c.api.columnBlob.Call(s.c.ctx,
r := s.c.call(s.c.api.columnBlob,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
r, err = s.c.api.errcode.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r = s.c.call(s.c.api.errcode, uint64(s.handle))
s.err = s.c.error(r[0])
return buf[0:0]
}
r, err = s.c.api.columnBytes.Call(s.c.ctx,
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
mem := s.c.mem.view(ptr, uint32(r[0]))
mem := s.c.mem.view(ptr, r[0])
return append(buf[0:0], mem...)
}
// Return true if stmt is an empty SQL statement.
// This is used as an optimization.
// It's OK to always return false here.
func emptyStatement(stmt string) bool {
for _, b := range []byte(stmt) {
switch b {
case ' ', '\n', '\r', '\t', '\v', '\f':
case ';':
default:
return false
}
}
return true
}

View File

@@ -1,361 +1,60 @@
package sqlite3
import (
"math"
"testing"
)
func TestStmt(t *testing.T) {
func Test_emptyStatement(t *testing.T) {
t.Parallel()
tests := []struct {
name string
stmt string
want bool
}{
{"empty", "", true},
{"space", " ", true},
{"separator", ";\n ", true},
{"begin", "BEGIN", false},
{"select", "SELECT 1;", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := emptyStatement(tt.stmt); got != tt.want {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}
func Fuzz_emptyStatement(f *testing.F) {
f.Add("")
f.Add(" ")
f.Add(";\n ")
f.Add("; ;\v")
f.Add("BEGIN")
f.Add("SELECT 1;")
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
f.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`INSERT INTO test(col) VALUES(?)`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
err = stmt.BindBool(1, false)
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.ClearBindings()
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindBool(1, true)
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindInt(1, 2)
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindFloat(1, math.Pi)
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindNull(1)
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindText(1, "")
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindText(1, "text")
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindBlob(1, []byte("blob"))
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindBlob(1, nil)
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
// The table should have: 0, NULL, 1, 2, π, NULL, "", "text", `blob`, NULL
stmt, _, err = db.Prepare(`SELECT col FROM test`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != INTEGER {
t.Errorf("got %v, want INTEGER", got)
f.Fuzz(func(t *testing.T, sql string) {
// If empty, SQLite parses it as empty.
if emptyStatement(sql) {
stmt, tail, err := db.Prepare(sql)
if err != nil {
t.Errorf("%q, %v", sql, err)
}
if stmt != nil {
t.Errorf("%q, %v", sql, stmt)
}
if tail != "" {
t.Errorf("%q", sql)
}
stmt.Close()
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "0" {
t.Errorf("got %q, want zero", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "0" {
t.Errorf("got %q, want zero", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != true {
t.Errorf("got %v, want true", got)
}
if got := stmt.ColumnInt(0); got != 1 {
t.Errorf("got %v, want one", got)
}
if got := stmt.ColumnFloat(0); got != 1 {
t.Errorf("got %v, want one", got)
}
if got := stmt.ColumnText(0); got != "1" {
t.Errorf("got %q, want one", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "1" {
t.Errorf("got %q, want one", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != true {
t.Errorf("got %v, want true", got)
}
if got := stmt.ColumnInt(0); got != 2 {
t.Errorf("got %v, want two", got)
}
if got := stmt.ColumnFloat(0); got != 2 {
t.Errorf("got %v, want two", got)
}
if got := stmt.ColumnText(0); got != "2" {
t.Errorf("got %q, want two", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "2" {
t.Errorf("got %q, want two", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != FLOAT {
t.Errorf("got %v, want FLOAT", got)
}
if got := stmt.ColumnBool(0); got != true {
t.Errorf("got %v, want true", got)
}
if got := stmt.ColumnInt(0); got != 3 {
t.Errorf("got %v, want three", got)
}
if got := stmt.ColumnFloat(0); got != math.Pi {
t.Errorf("got %v, want π", got)
}
if got := stmt.ColumnText(0); got != "3.14159265358979" {
t.Errorf("got %q, want π", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "3.14159265358979" {
t.Errorf("got %q, want π", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "text" {
t.Errorf(`got %q, want "text"`, got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "text" {
t.Errorf(`got %q, want "text"`, got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "blob" {
t.Errorf(`got %q, want "blob"`, got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "blob" {
t.Errorf(`got %q, want "blob"`, got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}
func TestStmt_Close(t *testing.T) {
var stmt *Stmt
stmt.Close()
})
}

297
tests/blob_test.go Normal file
View File

@@ -0,0 +1,297 @@
package tests
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"io"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestBlob(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
size := blob.Size()
if size != 1024 {
t.Errorf("got %d, want 1024", size)
}
var data [1280]byte
_, err = rand.Read(data[:])
if err != nil {
t.Fatal(err)
}
_, err = io.Copy(blob, bytes.NewReader(data[:size/2]))
if err != nil {
t.Fatal(err)
}
_, err = io.Copy(blob, bytes.NewReader(data[:]))
if !errors.Is(err, sqlite3.ERROR) {
t.Fatal("want error")
}
_, err = io.Copy(blob, bytes.NewReader(data[size/2:size]))
if err != nil {
t.Fatal(err)
}
_, err = blob.Seek(size/4, io.SeekStart)
if err != nil {
t.Fatal(err)
}
if got, err := io.ReadAll(blob); err != nil {
t.Fatal(err)
} else if !bytes.Equal(got, data[size/4:size]) {
t.Errorf("got %q, want %q", got, data[size/4:size])
}
if n, err := blob.Read(make([]byte, 1)); n != 0 || err != io.EOF {
t.Errorf("got (%d, %v), want (0, EOF)", n, err)
}
if err := blob.Close(); err != nil {
t.Fatal(err)
}
if err := db.Close(); err != nil {
t.Fatal(err)
}
}
func TestBlob_invalid(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
_, err = db.OpenBlob("", "test", "col", db.LastInsertRowID(), false)
if !errors.Is(err, sqlite3.ERROR) {
t.Fatal("want error")
}
}
func TestBlob_Write_readonly(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), false)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
_, err = blob.Write([]byte("data"))
if !errors.Is(err, sqlite3.READONLY) {
t.Fatal("want error")
}
}
func TestBlob_Read_expired(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), false)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
err = db.Exec(`DELETE FROM test`)
if err != nil {
t.Fatal(err)
}
_, err = io.ReadAll(blob)
if !errors.Is(err, sqlite3.ABORT) {
t.Fatal("want error", err)
}
}
func TestBlob_Seek(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
_, err = blob.Seek(0, 10)
if err == nil {
t.Fatal("want error")
}
_, err = blob.Seek(-1, io.SeekCurrent)
if err == nil {
t.Fatal("want error")
}
n, err := blob.Seek(1, io.SeekEnd)
if err != nil {
t.Fatal(err)
}
if n != blob.Size()+1 {
t.Errorf("got %d, want %d", n, blob.Size())
}
_, err = blob.Write([]byte("data"))
if !errors.Is(err, sqlite3.ERROR) {
t.Fatal("want error")
}
}
func TestBlob_Reopen(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
var rowids []int64
for i := 0; i < 100; i++ {
err = db.Exec(`INSERT INTO test VALUES (zeroblob(10))`)
if err != nil {
t.Fatal(err)
}
rowids = append(rowids, db.LastInsertRowID())
}
var blob *sqlite3.Blob
for i, rowid := range rowids {
if i > 0 {
err = blob.Reopen(rowid)
} else {
blob, err = db.OpenBlob("main", "test", "col", rowid, true)
}
if err != nil {
t.Fatal(err)
}
_, err = fmt.Fprintf(blob, "blob %d\n", i)
if err != nil {
t.Fatal(err)
}
}
if err := blob.Close(); err != nil {
t.Fatal(err)
}
for i, rowid := range rowids {
if i > 0 {
err = blob.Reopen(rowid)
} else {
blob, err = db.OpenBlob("main", "test", "col", rowid, false)
}
if err != nil {
t.Fatal(err)
}
var got int
_, err = fmt.Fscanf(blob, "blob %d\n", &got)
if err != nil {
t.Fatal(err)
}
if got != i {
t.Errorf("got %d, want %d", got, i)
}
}
if err := blob.Close(); err != nil {
t.Fatal(err)
}
}

182
tests/bradfitz/sql_test.go Normal file
View File

@@ -0,0 +1,182 @@
package bradfitz
// Adapted from: https://github.com/bradfitz/go-sql-test
import (
"database/sql"
"fmt"
"math/rand"
"path/filepath"
"sync"
"testing"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
type Tester interface {
RunTest(*testing.T, func(params))
}
var (
sqlite Tester = sqliteDB{}
)
const TablePrefix = "gosqltest_"
type sqliteDB struct{}
type params struct {
dbType Tester
*testing.T
*sql.DB
}
func (t params) mustExec(sql string, args ...interface{}) sql.Result {
res, err := t.DB.Exec(sql, args...)
if err != nil {
t.Fatalf("Error running %q: %v", sql, err)
}
return res
}
func (sqliteDB) RunTest(t *testing.T, fn func(params)) {
db, err := sql.Open("sqlite3", "file:"+
filepath.Join(t.TempDir(), "foo.db")+
"?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)&_pragma=synchronous(off)")
if err != nil {
t.Fatalf("foo.db open fail: %v", err)
}
fn(params{sqlite, t, db})
if err := db.Close(); err != nil {
t.Fatalf("foo.db close fail: %v", err)
}
}
func TestBlobs_SQLite(t *testing.T) { sqlite.RunTest(t, testBlobs) }
func testBlobs(t params) {
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
t.mustExec("create table " + TablePrefix + "foo (id integer primary key, bar blob)")
t.mustExec("insert into "+TablePrefix+"foo (id, bar) values(?,?)", 0, blob)
want := fmt.Sprintf("%x", blob)
b := make([]byte, 16)
err := t.QueryRow("select bar from "+TablePrefix+"foo where id = ?", 0).Scan(&b)
got := fmt.Sprintf("%x", b)
if err != nil {
t.Errorf("[]byte scan: %v", err)
} else if got != want {
t.Errorf("for []byte, got %q; want %q", got, want)
}
err = t.QueryRow("select bar from "+TablePrefix+"foo where id = ?", 0).Scan(&got)
want = string(blob)
if err != nil {
t.Errorf("string scan: %v", err)
} else if got != want {
t.Errorf("for string, got %q; want %q", got, want)
}
}
func TestManyQueryRow_SQLite(t *testing.T) { sqlite.RunTest(t, testManyQueryRow) }
func testManyQueryRow(t params) {
if testing.Short() {
t.Skip("skipping in short mode")
}
t.mustExec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))")
t.mustExec("insert into "+TablePrefix+"foo (id, name) values(?,?)", 1, "bob")
var name string
for i := 0; i < 10000; i++ {
err := t.QueryRow("select name from "+TablePrefix+"foo where id = ?", 1).Scan(&name)
if err != nil || name != "bob" {
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
}
}
}
func TestTxQuery_SQLite(t *testing.T) { sqlite.RunTest(t, testTxQuery) }
func testTxQuery(t params) {
tx, err := t.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
_, err = tx.Exec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))")
if err != nil {
t.Logf("cannot drop table "+TablePrefix+"foo: %s", err)
}
_, err = tx.Exec("insert into "+TablePrefix+"foo (id, name) values(?,?)", 1, "bob")
if err != nil {
t.Fatal(err)
}
r, err := tx.Query("select name from "+TablePrefix+"foo where id = ?", 1)
if err != nil {
t.Fatal(err)
}
defer r.Close()
if !r.Next() {
if r.Err() != nil {
t.Fatal(err)
}
t.Fatal("expected one rows")
}
var name string
err = r.Scan(&name)
if err != nil {
t.Fatal(err)
}
}
func TestPreparedStmt_SQLite(t *testing.T) { sqlite.RunTest(t, testPreparedStmt) }
func testPreparedStmt(t params) {
if testing.Short() {
t.Skip("skipping in short mode")
}
t.mustExec("CREATE TABLE " + TablePrefix + "t (count INT)")
sel, err := t.Prepare("SELECT count FROM " + TablePrefix + "t ORDER BY count DESC")
if err != nil {
t.Fatalf("prepare 1: %v", err)
}
ins, err := t.Prepare("INSERT INTO " + TablePrefix + "t (count) VALUES (?)")
if err != nil {
t.Fatalf("prepare 2: %v", err)
}
for n := 1; n <= 3; n++ {
if _, err := ins.Exec(n); err != nil {
t.Fatalf("insert(%d) = %v", n, err)
}
}
const nRuns = 10
var wg sync.WaitGroup
for i := 0; i < nRuns; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10; j++ {
count := 0
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
t.Errorf("Query: %v", err)
return
}
if _, err := ins.Exec(rand.Intn(100)); err != nil {
t.Errorf("Insert: %v", err)
return
}
}
}()
}
wg.Wait()
}

View File

@@ -1,4 +1,4 @@
package compile_empty
package compile
import (
"testing"

View File

@@ -1,4 +1,4 @@
package compile_empty
package compile
import (
"testing"
@@ -6,7 +6,7 @@ import (
"github.com/ncruces/go-sqlite3"
)
func TestCompile_empty(t *testing.T) {
func TestCompile_missing(t *testing.T) {
sqlite3.Path = "sqlite3.wasm"
_, err := sqlite3.Open(":memory:")
if err == nil {

View File

@@ -0,0 +1,14 @@
package compile
import (
"testing"
"github.com/ncruces/go-sqlite3"
)
func TestCompile_nil(t *testing.T) {
_, err := sqlite3.Open(":memory:")
if err == nil {
t.Error("want error")
}
}

260
tests/conn_test.go Normal file
View File

@@ -0,0 +1,260 @@
package tests
import (
"context"
"errors"
"strings"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestConn_Open_dir(t *testing.T) {
t.Parallel()
_, err := sqlite3.Open(".")
if err == nil {
t.Fatal("want error")
}
if !errors.Is(err, sqlite3.CANTOPEN) {
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
}
}
func TestConn_Close(t *testing.T) {
var conn *sqlite3.Conn
conn.Close()
}
func TestConn_Close_BUSY(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`BEGIN`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
err = db.Close()
if err == nil {
t.Fatal("want error")
}
if !errors.Is(err, sqlite3.BUSY) {
t.Errorf("got %v, want sqlite3.BUSY", err)
}
var terr interface{ Temporary() bool }
if !errors.As(err, &terr) || !terr.Temporary() {
t.Error("not temporary", err)
}
if got := err.Error(); got != `sqlite3: database is locked: unable to close due to unfinalized statements or unfinished backups` {
t.Error("got message:", got)
}
}
func TestConn_SetInterrupt(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
// Interrupt doesn't interrupt this.
err = db.Exec(`SELECT 1`)
if err != nil {
t.Fatal(err)
}
db.SetInterrupt(context.Background())
stmt, _, err := db.Prepare(`
WITH RECURSIVE
fibonacci (curr, next)
AS (
SELECT 0, 1
UNION ALL
SELECT next, curr + next FROM fibonacci
LIMIT 1e6
)
SELECT min(curr) FROM fibonacci
`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
db.SetInterrupt(ctx)
cancel()
// Interrupting works.
err = stmt.Exec()
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
// Interrupting sticks.
err = db.Exec(`SELECT 1`)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
db.SetInterrupt(ctx)
// Interrupting can be cleared.
err = db.Exec(`SELECT 1`)
if err != nil {
t.Fatal(err)
}
}
func TestConn_Prepare_empty(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(``)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt != nil {
t.Error("want nil")
}
}
func TestConn_Prepare_tail(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, tail, err := db.Prepare(`SELECT 1; -- HERE`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if !strings.Contains(tail, "-- HERE") {
t.Errorf("got %q", tail)
}
}
func TestConn_Prepare_invalid(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
var serr *sqlite3.Error
_, _, err = db.Prepare(`SELECT`)
if err == nil {
t.Fatal("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message:", got)
}
_, _, err = db.Prepare(`SELECT * FRM sqlite_schema`)
if err == nil {
t.Fatal("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.ERROR", err)
}
if rc := serr.Code(); rc != sqlite3.ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := serr.SQL(); got != `FRM sqlite_schema` {
t.Error("got SQL:", got)
}
if got := serr.Error(); got != `sqlite3: SQL logic error: near "FRM": syntax error` {
t.Error("got message:", got)
}
}
func TestConn_MustPrepare_empty(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(``)
t.Error("want panic")
}
func TestConn_MustPrepare_tail(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(`SELECT 1; -- HERE`)
t.Error("want panic")
}
func TestConn_MustPrepare_invalid(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(`SELECT`)
t.Error("want panic")
}
func TestConn_Pragma(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.Pragma("encoding=''")
t.Error("want panic")
}

View File

@@ -1,7 +1,6 @@
package tests
import (
"os"
"path/filepath"
"testing"
@@ -14,16 +13,12 @@ func TestDB_memory(t *testing.T) {
}
func TestDB_file(t *testing.T) {
dir, err := os.MkdirTemp("", "sqlite3-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
testDB(t, filepath.Join(dir, "test.db"))
testDB(t, filepath.Join(t.TempDir(), "test.db"))
}
func testDB(t *testing.T, name string) {
t.Parallel()
db, err := sqlite3.Open(name)
if err != nil {
t.Fatal(err)
@@ -35,32 +30,41 @@ func testDB(t *testing.T, name string) {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}
changes := db.Changes()
if changes != 3 {
t.Errorf("got %d want 3", changes)
}
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
row := 0
ids := []int{0, 1, 2}
names := []string{"go", "zig", "whatever"}
for ; stmt.Step(); row++ {
if ids[row] != stmt.ColumnInt(0) {
t.Errorf("got %d, want %d", stmt.ColumnInt(0), ids[row])
id := stmt.ColumnInt(0)
name := stmt.ColumnText(1)
if id != ids[row] {
t.Errorf("got %d, want %d", id, ids[row])
}
if names[row] != stmt.ColumnText(1) {
t.Errorf("got %q, want %q", stmt.ColumnText(1), names[row])
if name != names[row] {
t.Errorf("got %q, want %q", name, names[row])
}
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
if row != 3 {
t.Errorf("got %d rows, want %d", row, len(ids))
t.Errorf("got %d, want %d", row, len(ids))
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
err = stmt.Close()

View File

@@ -1,26 +0,0 @@
package tests
import (
"errors"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestDir(t *testing.T) {
_, err := sqlite3.Open(".")
if err == nil {
t.Fatal("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.CANTOPEN {
t.Errorf("got %d, want sqlite3.CANTOPEN", rc)
}
if got := err.Error(); got != `sqlite3: unable to open database file` {
t.Error("got message: ", got)
}
}

103
tests/driver_test.go Normal file
View File

@@ -0,0 +1,103 @@
package tests
import (
"context"
"database/sql"
"testing"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestDriver(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
_, err = conn.ExecContext(ctx,
`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
res, err := conn.ExecContext(ctx,
`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}
changes, err := res.RowsAffected()
if err != nil {
t.Fatal(err)
}
if changes != 3 {
t.Errorf("got %d want 3", changes)
}
stmt, err := conn.PrepareContext(context.Background(),
`SELECT id, name FROM users`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
t.Fatal(err)
}
defer rows.Close()
row := 0
ids := []int{0, 1, 2}
names := []string{"go", "zig", "whatever"}
for ; rows.Next(); row++ {
var id int
var name string
err := rows.Scan(&id, &name)
if err != nil {
t.Fatal(err)
}
if id != ids[row] {
t.Errorf("got %d, want %d", id, ids[row])
}
if name != names[row] {
t.Errorf("got %q, want %q", name, names[row])
}
}
if row != 3 {
t.Errorf("got %d, want %d", row, len(ids))
}
err = rows.Close()
if err != nil {
t.Fatal(err)
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
err = conn.Close()
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,188 @@
package tests
import (
"io"
"os"
"os/exec"
"path/filepath"
"testing"
"golang.org/x/sync/errgroup"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestParallel(t *testing.T) {
name := filepath.Join(t.TempDir(), "test.db")
testParallel(t, name, 1000)
testIntegrity(t, name)
}
func TestMultiProcess(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
name := filepath.Join(t.TempDir(), "test.db")
t.Setenv("TestMultiProcess_dbname", name)
cmd := exec.Command("go", "test", "-v", "-run", "TestChildProcess")
out, err := cmd.StdoutPipe()
if err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
var buf [3]byte
// Wait for child to start.
if _, err := io.ReadFull(out, buf[:]); err != nil || string(buf[:]) != "===" {
t.Fatal(err)
}
testParallel(t, name, 1000)
if err := cmd.Wait(); err != nil {
t.Error(err)
}
testIntegrity(t, name)
}
func TestChildProcess(t *testing.T) {
name := os.Getenv("TestMultiProcess_dbname")
if name == "" || testing.Short() {
t.SkipNow()
}
testParallel(t, name, 1000)
}
func testParallel(t *testing.T, name string, n int) {
writer := func() error {
db, err := sqlite3.Open(name)
if err != nil {
return err
}
defer db.Close()
err = db.Exec(`
PRAGMA busy_timeout=10000;
PRAGMA synchronous=off;
PRAGMA locking_mode=normal;
PRAGMA journal_mode=truncate;
`)
if err != nil {
return err
}
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
return err
}
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
return err
}
return db.Close()
}
reader := func() error {
db, err := sqlite3.Open(name)
if err != nil {
return err
}
defer db.Close()
err = db.Exec(`
PRAGMA busy_timeout=10000;
PRAGMA locking_mode=normal;
`)
if err != nil {
return err
}
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
if err != nil {
return err
}
defer stmt.Close()
row := 0
for stmt.Step() {
row++
}
if err := stmt.Err(); err != nil {
return err
}
if row%3 != 0 {
t.Errorf("got %d rows, want multiple of 3", row)
}
err = stmt.Close()
if err != nil {
return err
}
return db.Close()
}
err := writer()
if err != nil {
t.Fatal(err)
}
var group errgroup.Group
group.SetLimit(4)
for i := 0; i < n; i++ {
if i&7 != 7 {
group.Go(reader)
} else {
group.Go(writer)
}
}
err = group.Wait()
if err != nil {
t.Error(err)
}
}
func testIntegrity(t *testing.T, name string) {
db, err := sqlite3.Open(name)
if err != nil {
t.Fatal(err)
}
defer db.Close()
test := `PRAGMA integrity_check`
if testing.Short() {
test = `PRAGMA quick_check`
}
stmt, _, err := db.Prepare(test)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
if row := stmt.ColumnText(0); row != "ok" {
t.Error(row)
}
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -1,111 +0,0 @@
package tests
import (
"os"
"path/filepath"
"runtime"
"testing"
"golang.org/x/sync/errgroup"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestParallel(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip()
}
dir, err := os.MkdirTemp("", "sqlite3-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
writer := func() error {
db, err := sqlite3.Open(filepath.Join(dir, "test.db"))
if err != nil {
return err
}
defer db.Close()
err = db.Exec(`
PRAGMA locking_mode = NORMAL;
PRAGMA busy_timeout = 1000;
`)
if err != nil {
return err
}
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
return err
}
err = db.Exec(`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
return err
}
return db.Close()
}
reader := func() error {
db, err := sqlite3.Open(filepath.Join(dir, "test.db"))
if err != nil {
return err
}
defer db.Close()
err = db.Exec(`
PRAGMA locking_mode = NORMAL;
PRAGMA busy_timeout = 1000;
`)
if err != nil {
return err
}
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
if err != nil {
return err
}
row := 0
for stmt.Step() {
row++
}
if err := stmt.Err(); err != nil {
return err
}
if row%3 != 0 {
t.Errorf("got %d rows, want multiple of 3", row)
}
err = stmt.Close()
if err != nil {
return err
}
return db.Close()
}
err = writer()
if err != nil {
t.Fatal(err)
}
var group errgroup.Group
group.SetLimit(4)
for i := 0; i < 32; i++ {
if i&7 != 7 {
group.Go(reader)
} else {
group.Go(writer)
}
}
err = group.Wait()
if err != nil {
t.Fatal(err)
}
}

463
tests/stmt_test.go Normal file
View File

@@ -0,0 +1,463 @@
package tests
import (
"math"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
)
func TestStmt(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`INSERT INTO test VALUES (?)`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if got := stmt.BindCount(); got != 1 {
t.Errorf("got %d, want 1", got)
}
if err := stmt.BindBool(1, false); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.BindBool(1, true); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.BindInt(1, 2); err != nil {
t.Fatal(err)
}
if err = stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.BindFloat(1, math.Pi); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.BindNull(1); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.BindText(1, ""); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.BindText(1, "text"); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.BindBlob(1, []byte("blob")); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.BindBlob(1, nil); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.BindZeroBlob(1, 4); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.ClearBindings(); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
// The table should have: 0, 1, 2, π, NULL, "", "text", "blob", NULL, "\0\0\0\0", NULL
stmt, _, err = db.Prepare(`SELECT col FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "0" {
t.Errorf("got %q, want zero", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "0" {
t.Errorf("got %q, want zero", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != true {
t.Errorf("got %v, want true", got)
}
if got := stmt.ColumnInt(0); got != 1 {
t.Errorf("got %v, want one", got)
}
if got := stmt.ColumnFloat(0); got != 1 {
t.Errorf("got %v, want one", got)
}
if got := stmt.ColumnText(0); got != "1" {
t.Errorf("got %q, want one", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "1" {
t.Errorf("got %q, want one", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != true {
t.Errorf("got %v, want true", got)
}
if got := stmt.ColumnInt(0); got != 2 {
t.Errorf("got %v, want two", got)
}
if got := stmt.ColumnFloat(0); got != 2 {
t.Errorf("got %v, want two", got)
}
if got := stmt.ColumnText(0); got != "2" {
t.Errorf("got %q, want two", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "2" {
t.Errorf("got %q, want two", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.FLOAT {
t.Errorf("got %v, want FLOAT", got)
}
if got := stmt.ColumnBool(0); got != true {
t.Errorf("got %v, want true", got)
}
if got := stmt.ColumnInt(0); got != 3 {
t.Errorf("got %v, want three", got)
}
if got := stmt.ColumnFloat(0); got != math.Pi {
t.Errorf("got %v, want π", got)
}
if got := stmt.ColumnText(0); got != "3.14159265358979" {
t.Errorf("got %q, want π", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "3.14159265358979" {
t.Errorf("got %q, want π", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "text" {
t.Errorf(`got %q, want "text"`, got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "text" {
t.Errorf(`got %q, want "text"`, got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "blob" {
t.Errorf(`got %q, want "blob"`, got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "blob" {
t.Errorf(`got %q, want "blob"`, got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "\x00\x00\x00\x00" {
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "\x00\x00\x00\x00" {
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
}
if err := stmt.Close(); err != nil {
t.Fatal(err)
}
if err := db.Close(); err != nil {
t.Fatal(err)
}
}
func TestStmt_Close(t *testing.T) {
var stmt *sqlite3.Stmt
stmt.Close()
}
func TestStmt_BindName(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
want := []string{"", "", "", "", "?5", ":AAA", "@AAA", "$AAA"}
stmt, _, err := db.Prepare(`SELECT ?, ?5, :AAA, @AAA, $AAA`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if got := stmt.BindCount(); got != len(want) {
t.Errorf("got %d, want %d", got, len(want))
}
for i, name := range want {
id := i + 1
if got := stmt.BindName(id); got != name {
t.Errorf("got %q, want %q", got, name)
}
if name == "" {
id = 0
}
if got := stmt.BindIndex(name); got != id {
t.Errorf("got %d, want %d", got, id)
}
}
}
func TestStmt_ColumnTime(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT ?, ?, ?, datetime(), unixepoch(), julianday(), NULL, 'abc'`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
err = stmt.BindTime(1, reference, sqlite3.TimeFormat4)
if err != nil {
t.Fatal(err)
}
err = stmt.BindTime(2, reference, sqlite3.TimeFormatUnixMilli)
if err != nil {
t.Fatal(err)
}
err = stmt.BindTime(3, reference, sqlite3.TimeFormatJulianDay)
if err != nil {
t.Fatal(err)
}
if now := time.Now(); stmt.Step() {
if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); !got.Equal(reference) {
t.Errorf("got %v, want %v", got, reference)
}
if got := stmt.ColumnTime(1, sqlite3.TimeFormatAuto); !got.Equal(reference) {
t.Errorf("got %v, want %v", got, reference)
}
if got := stmt.ColumnTime(2, sqlite3.TimeFormatAuto); got.Sub(reference).Abs() > time.Millisecond {
t.Errorf("got %v, want %v", got, reference)
}
if got := stmt.ColumnTime(3, sqlite3.TimeFormatAuto); got.Sub(now).Abs() > time.Second {
t.Errorf("got %v, want %v", got, now)
}
if got := stmt.ColumnTime(4, sqlite3.TimeFormatAuto); got.Sub(now).Abs() > time.Second {
t.Errorf("got %v, want %v", got, now)
}
if got := stmt.ColumnTime(5, sqlite3.TimeFormatAuto); got.Sub(now).Abs() > time.Second/10 {
t.Errorf("got %v, want %v", got, now)
}
if got := stmt.ColumnTime(6, sqlite3.TimeFormatAuto); got != (time.Time{}) {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnTime(7, sqlite3.TimeFormatAuto); got != (time.Time{}) {
t.Errorf("got %v, want zero", got)
}
if stmt.Err() == nil {
t.Errorf("want error")
}
}
}

504
tests/tx_test.go Normal file
View File

@@ -0,0 +1,504 @@
package tests
import (
"context"
"errors"
"testing"
"github.com/ncruces/go-sqlite3"
)
func TestConn_Transaction_exec(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
errFailed := errors.New("failed")
count := func() int {
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
return stmt.ColumnInt(0)
}
t.Fatal(stmt.Err())
return 0
}
insert := func(succeed bool) (err error) {
tx := db.Begin()
defer tx.End(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
t.Fatal(err)
}
if succeed {
return nil
}
return errFailed
}
err = insert(true)
if err != nil {
t.Fatal(err)
}
if got := count(); got != 1 {
t.Errorf("got %d, want 1", got)
}
err = insert(true)
if err != nil {
t.Fatal(err)
}
if got := count(); got != 2 {
t.Errorf("got %d, want 2", got)
}
err = insert(false)
if err != errFailed {
t.Errorf("got %v, want errFailed", err)
}
if got := count(); got != 2 {
t.Errorf("got %d, want 2", got)
}
}
func TestConn_Transaction_panic(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES ('one');`)
if err != nil {
t.Fatal(err)
}
panics := func() (err error) {
tx := db.Begin()
defer tx.End(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
return err
}
panic("omg!")
}
defer func() {
p := recover()
if p != "omg!" {
t.Errorf("got %v, want panic", p)
}
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
return
}
t.Fatal(stmt.Err())
}()
err = panics()
if err != nil {
t.Error(err)
}
}
func TestConn_Transaction_interrupt(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
tx, err := db.BeginImmediate()
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
}
tx.End(&err)
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
tx, err = db.BeginExclusive()
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (2)`)
if err != nil {
t.Fatal(err)
}
cancel()
_, err = db.BeginImmediate()
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = db.Exec(`INSERT INTO test VALUES (3)`)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
var nilErr error
tx.End(&nilErr)
if !errors.Is(nilErr, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", nilErr)
}
db.SetInterrupt(context.Background())
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
}
err = stmt.Err()
if err != nil {
t.Error(err)
}
}
func TestConn_Transaction_rollback(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
tx := db.Begin()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`COMMIT`)
if err != nil {
t.Fatal(err)
}
tx.End(&err)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
}
err = stmt.Err()
if err != nil {
t.Error(err)
}
}
func TestConn_Savepoint_exec(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
errFailed := errors.New("failed")
count := func() int {
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
return stmt.ColumnInt(0)
}
t.Fatal(stmt.Err())
return 0
}
insert := func(succeed bool) (err error) {
defer db.Savepoint()(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
t.Fatal(err)
}
if succeed {
return nil
}
return errFailed
}
err = insert(true)
if err != nil {
t.Fatal(err)
}
if got := count(); got != 1 {
t.Errorf("got %d, want 1", got)
}
err = insert(true)
if err != nil {
t.Fatal(err)
}
if got := count(); got != 2 {
t.Errorf("got %d, want 2", got)
}
err = insert(false)
if err != errFailed {
t.Errorf("got %v, want errFailed", err)
}
if got := count(); got != 2 {
t.Errorf("got %d, want 2", got)
}
}
func TestConn_Savepoint_panic(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES ('one');`)
if err != nil {
t.Fatal(err)
}
panics := func() (err error) {
defer db.Savepoint()(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
return err
}
panic("omg!")
}
defer func() {
p := recover()
if p != "omg!" {
t.Errorf("got %v, want panic", p)
}
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
return
}
t.Fatal(stmt.Err())
}()
err = panics()
if err != nil {
t.Error(err)
}
}
func TestConn_Savepoint_interrupt(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
release := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
}
release(&err)
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
release1 := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (2)`)
if err != nil {
t.Fatal(err)
}
release2 := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (3)`)
if err != nil {
t.Fatal(err)
}
cancel()
db.Savepoint()(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = db.Exec(`INSERT INTO test VALUES (4)`)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = context.Canceled
release2(&err)
if err != context.Canceled {
t.Fatal(err)
}
var nilErr error
release1(&nilErr)
if !errors.Is(nilErr, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", nilErr)
}
db.SetInterrupt(context.Background())
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
}
err = stmt.Err()
if err != nil {
t.Error(err)
}
}
func TestConn_Savepoint_rollback(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
release := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`COMMIT`)
if err != nil {
t.Fatal(err)
}
release(&err)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
}
err = stmt.Err()
if err != nil {
t.Error(err)
}
}

332
time.go Normal file
View File

@@ -0,0 +1,332 @@
package sqlite3
import (
"math"
"strconv"
"strings"
"time"
"github.com/ncruces/julianday"
)
// TimeFormat specifies how to encode/decode time values.
//
// See the documentation for the [TimeFormatDefault] constant
// for formats recognized by SQLite.
//
// https://www.sqlite.org/lang_datefunc.html
type TimeFormat string
// TimeFormats recognized by SQLite to encode/decode time values.
//
// https://www.sqlite.org/lang_datefunc.html
const (
TimeFormatDefault TimeFormat = "" // time.RFC3339Nano
// Text formats
TimeFormat1 TimeFormat = "2006-01-02"
TimeFormat2 TimeFormat = "2006-01-02 15:04"
TimeFormat3 TimeFormat = "2006-01-02 15:04:05"
TimeFormat4 TimeFormat = "2006-01-02 15:04:05.000"
TimeFormat5 TimeFormat = "2006-01-02T15:04"
TimeFormat6 TimeFormat = "2006-01-02T15:04:05"
TimeFormat7 TimeFormat = "2006-01-02T15:04:05.000"
TimeFormat8 TimeFormat = "15:04"
TimeFormat9 TimeFormat = "15:04:05"
TimeFormat10 TimeFormat = "15:04:05.000"
TimeFormat2TZ = TimeFormat2 + "Z07:00"
TimeFormat3TZ = TimeFormat3 + "Z07:00"
TimeFormat4TZ = TimeFormat4 + "Z07:00"
TimeFormat5TZ = TimeFormat5 + "Z07:00"
TimeFormat6TZ = TimeFormat6 + "Z07:00"
TimeFormat7TZ = TimeFormat7 + "Z07:00"
TimeFormat8TZ = TimeFormat8 + "Z07:00"
TimeFormat9TZ = TimeFormat9 + "Z07:00"
TimeFormat10TZ = TimeFormat10 + "Z07:00"
// Numeric formats
TimeFormatJulianDay TimeFormat = "julianday"
TimeFormatUnix TimeFormat = "unixepoch"
TimeFormatUnixFrac TimeFormat = "unixepoch_frac"
TimeFormatUnixMilli TimeFormat = "unixepoch_milli" // not an SQLite format
TimeFormatUnixMicro TimeFormat = "unixepoch_micro" // not an SQLite format
TimeFormatUnixNano TimeFormat = "unixepoch_nano" // not an SQLite format
// Auto
TimeFormatAuto TimeFormat = "auto"
)
// Encode encodes a time value using this format.
//
// [TimeFormatDefault] and [TimeFormatAuto] encode using [time.RFC3339Nano],
// with nanosecond accuracy, and preserving any timezone offset.
//
// This is the format used by the database/sql driver:
// [database/sql.Row.Scan] is able to decode as [time.Time]
// values encoded with [time.RFC3339Nano].
//
// Time values encoded with [time.RFC3339Nano] cannot be sorted as strings
// to produce a time-ordered sequence.
// Use [TimeFormat7] for time-ordered encoding.
//
// Formats [TimeFormat1] through [TimeFormat10]
// convert time values to UTC before encoding.
//
// Returns a string for the text formats,
// a float64 for [TimeFormatJulianDay] and [TimeFormatUnixFrac],
// or an int64 for the other numeric formats.
//
// https://www.sqlite.org/lang_datefunc.html
func (f TimeFormat) Encode(t time.Time) any {
switch f {
// Numeric formats
case TimeFormatJulianDay:
return julianday.Float(t)
case TimeFormatUnix:
return t.Unix()
case TimeFormatUnixFrac:
return float64(t.Unix()) + float64(t.Nanosecond())*1e-9
case TimeFormatUnixMilli:
return t.UnixMilli()
case TimeFormatUnixMicro:
return t.UnixMicro()
case TimeFormatUnixNano:
return t.UnixNano()
// Special formats
case TimeFormatDefault, TimeFormatAuto:
f = time.RFC3339Nano
// SQLite assumes UTC if unspecified.
case
TimeFormat1, TimeFormat2,
TimeFormat3, TimeFormat4,
TimeFormat5, TimeFormat6,
TimeFormat7, TimeFormat8,
TimeFormat9, TimeFormat10:
t = t.UTC()
}
return t.Format(string(f))
}
// Decode decodes a time value using this format.
//
// The time value can be a string, an int64, or a float64.
//
// Formats [TimeFormat8] through [TimeFormat10]
// (and [TimeFormat8TZ] through [TimeFormat10TZ])
// assume a date of 2000-01-01.
//
// The timezone indicator and fractional seconds are always optional
// for formats [TimeFormat2] through [TimeFormat10]
// (and [TimeFormat2TZ] through [TimeFormat10TZ]).
//
// [TimeFormatAuto] implements (and extends) the SQLite auto modifier.
// Julian day numbers are safe to use for historical dates,
// from 4712BC through 9999AD.
// Unix timestamps (expressed in seconds, milliseconds, microseconds, or nanoseconds),
// are safe to use for current events, from 1980 through at least 2260.
// Unix timestamps before 1980 may be misinterpreted as julian day numbers,
// or have the wrong time unit.
//
// https://www.sqlite.org/lang_datefunc.html
func (f TimeFormat) Decode(v any) (time.Time, error) {
switch f {
// Numeric formats
case TimeFormatJulianDay:
switch v := v.(type) {
case string:
return julianday.Parse(v)
case float64:
return julianday.FloatTime(v), nil
case int64:
return julianday.Time(v, 0), nil
default:
return time.Time{}, timeErr
}
case TimeFormatUnix, TimeFormatUnixFrac:
if s, ok := v.(string); ok {
f, err := strconv.ParseFloat(s, 64)
if err != nil {
return time.Time{}, err
}
v = f
}
switch v := v.(type) {
case float64:
sec, frac := math.Modf(v)
nsec := math.Floor(frac * 1e9)
return time.Unix(int64(sec), int64(nsec)), nil
case int64:
return time.Unix(v, 0), nil
default:
return time.Time{}, timeErr
}
case TimeFormatUnixMilli:
if s, ok := v.(string); ok {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return time.Time{}, err
}
v = i
}
switch v := v.(type) {
case float64:
return time.UnixMilli(int64(math.Floor(v))), nil
case int64:
return time.UnixMilli(int64(v)), nil
default:
return time.Time{}, timeErr
}
case TimeFormatUnixMicro:
if s, ok := v.(string); ok {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return time.Time{}, err
}
v = i
}
switch v := v.(type) {
case float64:
return time.UnixMicro(int64(math.Floor(v))), nil
case int64:
return time.UnixMicro(int64(v)), nil
default:
return time.Time{}, timeErr
}
case TimeFormatUnixNano:
if s, ok := v.(string); ok {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return time.Time{}, timeErr
}
v = i
}
switch v := v.(type) {
case float64:
return time.Unix(0, int64(math.Floor(v))), nil
case int64:
return time.Unix(0, int64(v)), nil
default:
return time.Time{}, timeErr
}
// Special formats
case TimeFormatAuto:
switch s := v.(type) {
case string:
i, err := strconv.ParseInt(s, 10, 64)
if err == nil {
v = i
break
}
f, err := strconv.ParseFloat(s, 64)
if err == nil {
v = f
break
}
dates := []TimeFormat{
TimeFormat6TZ, TimeFormat6, TimeFormat3TZ, TimeFormat3,
TimeFormat5TZ, TimeFormat5, TimeFormat2TZ, TimeFormat2,
TimeFormat1,
}
for _, f := range dates {
t, err := time.Parse(string(f), s)
if err == nil {
return t, nil
}
}
times := []TimeFormat{
TimeFormat9TZ, TimeFormat9, TimeFormat8TZ, TimeFormat8,
}
for _, f := range times {
t, err := time.Parse(string(f), s)
if err == nil {
return t.AddDate(2000, 0, 0), nil
}
}
}
switch v := v.(type) {
case float64:
if 0 <= v && v < 5373484.5 {
return TimeFormatJulianDay.Decode(v)
}
if v < 253402300800 {
return TimeFormatUnixFrac.Decode(v)
}
if v < 253402300800_000 {
return TimeFormatUnixMilli.Decode(v)
}
if v < 253402300800_000000 {
return TimeFormatUnixMicro.Decode(v)
}
return TimeFormatUnixNano.Decode(v)
case int64:
if 0 <= v && v < 5373485 {
return TimeFormatJulianDay.Decode(v)
}
if v < 253402300800 {
return TimeFormatUnixFrac.Decode(v)
}
if v < 253402300800_000 {
return TimeFormatUnixMilli.Decode(v)
}
if v < 253402300800_000000 {
return TimeFormatUnixMicro.Decode(v)
}
return TimeFormatUnixNano.Decode(v)
default:
return time.Time{}, timeErr
}
case
TimeFormat2, TimeFormat2TZ,
TimeFormat3, TimeFormat3TZ,
TimeFormat4, TimeFormat4TZ,
TimeFormat5, TimeFormat5TZ,
TimeFormat6, TimeFormat6TZ,
TimeFormat7, TimeFormat7TZ:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
}
return f.parseRelaxed(s)
case
TimeFormat8, TimeFormat8TZ,
TimeFormat9, TimeFormat9TZ,
TimeFormat10, TimeFormat10TZ:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
}
t, err := f.parseRelaxed(s)
return t.AddDate(2000, 0, 0), err
default:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
}
if f == "" {
f = time.RFC3339Nano
}
return time.Parse(string(f), s)
}
}
func (f TimeFormat) parseRelaxed(s string) (time.Time, error) {
fs := string(f)
fs = strings.TrimSuffix(fs, "Z07:00")
fs = strings.TrimSuffix(fs, ".000")
t, err := time.Parse(fs+"Z07:00", s)
if err != nil {
return time.Parse(fs, s)
}
return t, nil
}

118
time_test.go Normal file
View File

@@ -0,0 +1,118 @@
package sqlite3
import (
"reflect"
"testing"
"time"
)
func TestTimeFormat_Encode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt TimeFormat
time time.Time
want any
}{
{TimeFormatDefault, reference, "2013-10-07T04:23:19.12-04:00"},
{TimeFormatJulianDay, reference, 2456572.849526851851852},
{TimeFormatUnix, reference, int64(1381134199)},
{TimeFormatUnixFrac, reference, 1381134199.120},
{TimeFormatUnixMilli, reference, int64(1381134199_120)},
{TimeFormatUnixMicro, reference, int64(1381134199_120000)},
{TimeFormatUnixNano, reference, int64(1381134199_120000000)},
{TimeFormat7, reference, "2013-10-07T08:23:19.120"},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
if got := tt.fmt.Encode(tt.time); !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q.Encode(%v) = %v, want %v", tt.fmt, tt.time, got, tt.want)
}
})
}
}
func TestTimeFormat_Decode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
reftime := time.Date(2000, 1, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt TimeFormat
val any
want time.Time
wantDelta time.Duration
wantErr bool
}{
{TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
{TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
{TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
{TimeFormatJulianDay, false, time.Time{}, 0, true},
{TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
{TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
{TimeFormatUnix, int64(1381134199), reference, time.Second, false},
{TimeFormatUnix, "abc", time.Time{}, 0, true},
{TimeFormatUnix, false, time.Time{}, 0, true},
{TimeFormatUnixMilli, "1381134199120", reference, 0, false},
{TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
{TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
{TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
{TimeFormatUnixMilli, false, time.Time{}, 0, true},
{TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
{TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
{TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
{TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
{TimeFormatUnixMicro, false, time.Time{}, 0, true},
{TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
{TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
{TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
{TimeFormatUnixNano, "abc", time.Time{}, 0, true},
{TimeFormatUnixNano, false, time.Time{}, 0, true},
{TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
{TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
{TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199", reference, time.Second, false},
{TimeFormatAuto, "1381134199120", reference, 0, false},
{TimeFormatAuto, "1381134199120000", reference, 0, false},
{TimeFormatAuto, "1381134199120000000", reference, 0, false},
{TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{TimeFormatAuto, "04:23:19.12-04:00", reftime, 0, false},
{TimeFormatAuto, "abc", time.Time{}, 0, true},
{TimeFormatAuto, false, time.Time{}, 0, true},
{TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
{TimeFormat9, "04:23:19.12-04:00", reftime, 0, false},
{TimeFormat9, "08:23:19.12", reftime, 0, false},
{TimeFormat3, false, time.Time{}, 0, true},
{TimeFormat9, false, time.Time{}, 0, true},
{TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
{TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
{TimeFormatDefault, false, time.Time{}, 0, true},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
got, err := tt.fmt.Decode(tt.val)
if (err != nil) != tt.wantErr {
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
return
}
if tt.want.Sub(got).Abs() > tt.wantDelta {
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
}
})
}
}

177
tx.go Normal file
View File

@@ -0,0 +1,177 @@
package sqlite3
import (
"context"
"errors"
"fmt"
"runtime"
)
type Tx struct {
c *Conn
}
// Begin starts a deferred transaction.
//
// https://www.sqlite.org/lang_transaction.html
func (c *Conn) Begin() Tx {
err := c.Exec(`BEGIN DEFERRED`)
if err != nil && !errors.Is(err, INTERRUPT) {
panic(err)
}
return Tx{c}
}
// BeginImmediate starts an immediate transaction.
//
// https://www.sqlite.org/lang_transaction.html
func (c *Conn) BeginImmediate() (Tx, error) {
err := c.Exec(`BEGIN IMMEDIATE`)
if err != nil {
return Tx{}, err
}
return Tx{c}, nil
}
// BeginExclusive starts an exclusive transaction.
//
// https://www.sqlite.org/lang_transaction.html
func (c *Conn) BeginExclusive() (Tx, error) {
err := c.Exec(`BEGIN EXCLUSIVE`)
if err != nil {
return Tx{}, err
}
return Tx{c}, nil
}
// End calls either [Tx.Commit] or [Tx.Rollback]
// depending on whether *error points to a nil or non-nil error.
//
// This is meant to be deferred:
//
// func doWork(conn *sqlite3.Conn) (err error) {
// tx := conn.Begin()
// defer tx.End(&err)
//
// // ... do work in the transaction
// }
//
// https://www.sqlite.org/lang_transaction.html
func (tx Tx) End(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
if tx.c.GetAutocommit() {
// There is nothing to commit/rollback.
return
}
if *errp == nil && recovered == nil {
// Success path.
*errp = tx.Commit()
if *errp == nil {
return
}
// Possible interrupt, fall through to the error path.
}
// Error path.
err := tx.Rollback()
if err != nil {
panic(err)
}
}
// Commit commits the transaction.
//
// https://www.sqlite.org/lang_transaction.html
func (tx Tx) Commit() error {
return tx.c.Exec(`COMMIT`)
}
// Rollback rollsback the transaction.
//
// https://www.sqlite.org/lang_transaction.html
func (tx Tx) Rollback() error {
// ROLLBACK even if the connection has been interrupted.
old := tx.c.SetInterrupt(context.Background())
defer tx.c.SetInterrupt(old)
return tx.c.Exec(`ROLLBACK`)
}
// Savepoint creates a named SQLite transaction using SAVEPOINT.
//
// On success Savepoint returns a release func that will call either
// RELEASE or ROLLBACK depending on whether the parameter *error
// points to a nil or non-nil error.
//
// This is meant to be deferred:
//
// func doWork(conn *sqlite3.Conn) (err error) {
// defer conn.Savepoint()(&err)
//
// // ... do work in the transaction
// }
//
// https://www.sqlite.org/lang_savepoint.html
func (c *Conn) Savepoint() (release func(*error)) {
name := "sqlite3.Savepoint" // names can be reused
var pc [1]uintptr
if n := runtime.Callers(2, pc[:]); n > 0 {
frames := runtime.CallersFrames(pc[:n])
frame, _ := frames.Next()
if frame.Function != "" {
name = frame.Function
}
}
err := c.Exec(fmt.Sprintf("SAVEPOINT %q;", name))
if err != nil {
if errors.Is(err, INTERRUPT) {
return func(errp *error) {
if *errp == nil {
*errp = err
}
}
}
panic(err)
}
return func(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
if c.GetAutocommit() {
// There is nothing to commit/rollback.
return
}
if *errp == nil && recovered == nil {
// Success path.
// RELEASE the savepoint successfully.
*errp = c.Exec(fmt.Sprintf("RELEASE %q;", name))
if *errp == nil {
return
}
// Possible interrupt, fall through to the error path.
}
// Error path.
// Always ROLLBACK even if the connection has been interrupted.
old := c.SetInterrupt(context.Background())
defer c.SetInterrupt(old)
err := c.Exec(fmt.Sprintf("ROLLBACK TO %q;", name))
if err != nil {
panic(err)
}
err = c.Exec(fmt.Sprintf("RELEASE %q;", name))
if err != nil {
panic(err)
}
}
}

94
vfs.go
View File

@@ -18,37 +18,44 @@ import (
"github.com/tetratelabs/wazero/sys"
)
func vfsInstantiate(ctx context.Context, r wazero.Runtime) (err error) {
func vfsInstantiate(ctx context.Context, r wazero.Runtime) {
wasi := r.NewHostModuleBuilder("wasi_snapshot_preview1")
wasi.NewFunctionBuilder().WithFunc(vfsExit).Export("proc_exit")
_, err = wasi.Instantiate(ctx)
_, err := wasi.Instantiate(ctx)
if err != nil {
return err
panic(err)
}
env := r.NewHostModuleBuilder("env")
env.NewFunctionBuilder().WithFunc(vfsLocaltime).Export("go_localtime")
env.NewFunctionBuilder().WithFunc(vfsRandomness).Export("go_randomness")
env.NewFunctionBuilder().WithFunc(vfsSleep).Export("go_sleep")
env.NewFunctionBuilder().WithFunc(vfsCurrentTime).Export("go_current_time")
env.NewFunctionBuilder().WithFunc(vfsCurrentTime64).Export("go_current_time_64")
env.NewFunctionBuilder().WithFunc(vfsFullPathname).Export("go_full_pathname")
env.NewFunctionBuilder().WithFunc(vfsDelete).Export("go_delete")
env.NewFunctionBuilder().WithFunc(vfsAccess).Export("go_access")
env.NewFunctionBuilder().WithFunc(vfsOpen).Export("go_open")
env.NewFunctionBuilder().WithFunc(vfsClose).Export("go_close")
env.NewFunctionBuilder().WithFunc(vfsRead).Export("go_read")
env.NewFunctionBuilder().WithFunc(vfsWrite).Export("go_write")
env.NewFunctionBuilder().WithFunc(vfsTruncate).Export("go_truncate")
env.NewFunctionBuilder().WithFunc(vfsSync).Export("go_sync")
env.NewFunctionBuilder().WithFunc(vfsFileSize).Export("go_file_size")
env.NewFunctionBuilder().WithFunc(vfsLock).Export("go_lock")
env.NewFunctionBuilder().WithFunc(vfsUnlock).Export("go_unlock")
env.NewFunctionBuilder().WithFunc(vfsCheckReservedLock).Export("go_check_reserved_lock")
env.NewFunctionBuilder().WithFunc(vfsLocaltime).Export("os_localtime")
env.NewFunctionBuilder().WithFunc(vfsRandomness).Export("os_randomness")
env.NewFunctionBuilder().WithFunc(vfsSleep).Export("os_sleep")
env.NewFunctionBuilder().WithFunc(vfsCurrentTime).Export("os_current_time")
env.NewFunctionBuilder().WithFunc(vfsCurrentTime64).Export("os_current_time_64")
env.NewFunctionBuilder().WithFunc(vfsFullPathname).Export("os_full_pathname")
env.NewFunctionBuilder().WithFunc(vfsDelete).Export("os_delete")
env.NewFunctionBuilder().WithFunc(vfsAccess).Export("os_access")
env.NewFunctionBuilder().WithFunc(vfsOpen).Export("os_open")
env.NewFunctionBuilder().WithFunc(vfsClose).Export("os_close")
env.NewFunctionBuilder().WithFunc(vfsRead).Export("os_read")
env.NewFunctionBuilder().WithFunc(vfsWrite).Export("os_write")
env.NewFunctionBuilder().WithFunc(vfsTruncate).Export("os_truncate")
env.NewFunctionBuilder().WithFunc(vfsSync).Export("os_sync")
env.NewFunctionBuilder().WithFunc(vfsFileSize).Export("os_file_size")
env.NewFunctionBuilder().WithFunc(vfsLock).Export("os_lock")
env.NewFunctionBuilder().WithFunc(vfsUnlock).Export("os_unlock")
env.NewFunctionBuilder().WithFunc(vfsCheckReservedLock).Export("os_check_reserved_lock")
env.NewFunctionBuilder().WithFunc(vfsFileControl).Export("os_file_control")
_, err = env.Instantiate(ctx)
return err
if err != nil {
panic(err)
}
}
type vfsOSMethods bool
const vfsOS vfsOSMethods = false
func vfsExit(ctx context.Context, mod api.Module, exitCode uint32) {
// Ensure other callers see the exit code.
_ = mod.CloseWithExitCode(ctx, exitCode)
@@ -78,7 +85,7 @@ func vfsLocaltime(ctx context.Context, mod api.Module, t uint64, pTm uint32) uin
}
func vfsRandomness(ctx context.Context, mod api.Module, pVfs, nByte, zByte uint32) uint32 {
mem := memory{mod}.view(zByte, nByte)
mem := memory{mod}.view(zByte, uint64(nByte))
n, _ := rand.Reader.Read(mem)
return uint32(n)
}
@@ -112,11 +119,11 @@ func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull
// Or using [os.Readlink] to resolve a symbolic link (as the Unix VFS did).
// This might be buggy on Windows (the Windows VFS doesn't try).
siz := uint32(len(abs) + 1)
if siz > nFull {
size := uint64(len(abs) + 1)
if size > uint64(nFull) {
return uint32(CANTOPEN_FULLPATH)
}
mem := memory{mod}.view(zFull, siz)
mem := memory{mod}.view(zFull, size)
mem[len(abs)] = 0
copy(mem, abs)
@@ -145,7 +152,7 @@ func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32)
return _OK
}
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags AccessFlag, pResOut uint32) uint32 {
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags _AccessFlag, pResOut uint32) uint32 {
// Consider using [syscall.Access] for [ACCESS_READWRITE]/[ACCESS_READ]
// (as the Unix VFS does).
@@ -154,7 +161,7 @@ func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags Ac
var res uint32
switch {
case flags == ACCESS_EXISTS:
case flags == _ACCESS_EXISTS:
switch {
case err == nil:
res = 1
@@ -166,7 +173,7 @@ func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags Ac
case err == nil:
var want fs.FileMode = syscall.S_IRUSR
if flags == ACCESS_READWRITE {
if flags == _ACCESS_READWRITE {
want |= syscall.S_IWUSR
}
if fi.IsDir() {
@@ -217,17 +224,10 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, fla
}
if flags&OPEN_DELETEONCLOSE != 0 {
deleteOnClose(file)
vfsOS.DeleteOnClose(file)
}
info, err := file.Stat()
if err != nil {
return uint32(CANTOPEN)
}
if info.IsDir() {
return uint32(CANTOPEN_ISDIR)
}
id := vfsGetOpenFileID(file, info)
id := vfsGetFileID(file)
vfsFilePtr{mod, pFile}.SetID(id).SetLock(_NO_LOCK)
if pOutFlags != 0 {
@@ -238,7 +238,7 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, fla
func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 {
id := vfsFilePtr{mod, pFile}.ID()
err := vfsReleaseOpenFile(id)
err := vfsCloseFile(id)
if err != nil {
return uint32(IOERR_CLOSE)
}
@@ -246,7 +246,7 @@ func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 {
}
func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 {
buf := memory{mod}.view(zBuf, iAmt)
buf := memory{mod}.view(zBuf, uint64(iAmt))
file := vfsFilePtr{mod, pFile}.OSFile()
n, err := file.ReadAt(buf, int64(iOfst))
@@ -263,7 +263,7 @@ func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfs
}
func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 {
buf := memory{mod}.view(zBuf, iAmt)
buf := memory{mod}.view(zBuf, uint64(iAmt))
file := vfsFilePtr{mod, pFile}.OSFile()
_, err := file.WriteAt(buf, int64(iOfst))
@@ -304,3 +304,15 @@ func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) uint3
memory{mod}.writeUint64(pSize, uint64(off))
return _OK
}
func vfsFileControl(ctx context.Context, pFile, op, pArg uint32) uint32 {
// SQLite calls vfsFileControl with these opcodes:
// SQLITE_FCNTL_SIZE_HINT
// SQLITE_FCNTL_PRAGMA
// SQLITE_FCNTL_BUSYHANDLER
// SQLITE_FCNTL_HAS_MOVED
// SQLITE_FCNTL_SYNC
// SQLITE_FCNTL_COMMIT_PHASETWO
// SQLITE_FCNTL_PDB
return uint32(NOTFOUND)
}

View File

@@ -7,66 +7,35 @@ import (
"github.com/tetratelabs/wazero/api"
)
type vfsOpenFile struct {
file *os.File
info os.FileInfo
nref int
locker vfsFileLocker
}
var (
vfsOpenFiles []*vfsOpenFile
vfsOpenFiles []*os.File
vfsOpenFilesMtx sync.Mutex
)
func vfsGetOpenFileID(file *os.File, info os.FileInfo) uint32 {
func vfsGetFileID(file *os.File) uint32 {
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
// Reuse an already opened file.
for id, of := range vfsOpenFiles {
if of == nil {
continue
}
if os.SameFile(info, of.info) {
of.nref++
_ = file.Close()
return uint32(id)
}
}
of := &vfsOpenFile{
file: file,
info: info,
nref: 1,
locker: vfsFileLocker{file: file},
}
// Find an empty slot.
for id, ptr := range vfsOpenFiles {
if ptr == nil {
vfsOpenFiles[id] = of
vfsOpenFiles[id] = file
return uint32(id)
}
}
// Add a new slot.
id := len(vfsOpenFiles)
vfsOpenFiles = append(vfsOpenFiles, of)
return uint32(id)
vfsOpenFiles = append(vfsOpenFiles, file)
return uint32(len(vfsOpenFiles) - 1)
}
func vfsReleaseOpenFile(id uint32) error {
func vfsCloseFile(id uint32) error {
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
of := vfsOpenFiles[id]
if of.nref--; of.nref > 0 {
return nil
}
err := of.file.Close()
file := vfsOpenFiles[id]
vfsOpenFiles[id] = nil
return err
return file.Close()
}
type vfsFilePtr struct {
@@ -78,14 +47,7 @@ func (p vfsFilePtr) OSFile() *os.File {
id := p.ID()
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
return vfsOpenFiles[id].file
}
func (p vfsFilePtr) Locker() *vfsFileLocker {
id := p.ID()
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
return &vfsOpenFiles[id].locker
return vfsOpenFiles[id]
}
func (p vfsFilePtr) ID() uint32 {

View File

@@ -3,7 +3,6 @@ package sqlite3
import (
"context"
"os"
"sync"
"github.com/tetratelabs/wazero/api"
)
@@ -56,28 +55,20 @@ const (
type vfsLockState uint32
type vfsFileLocker struct {
sync.Mutex
file *os.File
state vfsLockState
shared int
}
func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) uint32 {
// SQLite never explicitly requests a pendig lock.
// Argument check. SQLite never explicitly requests a pendig lock.
if eLock != _SHARED_LOCK && eLock != _RESERVED_LOCK && eLock != _EXCLUSIVE_LOCK {
panic(assertErr())
}
ptr := vfsFilePtr{mod, pFile}
file := ptr.OSFile()
cLock := ptr.Lock()
// If we already have an equal or more restrictive lock, do nothing.
if cLock >= eLock {
return _OK
}
switch {
case cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK:
// Connection state check.
panic(assertErr())
case cLock == _NO_LOCK && eLock > _SHARED_LOCK:
// We never move from unlocked to anything higher than a shared lock.
panic(assertErr())
@@ -86,76 +77,54 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
panic(assertErr())
}
fLock := ptr.Locker()
fLock.Lock()
defer fLock.Unlock()
// If some other connection has a lock that precludes the requested lock, return BUSY.
if cLock != fLock.state && (eLock > _SHARED_LOCK || fLock.state >= _PENDING_LOCK) {
return uint32(BUSY)
}
// If a SHARED lock is requested, and some other connection has a SHARED or RESERVED lock,
// then increment the reference count and return OK.
if eLock == _SHARED_LOCK && (fLock.state == _SHARED_LOCK || fLock.state == _RESERVED_LOCK) {
if cLock != _NO_LOCK || fLock.shared <= 0 {
panic(assertErr())
}
ptr.SetLock(_SHARED_LOCK)
fLock.shared++
// If we already have an equal or more restrictive lock, do nothing.
if cLock >= eLock {
return _OK
}
// If control gets to this point, then actually go ahead and make
// operating system calls for the specified lock.
switch eLock {
case _SHARED_LOCK:
if fLock.state != _NO_LOCK || fLock.shared != 0 {
// Must be unlocked to get SHARED.
if cLock != _NO_LOCK {
panic(assertErr())
}
if rc := fLock.GetShared(); rc != _OK {
// Test the PENDING lock before acquiring a new SHARED lock.
if locked, _ := vfsOS.CheckPendingLock(file); locked {
return uint32(BUSY)
}
if rc := vfsOS.GetSharedLock(file); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_SHARED_LOCK)
fLock.state = _SHARED_LOCK
fLock.shared = 1
return _OK
case _RESERVED_LOCK:
if fLock.state != _SHARED_LOCK || fLock.shared <= 0 {
// Must be SHARED to get RESERVED.
if cLock != _SHARED_LOCK {
panic(assertErr())
}
if rc := fLock.GetReserved(); rc != _OK {
if rc := vfsOS.GetReservedLock(file); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_RESERVED_LOCK)
fLock.state = _RESERVED_LOCK
return _OK
case _EXCLUSIVE_LOCK:
if fLock.state <= _NO_LOCK || fLock.state >= _EXCLUSIVE_LOCK || fLock.shared <= 0 {
// Must be SHARED, RESERVED or PENDING to get EXCLUSIVE.
if cLock <= _NO_LOCK || cLock >= _EXCLUSIVE_LOCK {
panic(assertErr())
}
// A PENDING lock is needed before acquiring an EXCLUSIVE lock.
if fLock.state == _RESERVED_LOCK {
if rc := fLock.GetPending(); rc != _OK {
if cLock == _RESERVED_LOCK {
if rc := vfsOS.GetPendingLock(file); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_PENDING_LOCK)
fLock.state = _PENDING_LOCK
}
// We are trying for an EXCLUSIVE lock but another connection is still holding a shared lock.
if fLock.shared > 1 {
return uint32(BUSY)
}
if rc := fLock.GetExclusive(); rc != _OK {
if rc := vfsOS.GetExclusiveLock(file); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_EXCLUSIVE_LOCK)
fLock.state = _EXCLUSIVE_LOCK
return _OK
default:
@@ -164,51 +133,41 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
}
func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) uint32 {
// Argument check.
if eLock != _NO_LOCK && eLock != _SHARED_LOCK {
panic(assertErr())
}
ptr := vfsFilePtr{mod, pFile}
file := ptr.OSFile()
cLock := ptr.Lock()
// Connection state check.
if cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK {
panic(assertErr())
}
// If we don't have a more restrictive lock, do nothing.
if cLock <= eLock {
return _OK
}
fLock := ptr.Locker()
fLock.Lock()
defer fLock.Unlock()
switch eLock {
case _SHARED_LOCK:
if rc := vfsOS.DowngradeLock(file, cLock); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_SHARED_LOCK)
return _OK
if fLock.shared <= 0 {
case _NO_LOCK:
rc := vfsOS.ReleaseLock(file, cLock)
ptr.SetLock(_NO_LOCK)
return uint32(rc)
default:
panic(assertErr())
}
if cLock > _SHARED_LOCK {
if cLock != fLock.state {
panic(assertErr())
}
if eLock == _SHARED_LOCK {
if rc := fLock.Downgrade(); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_SHARED_LOCK)
fLock.state = _SHARED_LOCK
return _OK
}
}
if eLock != _NO_LOCK {
panic(assertErr())
}
// Release the connection lock and decrement the shared lock counter.
// Release the file lock only when all connections have released the lock.
ptr.SetLock(_NO_LOCK)
if fLock.shared--; fLock.shared == 0 {
fLock.state = _NO_LOCK
return uint32(fLock.Release())
}
return _OK
}
func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) uint32 {
@@ -219,16 +178,9 @@ func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut ui
panic(assertErr())
}
fLock := ptr.Locker()
fLock.Lock()
defer fLock.Unlock()
file := ptr.OSFile()
if fLock.state >= _RESERVED_LOCK {
memory{mod}.writeUint32(pResOut, 1)
return _OK
}
locked, rc := fLock.CheckReserved()
locked, rc := vfsOS.CheckReservedLock(file)
var res uint32
if locked {
res = 1
@@ -236,3 +188,28 @@ func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut ui
memory{mod}.writeUint32(pResOut, res)
return uint32(rc)
}
func (vfsOSMethods) GetSharedLock(file *os.File) xErrorCode {
// Acquire the SHARED lock.
return vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE)
}
func (vfsOSMethods) GetReservedLock(file *os.File) xErrorCode {
// Acquire the RESERVED lock.
return vfsOS.writeLock(file, _RESERVED_BYTE, 1)
}
func (vfsOSMethods) GetPendingLock(file *os.File) xErrorCode {
// Acquire the PENDING lock.
return vfsOS.writeLock(file, _PENDING_BYTE, 1)
}
func (vfsOSMethods) CheckReservedLock(file *os.File) (bool, xErrorCode) {
// Test the RESERVED lock.
return vfsOS.checkLock(file, _RESERVED_BYTE, 1)
}
func (vfsOSMethods) CheckPendingLock(file *os.File) (bool, xErrorCode) {
// Test the PENDING lock.
return vfsOS.checkLock(file, _PENDING_BYTE, 1)
}

View File

@@ -3,29 +3,28 @@ package sqlite3
import (
"context"
"os"
"path/filepath"
"runtime"
"testing"
)
func Test_vfsLock(t *testing.T) {
// Other OSes lack open file descriptors locks.
switch runtime.GOOS {
case "linux", "darwin", "solaris", "windows":
//
case "linux", "darwin", "illumos", "windows":
break
default:
t.Skip()
t.Skip("OS lacks OFD locks")
}
name := filepath.Join(t.TempDir(), "test.db")
// Create a temporary file.
file1, err := os.CreateTemp("", "sqlite3-")
file1, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
if err != nil {
t.Fatal(err)
}
defer file1.Close()
name := file1.Name()
defer os.RemoveAll(name)
// Open the temporary file again.
file2, err := os.OpenFile(name, os.O_RDWR, 0)
if err != nil {
@@ -33,26 +32,14 @@ func Test_vfsLock(t *testing.T) {
}
defer file2.Close()
// Bypass open file reuse.
vfsOpenFiles = append(vfsOpenFiles, &vfsOpenFile{
file: file1,
nref: 1,
locker: vfsFileLocker{file: file1},
}, &vfsOpenFile{
file: file2,
nref: 1,
locker: vfsFileLocker{file: file2},
})
mem := newMemory(128)
mem.writeUint32(4+4, 0)
mem.writeUint32(16+4, 1)
const (
pFile1 = 4
pFile2 = 16
pOutput = 32
)
mem := newMemory(128)
vfsFilePtr{mem.mod, pFile1}.SetID(vfsGetFileID(file1)).SetLock(_NO_LOCK)
vfsFilePtr{mem.mod, pFile2}.SetID(vfsGetFileID(file2)).SetLock(_NO_LOCK)
rc := vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
if rc != _OK {
@@ -110,11 +97,27 @@ func Test_vfsLock(t *testing.T) {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsUnlock(context.TODO(), mem.mod, pFile2, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile1, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)

View File

@@ -7,6 +7,7 @@ import (
"io/fs"
"os"
"path/filepath"
"syscall"
"testing"
"time"
@@ -136,12 +137,12 @@ func Test_vfsFullPathname(t *testing.T) {
}
func Test_vfsDelete(t *testing.T) {
file, err := os.CreateTemp("", "sqlite3-")
name := filepath.Join(t.TempDir(), "test.db")
file, err := os.Create(name)
if err != nil {
t.Fatal(err)
}
name := file.Name()
defer os.RemoveAll(name)
file.Close()
mem := newMemory(128 + _MAX_PATHNAME)
@@ -163,16 +164,21 @@ func Test_vfsDelete(t *testing.T) {
}
func Test_vfsAccess(t *testing.T) {
dir, err := os.MkdirTemp("", "sqlite3-")
if err != nil {
dir := t.TempDir()
file := filepath.Join(t.TempDir(), "test.db")
if f, err := os.Create(file); err != nil {
t.Fatal(err)
} else {
f.Close()
}
if err := os.Chmod(file, syscall.S_IRUSR); err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(8, dir)
rc := vfsAccess(context.TODO(), mem.mod, 0, 8, ACCESS_EXISTS, 4)
rc := vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_EXISTS, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -180,13 +186,22 @@ func Test_vfsAccess(t *testing.T) {
t.Error("directory did not exist")
}
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, ACCESS_READWRITE, 4)
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(4); got != 1 {
t.Error("can't access directory")
}
mem.writeString(8, file)
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(4); got != 0 {
t.Error("can access file")
}
}
func Test_vfsFile(t *testing.T) {

View File

@@ -8,68 +8,38 @@ import (
"syscall"
)
func deleteOnClose(f *os.File) {
_ = os.Remove(f.Name())
func (vfsOSMethods) DeleteOnClose(file *os.File) {
_ = os.Remove(file.Name())
}
func (l *vfsFileLocker) GetShared() xErrorCode {
// A PENDING lock is needed before acquiring a SHARED lock.
if rc := l.readLock(_PENDING_BYTE, 1); rc != _OK {
return rc
}
// Acquire the SHARED lock.
rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE)
// Drop the temporary PENDING lock.
if rc2 := l.unlock(_PENDING_BYTE, 1); rc == _OK {
return rc2
}
return rc
}
func (l *vfsFileLocker) GetReserved() xErrorCode {
// Acquire the RESERVED lock.
return l.writeLock(_RESERVED_BYTE, 1)
}
func (l *vfsFileLocker) GetPending() xErrorCode {
// Acquire the PENDING lock.
return l.writeLock(_PENDING_BYTE, 1)
}
func (l *vfsFileLocker) GetExclusive() xErrorCode {
func (vfsOSMethods) GetExclusiveLock(file *os.File) xErrorCode {
// Acquire the EXCLUSIVE lock.
return l.writeLock(_SHARED_FIRST, _SHARED_SIZE)
return vfsOS.writeLock(file, _SHARED_FIRST, _SHARED_SIZE)
}
func (l *vfsFileLocker) Downgrade() xErrorCode {
// Downgrade to a SHARED lock.
if rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return IOERR_RDLOCK
func (vfsOSMethods) DowngradeLock(file *os.File, state vfsLockState) xErrorCode {
if state >= _EXCLUSIVE_LOCK {
// Downgrade to a SHARED lock.
if rc := vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return l.unlock(_PENDING_BYTE, 2)
return vfsOS.unlock(file, _PENDING_BYTE, 2)
}
func (l *vfsFileLocker) Release() xErrorCode {
func (vfsOSMethods) ReleaseLock(file *os.File, _ vfsLockState) xErrorCode {
// Release all locks.
return l.unlock(0, 0)
return vfsOS.unlock(file, 0, 0)
}
func (l *vfsFileLocker) CheckReserved() (bool, xErrorCode) {
// Test the RESERVED lock.
return l.checkLock(_RESERVED_BYTE, 1)
}
func (l *vfsFileLocker) unlock(start, len int64) xErrorCode {
err := l.fcntlSetLock(&syscall.Flock_t{
func (vfsOSMethods) unlock(file *os.File, start, len int64) xErrorCode {
err := vfsOS.fcntlSetLock(file, &syscall.Flock_t{
Type: syscall.F_UNLCK,
Start: start,
Len: len,
@@ -80,79 +50,84 @@ func (l *vfsFileLocker) unlock(start, len int64) xErrorCode {
return _OK
}
func (l *vfsFileLocker) readLock(start, len int64) xErrorCode {
return l.errorCode(l.fcntlSetLock(&syscall.Flock_t{
func (vfsOSMethods) readLock(file *os.File, start, len int64) xErrorCode {
return vfsOS.lockErrorCode(vfsOS.fcntlSetLock(file, &syscall.Flock_t{
Type: syscall.F_RDLCK,
Start: start,
Len: len,
}), IOERR_LOCK)
}), IOERR_RDLOCK)
}
func (l *vfsFileLocker) writeLock(start, len int64) xErrorCode {
return l.errorCode(l.fcntlSetLock(&syscall.Flock_t{
func (vfsOSMethods) writeLock(file *os.File, start, len int64) xErrorCode {
return vfsOS.lockErrorCode(vfsOS.fcntlSetLock(file, &syscall.Flock_t{
Type: syscall.F_WRLCK,
Start: start,
Len: len,
}), IOERR_LOCK)
}
func (l *vfsFileLocker) checkLock(start, len int64) (bool, xErrorCode) {
func (vfsOSMethods) checkLock(file *os.File, start, len int64) (bool, xErrorCode) {
lock := syscall.Flock_t{
Type: syscall.F_RDLCK,
Start: start,
Len: len,
}
if l.fcntlGetLock(&lock) != nil {
if vfsOS.fcntlGetLock(file, &lock) != nil {
return false, IOERR_CHECKRESERVEDLOCK
}
return lock.Type != syscall.F_UNLCK, _OK
}
func (l *vfsFileLocker) fcntlGetLock(lock *syscall.Flock_t) error {
F_GETLK := syscall.F_GETLK
func (vfsOSMethods) fcntlGetLock(file *os.File, lock *syscall.Flock_t) error {
var F_OFD_GETLK int
switch runtime.GOOS {
case "linux":
// https://github.com/torvalds/linux/blob/master/include/uapi/asm-generic/fcntl.h
F_GETLK = 36 // F_OFD_GETLK
F_OFD_GETLK = 36 // F_OFD_GETLK
case "darwin":
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
F_GETLK = 92 // F_OFD_GETLK
case "solaris":
F_OFD_GETLK = 92 // F_OFD_GETLK
case "illumos":
// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/fcntl.h
F_GETLK = 47 // F_OFD_GETLK
F_OFD_GETLK = 47 // F_OFD_GETLK
default:
return notImplErr
}
return syscall.FcntlFlock(l.file.Fd(), F_GETLK, lock)
return syscall.FcntlFlock(file.Fd(), F_OFD_GETLK, lock)
}
func (l *vfsFileLocker) fcntlSetLock(lock *syscall.Flock_t) error {
F_SETLK := syscall.F_SETLK
func (vfsOSMethods) fcntlSetLock(file *os.File, lock *syscall.Flock_t) error {
var F_OFD_SETLK int
switch runtime.GOOS {
case "linux":
// https://github.com/torvalds/linux/blob/master/include/uapi/asm-generic/fcntl.h
F_SETLK = 37 // F_OFD_SETLK
F_OFD_SETLK = 37 // F_OFD_SETLK
case "darwin":
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
F_SETLK = 90 // F_OFD_SETLK
case "solaris":
F_OFD_SETLK = 90 // F_OFD_SETLK
case "illumos":
// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/fcntl.h
F_SETLK = 48 // F_OFD_SETLK
F_OFD_SETLK = 48 // F_OFD_SETLK
default:
return notImplErr
}
return syscall.FcntlFlock(l.file.Fd(), F_SETLK, lock)
return syscall.FcntlFlock(file.Fd(), F_OFD_SETLK, lock)
}
func (*vfsFileLocker) errorCode(err error, def xErrorCode) xErrorCode {
func (vfsOSMethods) lockErrorCode(err error, def xErrorCode) xErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(syscall.Errno); ok {
switch errno {
case syscall.EACCES:
case syscall.EAGAIN:
case syscall.EBUSY:
case syscall.EINTR:
case syscall.ENOLCK:
case syscall.EDEADLK:
case syscall.ETIMEDOUT:
case
syscall.EACCES,
syscall.EAGAIN,
syscall.EBUSY,
syscall.EINTR,
syscall.ENOLCK,
syscall.EDEADLK,
syscall.ETIMEDOUT:
return xErrorCode(BUSY)
case syscall.EPERM:
return xErrorCode(PERM)

View File

@@ -7,84 +7,61 @@ import (
"golang.org/x/sys/windows"
)
func deleteOnClose(f *os.File) {}
func (vfsOSMethods) DeleteOnClose(file *os.File) {}
func (l *vfsFileLocker) GetShared() xErrorCode {
// A PENDING lock is needed before acquiring a SHARED lock.
if rc := l.readLock(_PENDING_BYTE, 1); rc != _OK {
return rc
}
// Acquire the SHARED lock.
rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE)
// Drop the temporary PENDING lock.
if rc2 := l.unlock(_PENDING_BYTE, 1); rc == _OK {
return rc2
}
return rc
}
func (l *vfsFileLocker) GetReserved() xErrorCode {
// Acquire the RESERVED lock.
return l.writeLock(_RESERVED_BYTE, 1)
}
func (l *vfsFileLocker) GetPending() xErrorCode {
// Acquire the PENDING lock.
return l.writeLock(_PENDING_BYTE, 1)
}
func (l *vfsFileLocker) GetExclusive() xErrorCode {
func (vfsOSMethods) GetExclusiveLock(file *os.File) xErrorCode {
// Release the SHARED lock.
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
vfsOS.unlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Acquire the EXCLUSIVE lock.
rc := l.writeLock(_SHARED_FIRST, _SHARED_SIZE)
rc := vfsOS.writeLock(file, _SHARED_FIRST, _SHARED_SIZE)
// Reacquire the SHARED lock.
if rc != _OK {
l.readLock(_SHARED_FIRST, _SHARED_SIZE)
vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE)
}
return rc
}
func (l *vfsFileLocker) Downgrade() xErrorCode {
// Release the SHARED lock.
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
func (vfsOSMethods) DowngradeLock(file *os.File, state vfsLockState) xErrorCode {
if state >= _EXCLUSIVE_LOCK {
// Release the SHARED lock.
vfsOS.unlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Reacquire the SHARED lock.
if rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE); rc != _OK {
// This should never happen.
// We should always be able to reacquire the read lock.
return IOERR_RDLOCK
// Reacquire the SHARED lock.
if rc := vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE); rc != _OK {
// This should never happen.
// We should always be able to reacquire the read lock.
return IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
l.unlock(_RESERVED_BYTE, 1)
l.unlock(_PENDING_BYTE, 1)
return _OK
}
func (l *vfsFileLocker) Release() xErrorCode {
// Release all locks.
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
l.unlock(_RESERVED_BYTE, 1)
l.unlock(_PENDING_BYTE, 1)
return _OK
}
func (l *vfsFileLocker) CheckReserved() (bool, xErrorCode) {
// Test the RESERVED lock.
rc := l.readLock(_RESERVED_BYTE, 1)
if rc == _OK {
l.unlock(_RESERVED_BYTE, 1)
if state >= _RESERVED_LOCK {
vfsOS.unlock(file, _RESERVED_BYTE, 1)
}
return rc != _OK, _OK
if state >= _PENDING_LOCK {
vfsOS.unlock(file, _PENDING_BYTE, 1)
}
return _OK
}
func (l *vfsFileLocker) unlock(start, len uint32) xErrorCode {
err := windows.UnlockFileEx(windows.Handle(l.file.Fd()),
func (vfsOSMethods) ReleaseLock(file *os.File, state vfsLockState) xErrorCode {
// Release all locks.
if state >= _RESERVED_LOCK {
vfsOS.unlock(file, _RESERVED_BYTE, 1)
}
if state >= _SHARED_LOCK {
vfsOS.unlock(file, _SHARED_FIRST, _SHARED_SIZE)
}
if state >= _PENDING_LOCK {
vfsOS.unlock(file, _PENDING_BYTE, 1)
}
return _OK
}
func (vfsOSMethods) unlock(file *os.File, start, len uint32) xErrorCode {
err := windows.UnlockFileEx(windows.Handle(file.Fd()),
0, len, 0, &windows.Overlapped{Offset: start})
if err != nil {
return IOERR_UNLOCK
@@ -92,21 +69,29 @@ func (l *vfsFileLocker) unlock(start, len uint32) xErrorCode {
return _OK
}
func (l *vfsFileLocker) readLock(start, len uint32) xErrorCode {
return l.errorCode(windows.LockFileEx(windows.Handle(l.file.Fd()),
func (vfsOSMethods) readLock(file *os.File, start, len uint32) xErrorCode {
return vfsOS.lockErrorCode(windows.LockFileEx(windows.Handle(file.Fd()),
windows.LOCKFILE_FAIL_IMMEDIATELY,
0, len, 0, &windows.Overlapped{Offset: start}),
IOERR_LOCK)
IOERR_RDLOCK)
}
func (l *vfsFileLocker) writeLock(start, len uint32) xErrorCode {
return l.errorCode(windows.LockFileEx(windows.Handle(l.file.Fd()),
func (vfsOSMethods) writeLock(file *os.File, start, len uint32) xErrorCode {
return vfsOS.lockErrorCode(windows.LockFileEx(windows.Handle(file.Fd()),
windows.LOCKFILE_FAIL_IMMEDIATELY|windows.LOCKFILE_EXCLUSIVE_LOCK,
0, len, 0, &windows.Overlapped{Offset: start}),
IOERR_LOCK)
}
func (*vfsFileLocker) errorCode(err error, def xErrorCode) xErrorCode {
func (vfsOSMethods) checkLock(file *os.File, start, len uint32) (bool, xErrorCode) {
rc := vfsOS.readLock(file, start, len)
if rc == _OK {
vfsOS.unlock(file, start, len)
}
return rc != _OK, _OK
}
func (vfsOSMethods) lockErrorCode(err error, def xErrorCode) xErrorCode {
if err == nil {
return _OK
}