Compare commits

...

65 Commits

Author SHA1 Message Date
Nuno Cruces
e4efb20c71 Generate coverage chart. 2023-03-18 03:51:05 +00:00
Nuno Cruces
2c9459d907 Add SQLite speedtest1. 2023-03-18 03:03:11 +00:00
Nuno Cruces
d0875e5fab Lock timeouts. 2023-03-18 01:13:31 +00:00
Nuno Cruces
15dec13f15 FCNTL_SIZE_HINT, refactor. 2023-03-17 17:13:03 +00:00
Nuno Cruces
f38e36109a FCNTL_HAS_MOVED. 2023-03-17 14:11:09 +00:00
Nuno Cruces
4cb65ccbd9 xFileControl, xDeviceCharacteristics, PSOW. 2023-03-17 13:39:19 +00:00
Nuno Cruces
f789c2fb8b OPEN_NOFOLLOW. 2023-03-16 12:27:44 +00:00
Nuno Cruces
c6a2617dfc Locking fixes. 2023-03-16 02:52:22 +00:00
Nuno Cruces
6fc0afcd12 Towards lock timeouts. 2023-03-15 13:58:16 +00:00
Nuno Cruces
77088962f5 SQLite 3.41.1. 2023-03-15 13:29:09 +00:00
Nuno Cruces
71da34861b Fix time collation. 2023-03-13 04:19:58 +00:00
Nuno Cruces
56e8281bdb Time collation tests. 2023-03-10 16:42:20 +00:00
Nuno Cruces
f61d430e65 Documentation. 2023-03-10 16:26:19 +00:00
Nuno Cruces
dbaed53b9a Sync and delete improvements. 2023-03-10 14:17:02 +00:00
Nuno Cruces
8b1bfd04e3 Simplify windows hacks. 2023-03-10 10:43:02 +00:00
Nuno Cruces
11c1687146 Time collation. 2023-03-09 14:42:29 +00:00
Nuno Cruces
94c43a8685 Use access syscall. 2023-03-09 01:59:46 +00:00
Nuno Cruces
a25159a070 Fix sharing violation. 2023-03-09 01:23:52 +00:00
Nuno Cruces
e007e9b060 Windows fixes. 2023-03-08 20:10:46 +00:00
Nuno Cruces
66a730893f Fix readonly transaction rollback. 2023-03-08 18:07:21 +00:00
Nuno Cruces
926adeb3f5 Remove MustPrepare. 2023-03-08 17:39:41 +00:00
Nuno Cruces
677f51bec1 Savepoint API. 2023-03-08 17:39:23 +00:00
Nuno Cruces
5d6f92b733 Documentation, tests, tweaks. 2023-03-08 13:29:33 +00:00
Nuno Cruces
f5747f19fb Tests. 2023-03-07 14:19:22 +00:00
Nuno Cruces
dfcdbf9c4c Online backup. 2023-03-07 12:15:29 +00:00
Nuno Cruces
ad1e8f4b0e Refactor. 2023-03-07 10:47:55 +00:00
Nuno Cruces
8f29882671 Pass mptest crash. 2023-03-07 04:37:55 +00:00
Nuno Cruces
6c96a019e6 Towards shared modules: refactor. 2023-03-07 04:37:55 +00:00
Nuno Cruces
d291738b81 Towards shared modules: refactor. 2023-03-07 04:37:55 +00:00
Nuno Cruces
c1263d4f33 Towards shared modules: refactor. 2023-03-07 04:37:55 +00:00
Nuno Cruces
1ebdc1aa93 Towards shared modules: refactor. 2023-03-07 04:37:55 +00:00
Nuno Cruces
4dd10f071a Towards shared modules: backup. 2023-03-07 04:37:55 +00:00
Nuno Cruces
7dbddfa5c0 Towards shared modules. 2023-03-07 04:37:55 +00:00
dependabot[bot]
ce5e035801 Bump golang.org/x/sys from 0.5.0 to 0.6.0
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.5.0 to 0.6.0.
- [Release notes](https://github.com/golang/sys/releases)
- [Commits](https://github.com/golang/sys/compare/v0.5.0...v0.6.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-03-07 04:11:45 +00:00
Nuno Cruces
8bb8367a36 Refactor mptest. 2023-03-06 14:27:49 +00:00
Nuno Cruces
9f59b3d0ec Pass mptest multiwrite on Windows. 2023-03-05 14:33:51 +00:00
Nuno Cruces
5f893b5459 Add SQLite mptest. 2023-03-05 12:20:02 +00:00
Nuno Cruces
35b1a97b88 Use finalizers to detect unclosed connections. 2023-03-03 14:50:55 +00:00
Nuno Cruces
416c3863a0 Documentation, tests, dependencies. 2023-03-01 23:47:24 +00:00
Nuno Cruces
dbc400eb15 Refactor native code. 2023-03-01 13:12:32 +00:00
Nuno Cruces
35265271aa Rename. 2023-03-01 11:18:25 +00:00
Nuno Cruces
c7165a2e56 Documentation. 2023-03-01 10:34:39 +00:00
Nuno Cruces
e64bffa520 Pragmas. 2023-02-28 16:03:31 +00:00
Nuno Cruces
54046b6adc Documentation. 2023-02-28 16:02:13 +00:00
Nuno Cruces
1b3823483f Incremental blobs. 2023-02-27 13:45:32 +00:00
Nuno Cruces
ce6d0627b2 Tests. 2023-02-27 12:07:48 +00:00
Nuno Cruces
dd30215702 Incremental blobs. 2023-02-27 04:08:55 +00:00
Nuno Cruces
21aff4c9f5 Towards incremental blobs. 2023-02-27 03:20:23 +00:00
Nuno Cruces
b30f127547 WAL mode, extensions. 2023-02-26 04:49:10 +00:00
Nuno Cruces
6509e5deb2 Transactions. 2023-02-26 03:22:08 +00:00
Nuno Cruces
125b8053f8 Fix readonly transactions. 2023-02-25 15:34:24 +00:00
Nuno Cruces
1e4a246d2f Error handling. 2023-02-25 15:11:07 +00:00
Nuno Cruces
e6cd0aaf87 MustPrepare. 2023-02-25 01:29:46 +00:00
Nuno Cruces
c1472a48b0 Tests. 2023-02-25 00:50:03 +00:00
Nuno Cruces
a69ab1ebe3 Fix data race. 2023-02-24 15:19:57 +00:00
Nuno Cruces
1190c21684 Refactor. 2023-02-24 15:06:19 +00:00
Nuno Cruces
8c28c3a6f4 Interrupt API. 2023-02-24 14:56:49 +00:00
Nuno Cruces
0146496036 Nested transactions. 2023-02-24 14:31:41 +00:00
Nuno Cruces
fcd33d2f0f Time improvements. 2023-02-24 11:09:30 +00:00
Nuno Cruces
627df5db0f No sandbox. 2023-02-23 14:16:37 +00:00
Nuno Cruces
1ed62d300d Require OFD locks. 2023-02-23 13:29:51 +00:00
Nuno Cruces
5b2451c3ad Default sector size. 2023-02-23 03:22:39 +00:00
Nuno Cruces
d52e0371eb Only reuse main db files. 2023-02-23 02:22:57 +00:00
Nuno Cruces
75f2644b0e SQLite 3.41.0. 2023-02-22 20:08:50 +00:00
Nuno Cruces
71ae26e5c9 Documentation. 2023-02-22 17:51:30 +00:00
84 changed files with 5458 additions and 1779 deletions

View File

@@ -15,6 +15,8 @@ jobs:
steps:
- uses: actions/checkout@v3
with:
lfs: 'true'
- name: Set up Go
uses: actions/setup-go@v3
@@ -34,6 +36,8 @@ jobs:
- name: Update coverage report
uses: ncruces/go-coverage-report@main
with:
chart: 'true'
if: |
matrix.os == 'ubuntu-latest' &&
github.event_name == 'push'

5
.gitignore vendored
View File

@@ -13,7 +13,4 @@
# Dependency directories (remove the comment below to include it)
# vendor/
tools
# Project
sqlite3/sqlite3*
tools

View File

@@ -2,25 +2,83 @@
[![Go Reference](https://pkg.go.dev/badge/image)](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
[![Go Report](https://goreportcard.com/badge/github.com/ncruces/go-sqlite3)](https://goreportcard.com/report/github.com/ncruces/go-sqlite3)
[![Go Coverage](https://github.com/ncruces/go-sqlite3/wiki/coverage.svg)](https://raw.githack.com/wiki/ncruces/go-sqlite3/coverage.html)
[![Go Coverage](https://github.com/ncruces/go-sqlite3/wiki/coverage.svg)](https://github.com/ncruces/go-sqlite3/wiki/Test-coverage-report)
⚠️ CAUTION ⚠️
Go module `github.com/ncruces/go-sqlite3` wraps a [WASM](https://webassembly.org/) build of [SQLite](https://sqlite.org/),
and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings.
This is a WIP.\
DO NOT USE with data you care about.
- Package [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
wraps the [C SQLite API](https://www.sqlite.org/cintro.html)
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
- Package [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver)
provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
- Package [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
embeds a build of SQLite into your application.
Roadmap:
- [x] build SQLite using `zig cc --target=wasm32-wasi`
- [x] `:memory:` databases
- [x] port [`test_demovfs.c`](https://www.sqlite.org/src/doc/trunk/src/test_demovfs.c) to Go
- branch [`wasi`](https://github.com/ncruces/go-sqlite3/tree/wasi) uses `test_demovfs.c` directly
- [x] design a simple, nice API, enough for simple use cases
- [x] provide a simple `database/sql` driver
- [x] file locking, compatible with SQLite on Windows/Unix
- [ ] shared memory, compatible with SQLite on Windows/Unix
- needed for improved WAL mode
- [ ] advanced features
- [ ] incremental BLOB I/O
- [ ] online backup
### Caveats
This module replaces the SQLite [OS Interface](https://www.sqlite.org/vfs.html) (aka VFS)
with a pure Go implementation.
This has numerous benefits, but also comes with some drawbacks.
#### Write-Ahead Logging
Because WASM does not support shared memory,
[WAL](https://www.sqlite.org/wal.html) support is [limited](https://www.sqlite.org/wal.html#noshm).
To work around this limitation, SQLite is compiled with
[`SQLITE_DEFAULT_LOCKING_MODE=1`](https://www.sqlite.org/compile.html#default_locking_mode),
making `EXCLUSIVE` the default locking mode.
For non-WAL databases, `NORMAL` locking mode can be activated with
[`PRAGMA locking_mode=NORMAL`](https://www.sqlite.org/pragma.html#pragma_locking_mode).
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
the `database/sql` driver defaults to `NORMAL` locking mode.
To open WAL databases, or use `EXCLUSIVE` locking mode,
disable connection pooling by calling
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
#### Open File Description Locks
On Unix, this module uses [OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
to synchronize access to database files.
POSIX advisory locks, which SQLite uses, are [broken by design](https://www.sqlite.org/src/artifact/90c4fa?ln=1073-1161).
OFD locks are fully compatible with process-associated POSIX advisory locks,
and are supported on Linux and macOS.
As a work around for other Unixes, you can use [`nolock=1`](https://www.sqlite.org/uri.html).
#### Testing
The pure Go VFS is stress tested by running an unmodified build of SQLite's
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
on Linux, macOS and Windows.
Performance is tested by running
[speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c).
### Roadmap
- [ ] advanced SQLite features
- [x] nested transactions
- [x] incremental BLOB I/O
- [x] online backup
- [ ] snapshots
- [ ] session extension
- [ ] snapshot
- [ ] resumable bulk update
- [ ] shared-cache mode
- [ ] unlock-notify
- [ ] custom SQL functions
- [ ] custom VFSes
- [ ] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
- [ ] in-memory VFS, wrapping a [`bytes.Buffer`](https://pkg.go.dev/bytes#Buffer)
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
- [ ] custom VFS API
### Alternatives
- [`modernc.org/sqlite`](https://pkg.go.dev/modernc.org/sqlite)
- [`crawshaw.io/sqlite`](https://pkg.go.dev/crawshaw.io/sqlite)
- [`github.com/mattn/go-sqlite3`](https://pkg.go.dev/github.com/mattn/go-sqlite3)
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)

112
api.go
View File

@@ -1,112 +0,0 @@
// Package sqlite3 wraps the C SQLite API.
package sqlite3
import (
"context"
"github.com/tetratelabs/wazero/api"
)
func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) {
getFun := func(name string) api.Function {
f := module.ExportedFunction(name)
if f == nil {
err = noFuncErr + errorString(name)
return nil
}
return f
}
getPtr := func(name string) uint32 {
global := module.ExportedGlobal(name)
if global == nil {
err = noGlobalErr + errorString(name)
return 0
}
return memory{module}.readUint32(uint32(global.Get()))
}
c := Conn{
ctx: ctx,
mem: memory{module},
api: sqliteAPI{
malloc: getFun("malloc"),
free: getFun("free"),
destructor: uint64(getPtr("malloc_destructor")),
errcode: getFun("sqlite3_errcode"),
errstr: getFun("sqlite3_errstr"),
errmsg: getFun("sqlite3_errmsg"),
erroff: getFun("sqlite3_error_offset"),
open: getFun("sqlite3_open_v2"),
close: getFun("sqlite3_close"),
prepare: getFun("sqlite3_prepare_v3"),
finalize: getFun("sqlite3_finalize"),
reset: getFun("sqlite3_reset"),
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
bindName: getFun("sqlite3_bind_parameter_name"),
bindNull: getFun("sqlite3_bind_null"),
bindInteger: getFun("sqlite3_bind_int64"),
bindFloat: getFun("sqlite3_bind_double"),
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
columnInteger: getFun("sqlite3_column_int64"),
columnFloat: getFun("sqlite3_column_double"),
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
changes: getFun("sqlite3_changes64"),
interrupt: getFun("sqlite3_interrupt"),
},
}
if err != nil {
return nil, err
}
return &c, nil
}
type sqliteAPI struct {
malloc api.Function
free api.Function
destructor uint64
errcode api.Function
errstr api.Function
errmsg api.Function
erroff api.Function
open api.Function
close api.Function
prepare api.Function
finalize api.Function
reset api.Function
step api.Function
exec api.Function
clearBindings api.Function
bindNull api.Function
bindCount api.Function
bindIndex api.Function
bindName api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
columnInteger api.Function
columnFloat api.Function
columnText api.Function
columnBlob api.Function
columnBytes api.Function
lastRowid api.Function
changes api.Function
interrupt api.Function
}

134
backup.go Normal file
View File

@@ -0,0 +1,134 @@
package sqlite3
// Backup is a handle to an open BLOB.
//
// https://www.sqlite.org/c3ref/backup.html
type Backup struct {
c *Conn
handle uint32
otherc uint32
}
// Backup backs up srcDB on the src connection to the "main" database in dstURI.
//
// Backup calls [Open] to open the SQLite database file dstURI,
// and blocks until the entire backup is complete.
// Use [Conn.BackupInit] for incremental backup.
//
// https://www.sqlite.org/backup.html
func (src *Conn) Backup(srcDB, dstURI string) error {
b, err := src.BackupInit(srcDB, dstURI)
if err != nil {
return err
}
defer b.Close()
_, err = b.Step(-1)
return err
}
// Restore restores dstDB on the dst connection from the "main" database in srcURI.
//
// Restore calls [Open] to open the SQLite database file srcURI,
// and blocks until the entire restore is complete.
//
// https://www.sqlite.org/backup.html
func (dst *Conn) Restore(dstDB, srcURI string) error {
src, err := dst.openDB(srcURI, OPEN_READONLY|OPEN_URI)
if err != nil {
return err
}
b, err := dst.backupInit(dst.handle, dstDB, src, "main")
if err != nil {
return err
}
defer b.Close()
_, err = b.Step(-1)
return err
}
// BackupInit initializes a backup operation to copy the content of one database into another.
//
// BackupInit calls [Open] to open the SQLite database file dstURI,
// then initializes a backup that copies the contents of srcDB on the src connection
// to the "main" database in dstURI.
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupinit
func (src *Conn) BackupInit(srcDB, dstURI string) (*Backup, error) {
dst, err := src.openDB(dstURI, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
if err != nil {
return nil, err
}
return src.backupInit(dst, "main", src.handle, srcDB)
}
func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string) (*Backup, error) {
defer c.arena.reset()
dstPtr := c.arena.string(dstName)
srcPtr := c.arena.string(srcName)
other := dst
if c.handle == dst {
other = src
}
r := c.call(c.api.backupInit,
uint64(dst), uint64(dstPtr),
uint64(src), uint64(srcPtr))
if r[0] == 0 {
defer c.closeDB(other)
r = c.call(c.api.errcode, uint64(dst))
return nil, c.module.error(r[0], dst)
}
return &Backup{
c: c,
otherc: other,
handle: uint32(r[0]),
}, nil
}
// Close finishes a backup operation.
//
// It is safe to close a nil, zero or closed Backup.
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupfinish
func (b *Backup) Close() error {
if b == nil || b.handle == 0 {
return nil
}
r := b.c.call(b.c.api.backupFinish, uint64(b.handle))
b.c.closeDB(b.otherc)
b.handle = 0
return b.c.error(r[0])
}
// Step copies up to nPage pages between the source and destination databases.
// If nPage is negative, all remaining source pages are copied.
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupstep
func (b *Backup) Step(nPage int) (done bool, err error) {
r := b.c.call(b.c.api.backupStep, uint64(b.handle), uint64(nPage))
if r[0] == _DONE {
return true, nil
}
return false, b.c.error(r[0])
}
// Remaining returns the number of pages still to be backed up
// at the conclusion of the most recent [Backup.Step].
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupremaining
func (b *Backup) Remaining() int {
r := b.c.call(b.c.api.backupRemaining, uint64(b.handle))
return int(r[0])
}
// PageCount returns the total number of pages in the source database
// at the conclusion of the most recent [Backup.Step].
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backuppagecount
func (b *Backup) PageCount() int {
r := b.c.call(b.c.api.backupFinish, uint64(b.handle))
return int(r[0])
}

155
blob.go
View File

@@ -1,6 +1,159 @@
package sqlite3
import "io"
// ZeroBlob represents a zero-filled, length n BLOB
// that can be used as an argument to
// [database.sql.DB.Exec] and similar methods.
// [database/sql.DB.Exec] and similar methods.
type ZeroBlob int64
// Blob is a handle to an open BLOB.
//
// It implements [io.ReadWriteSeeker] for incremental BLOB I/O.
//
// https://www.sqlite.org/c3ref/blob.html
type Blob struct {
c *Conn
handle uint32
bytes int64
offset int64
}
var _ io.ReadWriteSeeker = &Blob{}
// OpenBlob opens a BLOB for incremental I/O.
//
// https://www.sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
c.checkInterrupt()
defer c.arena.reset()
blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db)
tablePtr := c.arena.string(table)
columnPtr := c.arena.string(column)
var flags uint64
if write {
flags = 1
}
r := c.call(c.api.blobOpen, uint64(c.handle),
uint64(dbPtr), uint64(tablePtr), uint64(columnPtr),
uint64(row), flags, uint64(blobPtr))
if err := c.error(r[0]); err != nil {
return nil, err
}
blob := Blob{c: c}
blob.handle = c.mem.readUint32(blobPtr)
blob.bytes = int64(c.call(c.api.blobBytes, uint64(blob.handle))[0])
return &blob, nil
}
// Close closes a BLOB handle.
//
// It is safe to close a nil, zero or closed Blob.
//
// https://www.sqlite.org/c3ref/blob_close.html
func (b *Blob) Close() error {
if b == nil || b.handle == 0 {
return nil
}
r := b.c.call(b.c.api.blobClose, uint64(b.handle))
b.handle = 0
return b.c.error(r[0])
}
// Size returns the size of the BLOB in bytes.
//
// https://www.sqlite.org/c3ref/blob_bytes.html
func (b *Blob) Size() int64 {
return b.bytes
}
// Read implements the [io.Reader] interface.
//
// https://www.sqlite.org/c3ref/blob_read.html
func (b *Blob) Read(p []byte) (n int, err error) {
if b.offset >= b.bytes {
return 0, io.EOF
}
want := int64(len(p))
avail := b.bytes - b.offset
if want > avail {
want = avail
}
ptr := b.c.new(uint64(want))
defer b.c.free(ptr)
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
uint64(ptr), uint64(want), uint64(b.offset))
err = b.c.error(r[0])
if err != nil {
return 0, err
}
mem := b.c.mem.view(ptr, uint64(want))
copy(p, mem)
b.offset += want
if b.offset >= b.bytes {
err = io.EOF
}
return int(want), err
}
// Write implements the [io.Writer] interface.
//
// https://www.sqlite.org/c3ref/blob_write.html
func (b *Blob) Write(p []byte) (n int, err error) {
offset := b.offset
if offset > b.bytes {
offset = b.bytes
}
ptr := b.c.newBytes(p)
defer b.c.free(ptr)
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
uint64(ptr), uint64(len(p)), uint64(offset))
err = b.c.error(r[0])
if err != nil {
return 0, err
}
b.offset += int64(len(p))
return len(p), nil
}
// Seek implements the [io.Seeker] interface.
func (b *Blob) Seek(offset int64, whence int) (int64, error) {
switch whence {
default:
return 0, whenceErr
case io.SeekStart:
break
case io.SeekCurrent:
offset += b.offset
case io.SeekEnd:
offset += b.bytes
}
if offset < 0 {
return 0, offsetErr
}
b.offset = offset
return offset, nil
}
// Reopen moves a BLOB handle to a new row of the same database table.
//
// https://www.sqlite.org/c3ref/blob_reopen.html
func (b *Blob) Reopen(row int64) error {
r := b.c.call(b.c.api.blobReopen, uint64(b.handle), uint64(row))
b.bytes = int64(b.c.call(b.c.api.blobBytes, uint64(b.handle))[0])
b.offset = 0
return b.c.error(r[0])
}

View File

@@ -1,58 +0,0 @@
package sqlite3
import (
"context"
"os"
"strconv"
"sync"
"sync/atomic"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
// Configure SQLite.
var (
Binary []byte // Binary to load.
Path string // Path to load the binary from.
)
var sqlite3 sqlite3Runtime
type sqlite3Runtime struct {
once sync.Once
runtime wazero.Runtime
compiled wazero.CompiledModule
instances atomic.Uint64
err error
}
func (s *sqlite3Runtime) instantiateModule(ctx context.Context) (api.Module, error) {
s.once.Do(func() { s.compileModule(ctx) })
if s.err != nil {
return nil, s.err
}
cfg := wazero.NewModuleConfig().
WithName("sqlite3-" + strconv.FormatUint(s.instances.Add(1), 10))
return s.runtime.InstantiateModule(ctx, s.compiled, cfg)
}
func (s *sqlite3Runtime) compileModule(ctx context.Context) {
s.runtime = wazero.NewRuntime(ctx)
vfsInstantiate(ctx, s.runtime)
bin := Binary
if bin == nil && Path != "" {
bin, s.err = os.ReadFile(Path)
if s.err != nil {
return
}
}
if bin == nil {
s.err = binaryErr
return
}
s.compiled, s.err = s.runtime.CompileModule(ctx, bin)
}

431
conn.go
View File

@@ -2,64 +2,118 @@ package sqlite3
import (
"context"
"math"
"database/sql/driver"
"errors"
"fmt"
"net/url"
"runtime"
"strings"
"sync/atomic"
"unsafe"
)
// Conn is a database connection handle.
// A Conn is not safe for concurrent use by multiple goroutines.
//
// https://www.sqlite.org/c3ref/sqlite3.html
type Conn struct {
ctx context.Context
api sqliteAPI
mem memory
handle uint32
*module
arena arena
pending *Stmt
waiter chan struct{}
done <-chan struct{}
handle uint32
arena arena
interrupt context.Context
waiter chan struct{}
pending *Stmt
}
// Open calls [OpenFlags] with [OPEN_READWRITE] and [OPEN_CREATE].
func Open(filename string) (conn *Conn, err error) {
return OpenFlags(filename, OPEN_READWRITE|OPEN_CREATE)
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE], [OPEN_URI] and [OPEN_NOFOLLOW].
func Open(filename string) (*Conn, error) {
return newConn(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI|OPEN_NOFOLLOW)
}
// OpenFlags opens an SQLite database file as specified by the filename argument.
//
// If none of the required flags is used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used.
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
//
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
//
// https://www.sqlite.org/c3ref/open.html
func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
ctx := context.Background()
module, err := sqlite3.instantiateModule(ctx)
func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
if flags&(OPEN_READONLY|OPEN_READWRITE|OPEN_CREATE) == 0 {
flags |= OPEN_READWRITE | OPEN_CREATE
}
return newConn(filename, flags)
}
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
mod, err := instantiateModule()
if err != nil {
return nil, err
}
defer func() {
if conn == nil {
module.Close(ctx)
mod.close()
} else {
runtime.SetFinalizer(conn, finalizer[Conn](3))
}
}()
c, err := newConn(ctx, module)
c := &Conn{module: mod}
c.arena = c.newArena(1024)
c.handle, err = c.openDB(filename, flags)
if err != nil {
return nil, err
}
c.arena = c.newArena(1024)
return c, nil
}
func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
defer c.arena.reset()
connPtr := c.arena.new(ptrlen)
namePtr := c.arena.string(filename)
r, err := c.api.open.Call(c.ctx, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
if err != nil {
panic(err)
flags |= OPEN_EXRESCODE
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
handle := c.mem.readUint32(connPtr)
if err := c.module.error(r[0], handle); err != nil {
c.closeDB(handle)
return 0, err
}
c.handle = c.mem.readUint32(connPtr)
if err := c.error(r[0]); err != nil {
return nil, err
if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
var pragmas strings.Builder
if _, after, ok := strings.Cut(filename, "?"); ok {
query, _ := url.ParseQuery(after)
for _, p := range query["_pragma"] {
pragmas.WriteString(`PRAGMA `)
pragmas.WriteString(p)
pragmas.WriteByte(';')
}
}
c.arena.reset()
pragmaPtr := c.arena.string(pragmas.String())
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.module.error(r[0], handle, pragmas.String()); err != nil {
if errors.Is(err, ERROR) {
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
c.closeDB(handle)
return 0, err
}
}
c.call(c.api.timeCollation, uint64(handle))
return handle, nil
}
func (c *Conn) closeDB(handle uint32) {
r := c.call(c.api.closeZombie, uint64(handle))
if err := c.module.error(r[0], handle); err != nil {
panic(err)
}
return c, nil
}
// Close closes the database connection.
@@ -68,7 +122,7 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
// open blob handles, and/or unfinished backup objects,
// Close will leave the database connection open and return [BUSY].
//
// It is safe to close a nil, zero or closed connection.
// It is safe to close a nil, zero or closed Conn.
//
// https://www.sqlite.org/c3ref/close.html
func (c *Conn) Close() error {
@@ -76,82 +130,18 @@ func (c *Conn) Close() error {
return nil
}
c.SetInterrupt(nil)
r, err := c.api.close.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
c.SetInterrupt(context.Background())
c.pending.Close()
c.pending = nil
r := c.call(c.api.close, uint64(c.handle))
if err := c.error(r[0]); err != nil {
return err
}
c.handle = 0
return c.mem.mod.Close(c.ctx)
}
// SetInterrupt interrupts a long-running query when done is closed.
//
// Subsequent uses of the connection will return [INTERRUPT]
// until done is reset by another call to SetInterrupt.
//
// Typically, done is provided by [context.Context.Done]:
//
// ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
// conn.SetInterrupt(ctx.Done())
// defer cancel()
//
// https://www.sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) {
// Is a waiter running?
if c.waiter != nil {
c.waiter <- struct{}{} // Cancel the waiter.
<-c.waiter // Wait for it to finish.
c.waiter = nil
}
// Finalize the uncompleted SQL statement.
if c.pending != nil {
c.pending.Close()
c.pending = nil
}
old = c.done
c.done = done
if done == nil {
return old
}
// Creating an uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
c.pending.Step()
waiter := make(chan struct{})
c.waiter = waiter
go func() {
select {
case <-waiter: // Waiter was cancelled.
break
case <-done: // Done was closed.
// This is safe to call from a goroutine
// because it doesn't touch the C stack.
_, err := c.api.interrupt.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
// Wait for the next call to SetInterrupt.
<-waiter
}
// Signal that the waiter has finished.
waiter <- struct{}{}
}()
return old
runtime.SetFinalizer(c, nil)
return c.module.close()
}
// Exec is a convenience function that allows an application to run
@@ -159,13 +149,11 @@ func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) {
//
// https://www.sqlite.org/c3ref/exec.html
func (c *Conn) Exec(sql string) error {
c.checkInterrupt()
defer c.arena.reset()
sqlPtr := c.arena.string(sql)
r, err := c.api.exec.Call(c.ctx, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
if err != nil {
panic(err)
}
r := c.call(c.api.exec, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
return c.error(r[0])
}
@@ -190,12 +178,9 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
tailPtr := c.arena.new(ptrlen)
sqlPtr := c.arena.string(sql)
r, err := c.api.prepare.Call(c.ctx, uint64(c.handle),
r := c.call(c.api.prepare, uint64(c.handle),
uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),
uint64(stmtPtr), uint64(tailPtr))
if err != nil {
panic(err)
}
stmt = &Stmt{c: c}
stmt.handle = c.mem.readUint32(stmtPtr)
@@ -211,16 +196,21 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
return
}
// GetAutocommit tests the connection for auto-commit mode.
//
// https://www.sqlite.org/c3ref/get_autocommit.html
func (c *Conn) GetAutocommit() bool {
r := c.call(c.api.autocommit, uint64(c.handle))
return r[0] != 0
}
// LastInsertRowID returns the rowid of the most recent successful INSERT
// on the database connection.
//
// https://www.sqlite.org/c3ref/last_insert_rowid.html
func (c *Conn) LastInsertRowID() uint64 {
r, err := c.api.lastRowid.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
return r[0]
func (c *Conn) LastInsertRowID() int64 {
r := c.call(c.api.lastRowid, uint64(c.handle))
return int64(r[0])
}
// Changes returns the number of rows modified, inserted or deleted
@@ -228,125 +218,118 @@ func (c *Conn) LastInsertRowID() uint64 {
// on the database connection.
//
// https://www.sqlite.org/c3ref/changes.html
func (c *Conn) Changes() uint64 {
r, err := c.api.changes.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
func (c *Conn) Changes() int64 {
r := c.call(c.api.changes, uint64(c.handle))
return int64(r[0])
}
// SetInterrupt interrupts a long-running query when a context is done.
//
// Subsequent uses of the connection will return [INTERRUPT]
// until the context is reset by another call to SetInterrupt.
//
// To associate a timeout with a connection:
//
// ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
// conn.SetInterrupt(ctx)
// defer cancel()
//
// SetInterrupt returns the old context assigned to the connection.
//
// https://www.sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
// Is a waiter running?
if c.waiter != nil {
c.waiter <- struct{}{} // Cancel the waiter.
<-c.waiter // Wait for it to finish.
c.waiter = nil
}
return r[0]
// Reset the pending statement.
if c.pending != nil {
c.pending.Reset()
}
old = c.interrupt
c.interrupt = ctx
if ctx == nil || ctx.Done() == nil {
return old
}
// Creating an uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
if c.pending == nil {
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
}
c.pending.Step()
// Don't create the goroutine if we're already interrupted.
// This happens frequently while restoring to a previously interrupted state.
if c.checkInterrupt() {
return old
}
waiter := make(chan struct{})
c.waiter = waiter
go func() {
select {
case <-waiter: // Waiter was cancelled.
break
case <-ctx.Done(): // Done was closed.
buf := c.mem.view(c.handle+c.api.interrupt, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
// Wait for the next call to SetInterrupt.
<-waiter
}
// Signal that the waiter has finished.
waiter <- struct{}{}
}()
return old
}
func (c *Conn) checkInterrupt() bool {
if c.interrupt == nil || c.interrupt.Err() == nil {
return false
}
buf := c.mem.view(c.handle+c.api.interrupt, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
return true
}
// Pragma executes a PRAGMA statement and returns any results.
//
// https://www.sqlite.org/pragma.html
func (c *Conn) Pragma(str string) ([]string, error) {
stmt, _, err := c.Prepare(`PRAGMA ` + str)
if err != nil {
return nil, err
}
defer stmt.Close()
var pragmas []string
for stmt.Step() {
pragmas = append(pragmas, stmt.ColumnText(0))
}
return pragmas, stmt.Close()
}
func (c *Conn) error(rc uint64, sql ...string) error {
if rc == _OK {
return nil
}
err := Error{code: rc}
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
panic(oomErr)
}
var r []uint64
r, _ = c.api.errstr.Call(c.ctx, rc)
if r != nil {
err.str = c.mem.readString(uint32(r[0]), _MAX_STRING)
}
r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle))
if r != nil {
err.msg = c.mem.readString(uint32(r[0]), _MAX_STRING)
}
if sql != nil {
r, _ = c.api.erroff.Call(c.ctx, uint64(c.handle))
if r != nil && r[0] != math.MaxUint32 {
err.sql = sql[0][r[0]:]
}
}
switch err.msg {
case err.str, "not an error":
err.msg = ""
}
return &err
return c.module.error(rc, c.handle, sql...)
}
func (c *Conn) free(ptr uint32) {
if ptr == 0 {
return
}
_, err := c.api.free.Call(c.ctx, uint64(ptr))
if err != nil {
panic(err)
}
}
// DriverConn is implemented by the SQLite [database/sql] driver connection.
//
// It can be used to access advanced SQLite features like
// [savepoints] and [incremental BLOB I/O].
//
// [savepoints]: https://www.sqlite.org/lang_savepoint.html
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
type DriverConn interface {
driver.ConnBeginTx
driver.ExecerContext
driver.ConnPrepareContext
func (c *Conn) new(size uint32) uint32 {
r, err := c.api.malloc.Call(c.ctx, uint64(size))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 && size != 0 {
panic(oomErr)
}
return ptr
}
func (c *Conn) newBytes(b []byte) uint32 {
if b == nil {
return 0
}
ptr := c.new(uint32(len(b)))
c.mem.writeBytes(ptr, b)
return ptr
}
func (c *Conn) newString(s string) uint32 {
ptr := c.new(uint32(len(s) + 1))
c.mem.writeString(ptr, s)
return ptr
}
func (c *Conn) newArena(size uint32) arena {
return arena{
c: c,
size: size,
base: c.new(size),
}
}
type arena struct {
c *Conn
base uint32
next uint32
size uint32
ptrs []uint32
}
func (a *arena) reset() {
for _, ptr := range a.ptrs {
a.c.free(ptr)
}
a.ptrs = nil
a.next = 0
}
func (a *arena) new(size uint32) uint32 {
if a.next+size <= a.size {
ptr := a.base + a.next
a.next += size
return ptr
}
ptr := a.c.new(size)
a.ptrs = append(a.ptrs, ptr)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint32(len(s) + 1))
a.c.mem.writeString(ptr, s)
return ptr
Savepoint() Savepoint
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
}

View File

@@ -7,11 +7,15 @@ const (
_ROW = 100 /* sqlite3_step() has another row ready */
_DONE = 101 /* sqlite3_step() has finished executing */
_OK_SYMLINK = (_OK | (2 << 8)) /* internal use only */
_UTF8 = 1
_MAX_STRING = 512 // Used for short strings: names, error messages…
_MAX_PATHNAME = 512
_MAX_ALLOCATION_SIZE = 0x7ffffeff
ptrlen = 4
)
@@ -131,6 +135,7 @@ const (
CONSTRAINT_DATATYPE ExtendedErrorCode = xErrorCode(CONSTRAINT) | (12 << 8)
NOTICE_RECOVER_WAL ExtendedErrorCode = xErrorCode(NOTICE) | (1 << 8)
NOTICE_RECOVER_ROLLBACK ExtendedErrorCode = xErrorCode(NOTICE) | (2 << 8)
NOTICE_RBU ExtendedErrorCode = xErrorCode(NOTICE) | (3 << 8)
WARNING_AUTOINDEX ExtendedErrorCode = xErrorCode(WARNING) | (1 << 8)
AUTH_USER ExtendedErrorCode = xErrorCode(AUTH) | (1 << 8)
)
@@ -165,14 +170,6 @@ const (
OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */
)
type _AccessFlag uint32
const (
_ACCESS_EXISTS _AccessFlag = 0
_ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
_ACCESS_READ _AccessFlag = 2 /* Unused */
)
// PrepareFlag is a flag that can be passed to [Conn.PrepareFlags].
//
// https://www.sqlite.org/c3ref/c_prepare_normalize.html
@@ -214,3 +211,65 @@ func (t Datatype) String() string {
}
return strconv.FormatUint(uint64(t), 10)
}
type _AccessFlag uint32
const (
_ACCESS_EXISTS _AccessFlag = 0
_ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
_ACCESS_READ _AccessFlag = 2 /* Unused */
)
type _SyncFlag uint32
const (
_SYNC_NORMAL _SyncFlag = 0x00002
_SYNC_FULL _SyncFlag = 0x00003
_SYNC_DATAONLY _SyncFlag = 0x00010
)
type _FcntlOpcode uint32
const (
_FCNTL_LOCKSTATE = 1
_FCNTL_GET_LOCKPROXYFILE = 2
_FCNTL_SET_LOCKPROXYFILE = 3
_FCNTL_LAST_ERRNO = 4
_FCNTL_SIZE_HINT = 5
_FCNTL_CHUNK_SIZE = 6
_FCNTL_FILE_POINTER = 7
_FCNTL_SYNC_OMITTED = 8
_FCNTL_WIN32_AV_RETRY = 9
_FCNTL_PERSIST_WAL = 10
_FCNTL_OVERWRITE = 11
_FCNTL_VFSNAME = 12
_FCNTL_POWERSAFE_OVERWRITE = 13
_FCNTL_PRAGMA = 14
_FCNTL_BUSYHANDLER = 15
_FCNTL_TEMPFILENAME = 16
_FCNTL_MMAP_SIZE = 18
_FCNTL_TRACE = 19
_FCNTL_HAS_MOVED = 20
_FCNTL_SYNC = 21
_FCNTL_COMMIT_PHASETWO = 22
_FCNTL_WIN32_SET_HANDLE = 23
_FCNTL_WAL_BLOCK = 24
_FCNTL_ZIPVFS = 25
_FCNTL_RBU = 26
_FCNTL_VFS_POINTER = 27
_FCNTL_JOURNAL_POINTER = 28
_FCNTL_WIN32_GET_HANDLE = 29
_FCNTL_PDB = 30
_FCNTL_BEGIN_ATOMIC_WRITE = 31
_FCNTL_COMMIT_ATOMIC_WRITE = 32
_FCNTL_ROLLBACK_ATOMIC_WRITE = 33
_FCNTL_LOCK_TIMEOUT = 34
_FCNTL_DATA_VERSION = 35
_FCNTL_SIZE_LIMIT = 36
_FCNTL_CKPT_DONE = 37
_FCNTL_RESERVE_BYTES = 38
_FCNTL_CKPT_START = 39
_FCNTL_EXTERNAL_READER = 40
_FCNTL_CKSM_FILE = 41
_FCNTL_RESET_CACHE = 42
)

View File

@@ -1,4 +1,27 @@
// Package driver provides a database/sql driver for SQLite.
//
// Importing package driver registers a [database/sql] driver named "sqlite3".
// You may also need to import package embed.
//
// import _ "github.com/ncruces/go-sqlite3/driver"
// import _ "github.com/ncruces/go-sqlite3/embed"
//
// The data source name for "sqlite3" databases can be a filename or a "file:" [URI].
//
// The [TRANSACTION] mode can be specified using "_txlock":
//
// sql.Open("sqlite3", "file:demo.db?_txlock=immediate")
//
// [PRAGMA] statements can be specified using "_pragma":
//
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
//
// If no PRAGMAs are specifed, a busy timeout of 1 minute
// and normal locking mode are used.
//
// [URI]: https://www.sqlite.org/uri.html
// [PRAGMA]: https://www.sqlite.org/pragma.html
// [TRANSACTION]: https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
package driver
import (
@@ -20,102 +43,112 @@ func init() {
type sqlite struct{}
func (sqlite) Open(name string) (driver.Conn, error) {
c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE)
func (sqlite) Open(name string) (_ driver.Conn, err error) {
c, err := sqlite3.Open(name)
if err != nil {
return nil, err
}
var txBegin string
var pragmas strings.Builder
if _, after, ok := strings.Cut(name, "?"); ok {
query, _ := url.ParseQuery(after)
var pragmas []string
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
query, _ := url.ParseQuery(after)
switch s := query.Get("_txlock"); s {
case "":
txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
txBegin = "BEGIN " + s
default:
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
}
switch s := query.Get("_txlock"); s {
case "":
txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
txBegin = "BEGIN " + s
default:
c.Close()
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
}
for _, p := range query["_pragma"] {
pragmas.WriteString(`PRAGMA `)
pragmas.WriteString(p)
pragmas.WriteByte(';')
pragmas = query["_pragma"]
}
}
if pragmas.Len() == 0 {
pragmas.WriteString(`PRAGMA locking_mode=normal;`)
pragmas.WriteString(`PRAGMA busy_timeout=60000;`)
if len(pragmas) == 0 {
err := c.Exec(`
PRAGMA busy_timeout=60000;
PRAGMA locking_mode=normal;
`)
if err != nil {
c.Close()
return nil, err
}
}
err = c.Exec(pragmas.String())
if err != nil {
return nil, fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
return conn{
conn: c,
txBegin: txBegin,
pragmas: pragmas.String(),
}, nil
}
type conn struct {
conn *sqlite3.Conn
pragmas string
txBegin string
txReadOnly bool
txCommit string
txRollback string
}
var (
// Ensure these interfaces are implemented:
_ driver.Validator = conn{}
_ driver.SessionResetter = conn{}
_ driver.ExecerContext = conn{}
_ driver.ConnBeginTx = conn{}
_ driver.ExecerContext = conn{}
_ driver.ConnBeginTx = conn{}
_ driver.Validator = conn{}
_ sqlite3.DriverConn = conn{}
)
func (c conn) Close() error {
return c.conn.Close()
}
func (c conn) IsValid() bool {
// Pool only normal locking mode connections.
stmt, _, err := c.conn.Prepare(`PRAGMA locking_mode`)
if err != nil {
return false
}
defer stmt.Close()
return stmt.Step() && stmt.ColumnText(0) == "normal"
}
func (c conn) ResetSession(ctx context.Context) error {
return c.conn.Exec(c.pragmas)
func (c conn) IsValid() (valid bool) {
r, err := c.conn.Pragma("locking_mode")
return err == nil && len(r) == 1 && r[0] == "normal"
}
func (c conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, error) {
txBegin := c.txBegin
c.txCommit = `COMMIT`
c.txRollback = `ROLLBACK`
if opts.ReadOnly {
query_only, err := c.conn.Pragma("query_only")
if err != nil {
return nil, err
}
txBegin = `
BEGIN deferred;
PRAGMA query_only=on`
c.txCommit = `
ROLLBACK;
PRAGMA query_only=` + query_only[0]
c.txRollback = c.txCommit
}
switch opts.Isolation {
default:
return nil, isolationErr
case driver.IsolationLevel(sql.LevelDefault):
case driver.IsolationLevel(sql.LevelSerializable):
case
driver.IsolationLevel(sql.LevelDefault),
driver.IsolationLevel(sql.LevelSerializable):
break
case driver.IsolationLevel(sql.LevelReadUncommitted):
read_uncommitted, err := c.conn.Pragma("read_uncommitted")
if err != nil {
return nil, err
}
txBegin += `; PRAGMA read_uncommitted=on`
c.txCommit += `; PRAGMA read_uncommitted=` + read_uncommitted[0]
c.txRollback += `; PRAGMA read_uncommitted=` + read_uncommitted[0]
}
txBegin := c.txBegin
if opts.ReadOnly {
txBegin = `
BEGIN deferred;
PRAGMA query_only=on;
`
}
c.txReadOnly = opts.ReadOnly
err := c.conn.Exec(txBegin)
if err != nil {
return nil, err
@@ -124,18 +157,15 @@ func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, er
}
func (c conn) Commit() error {
if c.txReadOnly {
return c.Rollback()
}
err := c.conn.Exec(`COMMIT`)
if err != nil {
err := c.conn.Exec(c.txCommit)
if err != nil && !c.conn.GetAutocommit() {
c.Rollback()
}
return err
}
func (c conn) Rollback() error {
return c.conn.Exec(`ROLLBACK`)
return c.conn.Exec(c.txRollback)
}
func (c conn) Prepare(query string) (driver.Stmt, error) {
@@ -159,14 +189,18 @@ func (c conn) Prepare(query string) (driver.Stmt, error) {
return stmt{s, c.conn}, nil
}
func (c conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
return c.Prepare(query)
}
func (c conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if len(args) != 0 {
// Slow path.
return nil, driver.ErrSkip
}
ch := c.conn.SetInterrupt(ctx.Done())
defer c.conn.SetInterrupt(ch)
old := c.conn.SetInterrupt(ctx)
defer c.conn.SetInterrupt(old)
err := c.conn.Exec(query)
if err != nil {
@@ -174,11 +208,19 @@ func (c conn) ExecContext(ctx context.Context, query string, args []driver.Named
}
return result{
int64(c.conn.LastInsertRowID()),
int64(c.conn.Changes()),
c.conn.LastInsertRowID(),
c.conn.Changes(),
}, nil
}
func (c conn) Savepoint() sqlite3.Savepoint {
return c.conn.Savepoint()
}
func (c conn) OpenBlob(db, table, column string, row int64, write bool) (*sqlite3.Blob, error) {
return c.conn.OpenBlob(db, table, column, row, write)
}
type stmt struct {
stmt *sqlite3.Stmt
conn *sqlite3.Conn
@@ -325,8 +367,8 @@ func (r rows) Columns() []string {
}
func (r rows) Next(dest []driver.Value) error {
ch := r.conn.SetInterrupt(r.ctx.Done())
defer r.conn.SetInterrupt(ch)
old := r.conn.SetInterrupt(r.ctx)
defer r.conn.SetInterrupt(old)
if !r.stmt.Step() {
if err := r.stmt.Err(); err != nil {
@@ -342,7 +384,7 @@ func (r rows) Next(dest []driver.Value) error {
case sqlite3.FLOAT:
dest[i] = r.stmt.ColumnFloat(i)
case sqlite3.TEXT:
dest[i] = maybeDate(r.stmt.ColumnText(i))
dest[i] = maybeTime(r.stmt.ColumnText(i))
case sqlite3.BLOB:
buf, _ := dest[i].([]byte)
dest[i] = r.stmt.ColumnBlob(i, buf)

View File

@@ -1,4 +1,3 @@
// Package driver provides a database/sql driver for SQLite.
package driver
import (
@@ -15,6 +14,8 @@ import (
)
func Test_Open_dir(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", ".")
if err != nil {
t.Fatal(err)
@@ -25,19 +26,14 @@ func Test_Open_dir(t *testing.T) {
if err == nil {
t.Fatal("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.CANTOPEN {
t.Errorf("got %d, want sqlite3.CANTOPEN", rc)
}
if got := err.Error(); got != `sqlite3: unable to open database file` {
t.Error("got message: ", got)
if !errors.Is(err, sqlite3.CANTOPEN) {
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
}
}
func Test_Open_pragma(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", "file::memory:?_pragma=busy_timeout(1000)")
if err != nil {
t.Fatal(err)
@@ -55,6 +51,8 @@ func Test_Open_pragma(t *testing.T) {
}
func Test_Open_pragma_invalid(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", "file::memory:?_pragma=busy_timeout+1000")
if err != nil {
t.Fatal(err)
@@ -73,13 +71,15 @@ func Test_Open_pragma_invalid(t *testing.T) {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: invalid _pragma: sqlite3: SQL logic error: near "1000": syntax error` {
t.Error("got message: ", got)
t.Error("got message:", got)
}
}
func Test_Open_txLock(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", "file:"+
filepath.Join(t.TempDir(), "test.db")+
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
"?_txlock=exclusive&_pragma=busy_timeout(0)")
if err != nil {
t.Fatal(err)
@@ -95,20 +95,13 @@ func Test_Open_txLock(t *testing.T) {
if err == nil {
t.Error("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.BUSY {
t.Errorf("got %d, want sqlite3.BUSY", rc)
if !errors.Is(err, sqlite3.BUSY) {
t.Errorf("got %v, want sqlite3.BUSY", err)
}
var terr interface{ Temporary() bool }
if !errors.As(err, &terr) || !terr.Temporary() {
t.Error("not temporary", err)
}
if got := err.Error(); got != `sqlite3: database is locked` {
t.Error("got message: ", got)
}
err = tx1.Commit()
if err != nil {
@@ -117,6 +110,8 @@ func Test_Open_txLock(t *testing.T) {
}
func Test_Open_txLock_invalid(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", "file::memory:?_txlock=xclusive")
if err != nil {
t.Fatal(err)
@@ -128,15 +123,19 @@ func Test_Open_txLock_invalid(t *testing.T) {
t.Fatal("want error")
}
if got := err.Error(); got != `sqlite3: invalid _txlock: xclusive` {
t.Error("got message: ", got)
t.Error("got message:", got)
}
}
func Test_BeginTx(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.db"))
db, err := sql.Open("sqlite3", "file:"+
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
"?_txlock=exclusive&_pragma=busy_timeout(0)")
if err != nil {
t.Fatal(err)
}
@@ -147,6 +146,16 @@ func Test_BeginTx(t *testing.T) {
t.Error("want isolationErr")
}
tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadUncommitted})
if err != nil {
t.Fatal(err)
}
err = tx.Rollback()
if err != nil {
t.Fatal(err)
}
tx1, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
t.Fatal(err)
@@ -161,15 +170,8 @@ func Test_BeginTx(t *testing.T) {
if err == nil {
t.Error("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.READONLY {
t.Errorf("got %d, want sqlite3.READONLY", rc)
}
if got := err.Error(); got != `sqlite3: attempt to write a readonly database` {
t.Error("got message: ", got)
if !errors.Is(err, sqlite3.READONLY) {
t.Errorf("got %v, want sqlite3.READONLY", err)
}
err = tx2.Commit()
@@ -184,6 +186,8 @@ func Test_BeginTx(t *testing.T) {
}
func Test_Prepare(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
@@ -208,7 +212,7 @@ func Test_Prepare(t *testing.T) {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message: ", got)
t.Error("got message:", got)
}
_, err = db.Prepare(`SELECT 1; SELECT`)
@@ -222,7 +226,7 @@ func Test_Prepare(t *testing.T) {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message: ", got)
t.Error("got message:", got)
}
_, err = db.Prepare(`SELECT 1; SELECT 2`)
@@ -232,6 +236,8 @@ func Test_Prepare(t *testing.T) {
}
func Test_QueryRow_named(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -282,6 +288,8 @@ func Test_QueryRow_named(t *testing.T) {
}
func Test_QueryRow_blob_null(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
@@ -310,39 +318,3 @@ func Test_QueryRow_blob_null(t *testing.T) {
}
}
}
func Test_ZeroBlob(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
_, err = conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
_, err = conn.ExecContext(ctx, `INSERT INTO test(col) VALUES(?)`, sqlite3.ZeroBlob(4))
if err != nil {
t.Fatal(err)
}
var got []byte
err = conn.QueryRowContext(ctx, `SELECT col FROM test`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if string(got) != "\x00\x00\x00\x00" {
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
}
}

View File

@@ -28,10 +28,11 @@ func Example() {
if err != nil {
log.Fatal(err)
}
defer db.Close()
defer os.Remove("./recordings.db")
defer db.Close()
err = createAlbumsTable()
// Create a table with some data in it.
err = albumsSetup()
if err != nil {
log.Fatal(err)
}
@@ -58,14 +59,13 @@ func Example() {
log.Fatal(err)
}
fmt.Printf("ID of added album: %v\n", albID)
// Output:
// Albums found: [{1 Blue Train John Coltrane 56.99} {2 Giant Steps John Coltrane 63.99}]
// Album found: {2 Giant Steps John Coltrane 63.99}
// ID of added album: 5
}
func createAlbumsTable() error {
func albumsSetup() error {
_, err := db.Exec(`
DROP TABLE IF EXISTS album;
CREATE TABLE album (

View File

@@ -8,8 +8,8 @@ import (
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
// if it roundtrips back to the same string.
// This way times can be persisted to, and recovered from, the database,
// but if a string is needed, [database.sql] will recover the same string.
func maybeDate(text string) driver.Value {
// but if a string is needed, [database/sql] will recover the same string.
func maybeTime(text string) driver.Value {
// Weed out (some) values that can't possibly be
// [time.RFC3339Nano] timestamps.
if len(text) < len("2006-01-02T15:04:05Z") {

View File

@@ -5,7 +5,8 @@ import (
"time"
)
func Fuzz_maybeDate(f *testing.F) {
// This checks that any string can be recovered as the same string.
func Fuzz_maybeTime_1(f *testing.F) {
f.Add("")
f.Add(" ")
f.Add("SQLite")
@@ -21,7 +22,7 @@ func Fuzz_maybeDate(f *testing.F) {
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
f.Fuzz(func(t *testing.T, str string) {
value := maybeDate(str)
value := maybeTime(str)
switch v := value.(type) {
case time.Time:
@@ -44,3 +45,56 @@ func Fuzz_maybeDate(f *testing.F) {
}
})
}
// This checks that any [time.Time] can be recovered as a [time.Time],
// with nanosecond accuracy, and preserving any timezone offset.
func Fuzz_maybeTime_2(f *testing.F) {
f.Add(0, 0)
f.Add(0, 1)
f.Add(0, -1)
f.Add(0, 999_999_999)
f.Add(0, 1_000_000_000)
f.Add(7956915742, 222_222_222) // twosday
f.Add(639095955742, 222_222_222) // twosday, year 22222AD
f.Add(-763421161058, 222_222_222) // twosday, year 22222BC
checkTime := func(t *testing.T, date time.Time) {
value := maybeTime(date.Format(time.RFC3339Nano))
switch v := value.(type) {
case time.Time:
// Make sure times round-trip to the same time:
if !v.Equal(date) {
t.Fatalf("did not round-trip: %v", date)
}
// Make with the same zone offset:
_, off1 := v.Zone()
_, off2 := date.Zone()
if off1 != off2 {
t.Fatalf("did not round-trip: %v", date)
}
case string:
t.Fatalf("was not recovered: %v", date)
default:
t.Fatalf("invalid type %T: %v", v, date)
}
}
f.Fuzz(func(t *testing.T, sec, nsec int) {
// Reduce the search space.
if 1e12 < sec || sec < -1e12 {
// Dates before 29000BC and after 33000AD; I think we're safe.
return
}
if 0 < nsec || nsec > 1e10 {
// Out of range nsec: [time.Time.Unix] handles these.
return
}
unix := time.Unix(int64(sec), int64(nsec))
checkTime(t, unix)
checkTime(t, unix.UTC())
checkTime(t, unix.In(time.FixedZone("", -8*3600)))
checkTime(t, unix.In(time.FixedZone("", +8*3600)))
})
}

75
driver_test.go Normal file
View File

@@ -0,0 +1,75 @@
package sqlite3_test
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
var db *sql.DB
func ExampleDriverConn() {
var err error
db, err = sql.Open("sqlite3", "demo.db")
if err != nil {
log.Fatal(err)
}
defer os.Remove("demo.db")
defer db.Close()
ctx := context.Background()
conn, err := db.Conn(ctx)
if err != nil {
log.Fatal(err)
}
defer conn.Close()
_, err = conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
log.Fatal(err)
}
res, err := conn.ExecContext(ctx, `INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(11))
if err != nil {
log.Fatal(err)
}
id, err := res.LastInsertId()
if err != nil {
log.Fatal(err)
}
err = conn.Raw(func(driverConn any) error {
conn := driverConn.(sqlite3.DriverConn)
savept := conn.Savepoint()
defer savept.Release(&err)
blob, err := conn.OpenBlob("main", "test", "col", id, true)
if err != nil {
return err
}
defer blob.Close()
_, err = fmt.Fprint(blob, "Hello BLOB!")
return err
})
if err != nil {
log.Fatal(err)
}
var msg string
err = conn.QueryRowContext(ctx, `SELECT col FROM test`).Scan(&msg)
if err != nil {
log.Fatal(err)
}
fmt.Println(msg)
// Output:
// Hello BLOB!
}

15
embed/README.md Normal file
View File

@@ -0,0 +1,15 @@
# Embeddable WASM build of SQLite
This folder includes an embeddable WASM build of SQLite 3.41.0 for use with
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
The following optional features are compiled in:
- math functions
- FTS3/4/5
- JSON
- R*Tree
- GeoPoly
See the [configuration options](../sqlite3/sqlite_cfg.h).
Built using [`zig`](https://ziglang.org/) version 0.10.1.

View File

@@ -8,43 +8,15 @@ cd -P -- "$(dirname -- "$0")"
# build SQLite
zig cc --target=wasm32-wasi -flto -g0 -Os \
-o sqlite3.wasm ../sqlite3/*.c \
-o sqlite3.wasm ../sqlite3/main.c \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-D_HAVE_SQLITE_CONFIG_H \
-Wl,--export=malloc \
-Wl,--export=free \
-Wl,--export=malloc_destructor \
-Wl,--export=sqlite3_errcode \
-Wl,--export=sqlite3_errstr \
-Wl,--export=sqlite3_errmsg \
-Wl,--export=sqlite3_error_offset \
-Wl,--export=sqlite3_open_v2 \
-Wl,--export=sqlite3_close \
-Wl,--export=sqlite3_prepare_v3 \
-Wl,--export=sqlite3_finalize \
-Wl,--export=sqlite3_reset \
-Wl,--export=sqlite3_step \
-Wl,--export=sqlite3_exec \
-Wl,--export=sqlite3_clear_bindings \
-Wl,--export=sqlite3_bind_parameter_count \
-Wl,--export=sqlite3_bind_parameter_index \
-Wl,--export=sqlite3_bind_parameter_name \
-Wl,--export=sqlite3_bind_null \
-Wl,--export=sqlite3_bind_int64 \
-Wl,--export=sqlite3_bind_double \
-Wl,--export=sqlite3_bind_text64 \
-Wl,--export=sqlite3_bind_blob64 \
-Wl,--export=sqlite3_bind_zeroblob64 \
-Wl,--export=sqlite3_column_count \
-Wl,--export=sqlite3_column_name \
-Wl,--export=sqlite3_column_type \
-Wl,--export=sqlite3_column_int64 \
-Wl,--export=sqlite3_column_double \
-Wl,--export=sqlite3_column_text \
-Wl,--export=sqlite3_column_blob \
-Wl,--export=sqlite3_column_bytes \
-Wl,--export=sqlite3_last_insert_rowid \
-Wl,--export=sqlite3_changes64 \
-Wl,--export=sqlite3_interrupt \
$(awk '{print "-Wl,--export="$0}' exports.txt)
# optimize SQLite
if which wasm-opt; then
wasm-opt -g -O -o sqlite3.tmp sqlite3.wasm
mv sqlite3.tmp sqlite3.wasm
fi

49
embed/exports.txt Normal file
View File

@@ -0,0 +1,49 @@
free
malloc
malloc_destructor
sqlite3_errcode
sqlite3_errstr
sqlite3_errmsg
sqlite3_error_offset
sqlite3_open_v2
sqlite3_close
sqlite3_close_v2
sqlite3_prepare_v3
sqlite3_finalize
sqlite3_reset
sqlite3_step
sqlite3_exec
sqlite3_clear_bindings
sqlite3_bind_parameter_count
sqlite3_bind_parameter_index
sqlite3_bind_parameter_name
sqlite3_bind_null
sqlite3_bind_int64
sqlite3_bind_double
sqlite3_bind_text64
sqlite3_bind_blob64
sqlite3_bind_zeroblob64
sqlite3_column_count
sqlite3_column_name
sqlite3_column_type
sqlite3_column_int64
sqlite3_column_double
sqlite3_column_text
sqlite3_column_blob
sqlite3_column_bytes
sqlite3_blob_open
sqlite3_blob_close
sqlite3_blob_bytes
sqlite3_blob_read
sqlite3_blob_write
sqlite3_blob_reopen
sqlite3_get_autocommit
sqlite3_last_insert_rowid
sqlite3_changes64
sqlite3_backup_init
sqlite3_backup_step
sqlite3_backup_finish
sqlite3_backup_remaining
sqlite3_backup_pagecount
sqlite3_time_collation
sqlite3_interrupt_offset

View File

@@ -1,7 +1,9 @@
// Package embed embeds SQLite into your application.
//
// You can obtain this build of SQLite from:
// https://github.com/ncruces/go-sqlite3/tree/main/embed
// Importing package embed initializes the [sqlite3.Binary] variable
// with an appropriate build of SQLite:
//
// import _ "github.com/ncruces/go-sqlite3/embed"
package embed
import (

BIN
embed/sqlite3.wasm Executable file → Normal file

Binary file not shown.

142
error.go
View File

@@ -1,6 +1,7 @@
package sqlite3
import (
"fmt"
"runtime"
"strconv"
"strings"
@@ -50,29 +51,160 @@ func (e *Error) Error() string {
return b.String()
}
// Is tests whether this error matches a given [ErrorCode] or [ExtendedErrorCode].
//
// It makes it possible to do:
//
// if errors.Is(err, sqlite3.BUSY) {
// // ... handle BUSY
// }
func (e *Error) Is(err error) bool {
switch c := err.(type) {
case ErrorCode:
return c == e.Code()
case ExtendedErrorCode:
return c == e.ExtendedCode()
}
return false
}
// Temporary returns true for [BUSY] errors.
func (e *Error) Temporary() bool {
return e.Code() == BUSY
}
// Timeout returns true for [BUSY_TIMEOUT] errors.
func (e *Error) Timeout() bool {
return e.ExtendedCode() == BUSY_TIMEOUT
}
// SQL returns the SQL starting at the token that triggered a syntax error.
func (e *Error) SQL() string {
return e.sql
}
// Error implements the error interface.
func (e ErrorCode) Error() string {
switch e {
case _OK:
return "sqlite3: not an error"
case _ROW:
return "sqlite3: another row available"
case _DONE:
return "sqlite3: no more rows available"
case ERROR:
return "sqlite3: SQL logic error"
case INTERNAL:
break
case PERM:
return "sqlite3: access permission denied"
case ABORT:
return "sqlite3: query aborted"
case BUSY:
return "sqlite3: database is locked"
case LOCKED:
return "sqlite3: database table is locked"
case NOMEM:
return "sqlite3: out of memory"
case READONLY:
return "sqlite3: attempt to write a readonly database"
case INTERRUPT:
return "sqlite3: interrupted"
case IOERR:
return "sqlite3: disk I/O error"
case CORRUPT:
return "sqlite3: database disk image is malformed"
case NOTFOUND:
return "sqlite3: unknown operation"
case FULL:
return "sqlite3: database or disk is full"
case CANTOPEN:
return "sqlite3: unable to open database file"
case PROTOCOL:
return "sqlite3: locking protocol"
case FORMAT:
break
case SCHEMA:
return "sqlite3: database schema has changed"
case TOOBIG:
return "sqlite3: string or blob too big"
case CONSTRAINT:
return "sqlite3: constraint failed"
case MISMATCH:
return "sqlite3: datatype mismatch"
case MISUSE:
return "sqlite3: bad parameter or other API misuse"
case NOLFS:
break
case AUTH:
return "sqlite3: authorization denied"
case EMPTY:
break
case RANGE:
return "sqlite3: column index out of range"
case NOTADB:
return "sqlite3: file is not a database"
case NOTICE:
return "sqlite3: notification message"
case WARNING:
return "sqlite3: warning message"
}
return "sqlite3: unknown error"
}
// Temporary returns true for [BUSY] errors.
func (e ErrorCode) Temporary() bool {
return e == BUSY
}
// Error implements the error interface.
func (e ExtendedErrorCode) Error() string {
switch x := ErrorCode(e); {
case e == ABORT_ROLLBACK:
return "sqlite3: abort due to ROLLBACK"
case x < _ROW:
return x.Error()
case e == _ROW:
return "sqlite3: another row available"
case e == _DONE:
return "sqlite3: no more rows available"
}
return "sqlite3: unknown error"
}
// Is tests whether this error matches a given [ErrorCode].
func (e ExtendedErrorCode) Is(err error) bool {
c, ok := err.(ErrorCode)
return ok && c == ErrorCode(e)
}
// Temporary returns true for [BUSY] errors.
func (e ExtendedErrorCode) Temporary() bool {
return ErrorCode(e) == BUSY
}
// Timeout returns true for [BUSY_TIMEOUT] errors.
func (e ExtendedErrorCode) Timeout() bool {
return e == BUSY_TIMEOUT
}
type errorString string
func (e errorString) Error() string { return string(e) }
const (
binaryErr = errorString("sqlite3: no SQLite binary embed/set/loaded")
nilErr = errorString("sqlite3: invalid memory address or null pointer dereference")
oomErr = errorString("sqlite3: out of memory")
rangeErr = errorString("sqlite3: index out of range")
noNulErr = errorString("sqlite3: missing NUL terminator")
noGlobalErr = errorString("sqlite3: could not find global: ")
noFuncErr = errorString("sqlite3: could not find function: ")
binaryErr = errorString("sqlite3: no SQLite binary embed/set/loaded")
timeErr = errorString("sqlite3: invalid time value")
notImplErr = errorString("sqlite3: not implemented")
whenceErr = errorString("sqlite3: invalid whence")
offsetErr = errorString("sqlite3: invalid offset")
)
func assertErr() errorString {
@@ -82,3 +214,11 @@ func assertErr() errorString {
}
return errorString(msg)
}
func finalizer[T any](skip int) func(*T) {
msg := fmt.Sprintf("sqlite3: %T not closed", new(T))
if _, file, line, ok := runtime.Caller(skip + 1); ok && skip >= 0 {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return func(*T) { panic(errorString(msg)) }
}

View File

@@ -1,26 +1,152 @@
package sqlite3
import (
"errors"
"strings"
"testing"
)
func TestError(t *testing.T) {
err := Error{code: 0x8080}
if rc := err.Code(); rc != 0x80 {
t.Errorf("got %#x, want 0x80", rc)
}
if rc := err.ExtendedCode(); rc != 0x8080 {
t.Errorf("got %#x, want 0x8080", rc)
}
if s := err.Error(); s != "sqlite3: 32896" {
func Test_assertErr(t *testing.T) {
err := assertErr()
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:10)") {
t.Errorf("got %q", s)
}
}
func Test_assertErr(t *testing.T) {
err := assertErr()
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:22)") {
func TestError(t *testing.T) {
t.Parallel()
err := Error{code: 0x8080}
if rc := err.Code(); rc != 0x80 {
t.Errorf("got %#x, want 0x80", rc)
}
if !errors.Is(&err, ErrorCode(0x80)) {
t.Errorf("want true")
}
if rc := err.ExtendedCode(); rc != 0x8080 {
t.Errorf("got %#x, want 0x8080", rc)
}
if !errors.Is(&err, ExtendedErrorCode(0x8080)) {
t.Errorf("want true")
}
if s := err.Error(); s != "sqlite3: 32896" {
t.Errorf("got %q", s)
}
if !errors.Is(err.ExtendedCode(), ErrorCode(0x80)) {
t.Errorf("want true")
}
}
func TestError_Temporary(t *testing.T) {
t.Parallel()
tests := []struct {
name string
code uint64
want bool
}{
{"ERROR", uint64(ERROR), false},
{"BUSY", uint64(BUSY), true},
{"BUSY_RECOVERY", uint64(BUSY_RECOVERY), true},
{"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), true},
{"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
{
err := &Error{code: tt.code}
if got := err.Temporary(); got != tt.want {
t.Errorf("Error.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
}
}
{
err := ErrorCode(tt.code)
if got := err.Temporary(); got != tt.want {
t.Errorf("ErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
}
}
{
err := ExtendedErrorCode(tt.code)
if got := err.Temporary(); got != tt.want {
t.Errorf("ExtendedErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
}
}
})
}
}
func TestError_Timeout(t *testing.T) {
t.Parallel()
tests := []struct {
name string
code uint64
want bool
}{
{"ERROR", uint64(ERROR), false},
{"BUSY", uint64(BUSY), false},
{"BUSY_RECOVERY", uint64(BUSY_RECOVERY), false},
{"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), false},
{"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
{
err := &Error{code: tt.code}
if got := err.Timeout(); got != tt.want {
t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want)
}
}
{
err := ExtendedErrorCode(tt.code)
if got := err.Timeout(); got != tt.want {
t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want)
}
}
})
}
}
func Test_ErrorCode_Error(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
// Test all error codes.
for i := 0; i == int(ErrorCode(i)); i++ {
want := "sqlite3: "
r := db.call(db.api.errstr, uint64(i))
want += db.mem.readString(uint32(r[0]), _MAX_STRING)
got := ErrorCode(i).Error()
if got != want {
t.Fatalf("got %q, want %q, with %d", got, want, i)
}
}
}
func Test_ExtendedErrorCode_Error(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
// Test all extended error codes.
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
want := "sqlite3: "
r := db.call(db.api.errstr, uint64(i))
want += db.mem.readString(uint32(r[0]), _MAX_STRING)
got := ExtendedErrorCode(i).Error()
if got != want {
t.Fatalf("got %q, want %q, with %d", got, want, i)
}
}
}

View File

@@ -21,7 +21,7 @@ func Example() {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
log.Fatal(err)
}
@@ -30,6 +30,7 @@ func Example() {
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnInt(0), stmt.ColumnText(1))
@@ -47,7 +48,6 @@ func Example() {
if err != nil {
log.Fatal(err)
}
// Output:
// 0 go
// 1 zig

6
go.mod
View File

@@ -4,7 +4,9 @@ go 1.19
require (
github.com/ncruces/julianday v0.1.5
github.com/tetratelabs/wazero v1.0.0-pre.9
github.com/tetratelabs/wazero v1.0.0-rc.2
golang.org/x/sync v0.1.0
golang.org/x/sys v0.5.0
golang.org/x/sys v0.6.0
)
retract v0.4.0 // tagged from the wrong branch

8
go.sum
View File

@@ -1,8 +1,8 @@
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.0.0-pre.9 h1:2uVdi2bvTi/JQxG2cp3LRm2aRadd3nURn5jcfbvqZcw=
github.com/tetratelabs/wazero v1.0.0-pre.9/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
github.com/tetratelabs/wazero v1.0.0-rc.2 h1:OA3UUynnoqxrjCQ94mpAtdO4/oMxFQVNL2BXDMOc66Q=
github.com/tetratelabs/wazero v1.0.0-rc.2/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

34
mem.go
View File

@@ -11,17 +11,41 @@ type memory struct {
mod api.Module
}
func (m memory) view(ptr, size uint32) []byte {
func (m memory) view(ptr uint32, size uint64) []byte {
if ptr == 0 {
panic(nilErr)
}
buf, ok := m.mod.Memory().Read(ptr, size)
if size > math.MaxUint32 {
panic(rangeErr)
}
buf, ok := m.mod.Memory().Read(ptr, uint32(size))
if !ok {
panic(rangeErr)
}
return buf
}
func (m memory) readUint8(ptr uint32) uint8 {
if ptr == 0 {
panic(nilErr)
}
v, ok := m.mod.Memory().ReadByte(ptr)
if !ok {
panic(rangeErr)
}
return v
}
func (m memory) writeUint8(ptr uint32, v uint8) {
if ptr == 0 {
panic(nilErr)
}
ok := m.mod.Memory().WriteByte(ptr, v)
if !ok {
panic(rangeErr)
}
}
func (m memory) readUint32(ptr uint32) uint32 {
if ptr == 0 {
panic(nilErr)
@@ -33,7 +57,7 @@ func (m memory) readUint32(ptr uint32) uint32 {
return v
}
func (m memory) writeUint32(ptr, v uint32) {
func (m memory) writeUint32(ptr uint32, v uint32) {
if ptr == 0 {
panic(nilErr)
}
@@ -100,12 +124,12 @@ func (m memory) readString(ptr, maxlen uint32) string {
}
func (m memory) writeBytes(ptr uint32, b []byte) {
buf := m.view(ptr, uint32(len(b)))
buf := m.view(ptr, uint64(len(b)))
copy(buf, b)
}
func (m memory) writeString(ptr uint32, s string) {
buf := m.view(ptr, uint32(len(s)+1))
buf := m.view(ptr, uint64(len(s)+1))
buf[len(s)] = 0
copy(buf, s)
}

View File

@@ -19,6 +19,13 @@ func Test_memory_view_range(t *testing.T) {
t.Error("want panic")
}
func Test_memory_view_overflow(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.view(1, math.MaxInt64)
t.Error("want panic")
}
func Test_memory_readUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)

354
module.go Normal file
View File

@@ -0,0 +1,354 @@
// Package sqlite3 wraps the C SQLite API.
package sqlite3
import (
"context"
"crypto/rand"
"io"
"math"
"os"
"runtime"
"strconv"
"sync"
"sync/atomic"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
// Configure SQLite WASM.
//
// Importing package embed initializes these
// with an appropriate build of SQLite:
//
// import _ "github.com/ncruces/go-sqlite3/embed"
var (
Binary []byte // WASM binary to load.
Path string // Path to load the binary from.
)
var sqlite3 struct {
once sync.Once
runtime wazero.Runtime
compiled wazero.CompiledModule
instances atomic.Uint64
err error
}
func instantiateModule() (*module, error) {
ctx := context.Background()
sqlite3.once.Do(compileModule)
if sqlite3.err != nil {
return nil, sqlite3.err
}
name := "sqlite3-" + strconv.FormatUint(sqlite3.instances.Add(1), 10)
cfg := wazero.NewModuleConfig().WithName(name).
WithSysWalltime().WithSysNanotime().WithSysNanosleep().
WithOsyield(runtime.Gosched).
WithRandSource(rand.Reader)
mod, err := sqlite3.runtime.InstantiateModule(ctx, sqlite3.compiled, cfg)
if err != nil {
return nil, err
}
return newModule(mod)
}
func compileModule() {
ctx := context.Background()
sqlite3.runtime = wazero.NewRuntime(ctx)
vfsInstantiate(ctx, sqlite3.runtime)
bin := Binary
if bin == nil && Path != "" {
bin, sqlite3.err = os.ReadFile(Path)
if sqlite3.err != nil {
return
}
}
if bin == nil {
sqlite3.err = binaryErr
return
}
sqlite3.compiled, sqlite3.err = sqlite3.runtime.CompileModule(ctx, bin)
}
type module struct {
ctx context.Context
mem memory
api sqliteAPI
vfs io.Closer
}
func newModule(mod api.Module) (m *module, err error) {
m = &module{}
m.mem = memory{mod}
m.ctx, m.vfs = vfsContext(context.Background())
getFun := func(name string) api.Function {
f := mod.ExportedFunction(name)
if f == nil {
err = noFuncErr + errorString(name)
return nil
}
return f
}
getVal := func(name string) uint32 {
global := mod.ExportedGlobal(name)
if global == nil {
err = noGlobalErr + errorString(name)
return 0
}
return m.mem.readUint32(uint32(global.Get()))
}
m.api = sqliteAPI{
free: getFun("free"),
malloc: getFun("malloc"),
destructor: uint64(getVal("malloc_destructor")),
errcode: getFun("sqlite3_errcode"),
errstr: getFun("sqlite3_errstr"),
errmsg: getFun("sqlite3_errmsg"),
erroff: getFun("sqlite3_error_offset"),
open: getFun("sqlite3_open_v2"),
close: getFun("sqlite3_close"),
closeZombie: getFun("sqlite3_close_v2"),
prepare: getFun("sqlite3_prepare_v3"),
finalize: getFun("sqlite3_finalize"),
reset: getFun("sqlite3_reset"),
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
bindName: getFun("sqlite3_bind_parameter_name"),
bindNull: getFun("sqlite3_bind_null"),
bindInteger: getFun("sqlite3_bind_int64"),
bindFloat: getFun("sqlite3_bind_double"),
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
columnInteger: getFun("sqlite3_column_int64"),
columnFloat: getFun("sqlite3_column_double"),
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
autocommit: getFun("sqlite3_get_autocommit"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
changes: getFun("sqlite3_changes64"),
blobOpen: getFun("sqlite3_blob_open"),
blobClose: getFun("sqlite3_blob_close"),
blobReopen: getFun("sqlite3_blob_reopen"),
blobBytes: getFun("sqlite3_blob_bytes"),
blobRead: getFun("sqlite3_blob_read"),
blobWrite: getFun("sqlite3_blob_write"),
backupInit: getFun("sqlite3_backup_init"),
backupStep: getFun("sqlite3_backup_step"),
backupFinish: getFun("sqlite3_backup_finish"),
backupRemaining: getFun("sqlite3_backup_remaining"),
backupPageCount: getFun("sqlite3_backup_pagecount"),
timeCollation: getFun("sqlite3_time_collation"),
interrupt: getVal("sqlite3_interrupt_offset"),
}
if err != nil {
return nil, err
}
return m, nil
}
func (m *module) close() error {
err := m.mem.mod.Close(m.ctx)
m.vfs.Close()
return err
}
func (m *module) error(rc uint64, handle uint32, sql ...string) error {
if rc == _OK {
return nil
}
err := Error{code: rc}
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
panic(oomErr)
}
var r []uint64
r = m.call(m.api.errstr, rc)
if r != nil {
err.str = m.mem.readString(uint32(r[0]), _MAX_STRING)
}
r = m.call(m.api.errmsg, uint64(handle))
if r != nil {
err.msg = m.mem.readString(uint32(r[0]), _MAX_STRING)
}
if sql != nil {
r = m.call(m.api.erroff, uint64(handle))
if r != nil && r[0] != math.MaxUint32 {
err.sql = sql[0][r[0]:]
}
}
switch err.msg {
case err.str, "not an error":
err.msg = ""
}
return &err
}
func (m *module) call(fn api.Function, params ...uint64) []uint64 {
r, err := fn.Call(m.ctx, params...)
if err != nil {
// The module closed or panicked; release resources.
m.vfs.Close()
panic(err)
}
return r
}
func (m *module) free(ptr uint32) {
if ptr == 0 {
return
}
m.call(m.api.free, uint64(ptr))
}
func (m *module) new(size uint64) uint32 {
if size > _MAX_ALLOCATION_SIZE {
panic(oomErr)
}
r := m.call(m.api.malloc, size)
ptr := uint32(r[0])
if ptr == 0 && size != 0 {
panic(oomErr)
}
return ptr
}
func (m *module) newBytes(b []byte) uint32 {
if b == nil {
return 0
}
ptr := m.new(uint64(len(b)))
m.mem.writeBytes(ptr, b)
return ptr
}
func (m *module) newString(s string) uint32 {
ptr := m.new(uint64(len(s) + 1))
m.mem.writeString(ptr, s)
return ptr
}
func (m *module) newArena(size uint64) arena {
return arena{
m: m,
base: m.new(size),
size: uint32(size),
}
}
type arena struct {
m *module
base uint32
next uint32
size uint32
ptrs []uint32
}
func (a *arena) free() {
if a.m == nil {
return
}
a.reset()
a.m.free(a.base)
a.m = nil
}
func (a *arena) reset() {
for _, ptr := range a.ptrs {
a.m.free(ptr)
}
a.ptrs = nil
a.next = 0
}
func (a *arena) new(size uint64) uint32 {
if size <= uint64(a.size-a.next) {
ptr := a.base + a.next
a.next += uint32(size)
return ptr
}
ptr := a.m.new(size)
a.ptrs = append(a.ptrs, ptr)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint64(len(s) + 1))
a.m.mem.writeString(ptr, s)
return ptr
}
type sqliteAPI struct {
free api.Function
malloc api.Function
destructor uint64
errcode api.Function
errstr api.Function
errmsg api.Function
erroff api.Function
open api.Function
close api.Function
closeZombie api.Function
prepare api.Function
finalize api.Function
reset api.Function
step api.Function
exec api.Function
clearBindings api.Function
bindNull api.Function
bindCount api.Function
bindIndex api.Function
bindName api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
columnInteger api.Function
columnFloat api.Function
columnText api.Function
columnBlob api.Function
columnBytes api.Function
autocommit api.Function
lastRowid api.Function
changes api.Function
blobOpen api.Function
blobClose api.Function
blobReopen api.Function
blobBytes api.Function
blobRead api.Function
blobWrite api.Function
backupInit api.Function
backupStep api.Function
backupFinish api.Function
backupRemaining api.Function
backupPageCount api.Function
timeCollation api.Function
interrupt uint32
}

View File

@@ -6,31 +6,64 @@ import (
"testing"
)
func TestConn_new(t *testing.T) {
func TestConn_error_OOM(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
defer func() { _ = recover() }()
db.new(math.MaxUint32)
m.error(uint64(NOMEM), 0)
t.Error("want panic")
}
func TestConn_call_nil(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer func() { _ = recover() }()
m.call(m.api.free)
t.Error("want panic")
}
func TestConn_new(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer m.close()
testOOM := func(size uint64) {
defer func() { _ = recover() }()
m.new(size)
t.Error("want panic")
}
testOOM(math.MaxUint32)
testOOM(_MAX_ALLOCATION_SIZE)
}
func TestConn_newArena(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
arena := db.newArena(16)
defer arena.reset()
arena := m.newArena(16)
defer arena.free()
const title = "Lorem ipsum"
@@ -38,7 +71,7 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := db.mem.readString(ptr, math.MaxUint32); got != title {
if got := m.mem.readString(ptr, math.MaxUint32); got != title {
t.Errorf("got %q, want %q", got, title)
}
@@ -47,33 +80,34 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := db.mem.readString(ptr, math.MaxUint32); got != body {
if got := m.mem.readString(ptr, math.MaxUint32); got != body {
t.Errorf("got %q, want %q", got, body)
}
arena.free()
}
func TestConn_newBytes(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
ptr := db.newBytes(nil)
ptr := m.newBytes(nil)
if ptr != 0 {
t.Errorf("got %#x, want nullptr", ptr)
}
buf := []byte("sqlite3")
ptr = db.newBytes(buf)
ptr = m.newBytes(buf)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := buf
if got := db.mem.view(ptr, uint32(len(want))); !bytes.Equal(got, want) {
if got := m.mem.view(ptr, uint64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -81,25 +115,25 @@ func TestConn_newBytes(t *testing.T) {
func TestConn_newString(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
ptr := db.newString("")
ptr := m.newString("")
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
str := "sqlite3\000sqlite3"
ptr = db.newString(str)
ptr = m.newString(str)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := str + "\000"
if got := db.mem.view(ptr, uint32(len(want))); string(got) != want {
if got := m.mem.view(ptr, uint64(len(want))); string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -107,40 +141,40 @@ func TestConn_newString(t *testing.T) {
func TestConn_getString(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
ptr := db.newString("")
ptr := m.newString("")
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
str := "sqlite3" + "\000 drop this"
ptr = db.newString(str)
ptr = m.newString(str)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := "sqlite3"
if got := db.mem.readString(ptr, math.MaxUint32); got != want {
if got := m.mem.readString(ptr, math.MaxUint32); got != want {
t.Errorf("got %q, want %q", got, want)
}
if got := db.mem.readString(ptr, 0); got != "" {
if got := m.mem.readString(ptr, 0); got != "" {
t.Errorf("got %q, want empty", got)
}
func() {
defer func() { _ = recover() }()
db.mem.readString(ptr, uint32(len(want)/2))
m.mem.readString(ptr, uint32(len(want)/2))
t.Error("want panic")
}()
func() {
defer func() { _ = recover() }()
db.mem.readString(0, math.MaxUint32)
m.mem.readString(0, math.MaxUint32)
t.Error("want panic")
}()
}
@@ -148,18 +182,18 @@ func TestConn_getString(t *testing.T) {
func TestConn_free(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
db.free(0)
m.free(0)
ptr := db.new(1)
ptr := m.new(1)
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
db.free(ptr)
m.free(ptr)
}

3
sqlite3/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
sqlite3.c
sqlite3.h
sqlite3ext.h

View File

@@ -4,7 +4,7 @@ set -eo pipefail
cd -P -- "$(dirname -- "$0")"
if [ ! -f "sqlite3.c" ]; then
url="https://www.sqlite.org/2022/sqlite-amalgamation-3400100.zip"
url="https://sqlite.org/2023/sqlite-amalgamation-3410100.zip"
curl "$url" > sqlite.zip
unzip -d . sqlite.zip
mv sqlite-amalgamation-*/sqlite3* .

5
sqlite3/format.sh Executable file
View File

@@ -0,0 +1,5 @@
#!/usr/bin/env bash
cd -P -- "$(dirname -- "$0")"
shopt -s extglob
clang-format -i !(sqlite3*).@(c|h)

View File

@@ -1,102 +1,20 @@
#include <stdbool.h>
#include <stdlib.h>
#include <time.h>
#include <stddef.h>
#include "sqlite3.h"
#include "sqlite3.c"
//
#include "os.c"
#include "qsort.c"
#include "time.c"
sqlite3_destructor_type malloc_destructor = &free;
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);
int sqlite3_os_init() {
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
}
int main() {
int rc = sqlite3_initialize();
if (rc != SQLITE_OK) return 1;
}
int go_localtime(sqlite3_int64, struct tm *);
int go_randomness(sqlite3_vfs *, int nByte, char *zOut);
int go_sleep(sqlite3_vfs *, int microseconds);
int go_current_time(sqlite3_vfs *, double *);
int go_current_time_64(sqlite3_vfs *, sqlite3_int64 *);
int go_open(sqlite3_vfs *, sqlite3_filename zName, sqlite3_file *, int flags,
int *pOutFlags);
int go_delete(sqlite3_vfs *, const char *zName, int syncDir);
int go_access(sqlite3_vfs *, const char *zName, int flags, int *pResOut);
int go_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut);
struct go_file {
sqlite3_file base;
int id;
int eLock;
};
int go_close(sqlite3_file *);
int go_read(sqlite3_file *, void *, int iAmt, sqlite3_int64 iOfst);
int go_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
int go_truncate(sqlite3_file *, sqlite3_int64 size);
int go_sync(sqlite3_file *, int flags);
int go_file_size(sqlite3_file *, sqlite3_int64 *pSize);
int go_file_control(sqlite3_file *pFile, int op, void *pArg);
int go_lock(sqlite3_file *pFile, int eLock);
int go_unlock(sqlite3_file *pFile, int eLock);
int go_check_reserved_lock(sqlite3_file *pFile, int *pResOut);
static int no_lock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; }
static int no_unlock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; }
static int no_check_reserved_lock(sqlite3_file *pFile, int *pResOut) {
*pResOut = 0;
return SQLITE_OK;
}
static int no_file_control(sqlite3_file *pFile, int op, void *pArg) {
return SQLITE_NOTFOUND;
}
static int no_sector_size(sqlite3_file *pFile) { return 0; }
static int no_device_characteristics(sqlite3_file *pFile) { return 0; }
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
return go_localtime((sqlite3_int64)*pTime, pTm);
}
static int go_open_c(sqlite3_vfs *vfs, sqlite3_filename zName,
sqlite3_file *file, int flags, int *pOutFlags) {
static const sqlite3_io_methods go_io = {
.iVersion = 1,
.xClose = go_close,
.xRead = go_read,
.xWrite = go_write,
.xTruncate = go_truncate,
.xSync = go_sync,
.xFileSize = go_file_size,
.xLock = go_lock,
.xUnlock = go_unlock,
.xCheckReservedLock = go_check_reserved_lock,
.xFileControl = no_file_control,
.xSectorSize = no_sector_size,
.xDeviceCharacteristics = no_device_characteristics,
};
int rc = go_open(vfs, zName, file, flags, pOutFlags);
file->pMethods = (char)rc == SQLITE_OK ? &go_io : NULL;
return rc;
}
int sqlite3_os_init() {
static sqlite3_vfs go_vfs = {
.iVersion = 2,
.szOsFile = sizeof(struct go_file),
.mxPathname = 512,
.zName = "go",
.xOpen = go_open_c,
.xDelete = go_delete,
.xAccess = go_access,
.xFullPathname = go_full_pathname,
.xRandomness = go_randomness,
.xSleep = go_sleep,
.xCurrentTime = go_current_time,
.xCurrentTimeInt64 = go_current_time_64,
};
return sqlite3_vfs_register(&go_vfs, /*default=*/true);
}
sqlite3_destructor_type malloc_destructor = &free;

144
sqlite3/os.c Normal file
View File

@@ -0,0 +1,144 @@
#include <time.h>
#include "sqlite3.h"
int os_localtime(sqlite3_int64, struct tm *);
int os_randomness(sqlite3_vfs *, int nByte, char *zOut);
int os_sleep(sqlite3_vfs *, int microseconds);
int os_current_time(sqlite3_vfs *, double *);
int os_current_time_64(sqlite3_vfs *, sqlite3_int64 *);
int os_open(sqlite3_vfs *, sqlite3_filename zName, sqlite3_file *, int flags,
int *pOutFlags);
int os_delete(sqlite3_vfs *, const char *zName, int syncDir);
int os_access(sqlite3_vfs *, const char *zName, int flags, int *pResOut);
int os_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut);
struct os_file {
sqlite3_file base;
int id;
char lock;
char psow;
int lockTimeout;
};
static_assert(offsetof(struct os_file, id) == 4, "Unexpected offset");
static_assert(offsetof(struct os_file, lock) == 8, "Unexpected offset");
static_assert(offsetof(struct os_file, psow) == 9, "Unexpected offset");
static_assert(offsetof(struct os_file, lockTimeout) == 12, "Unexpected offset");
int os_close(sqlite3_file *);
int os_read(sqlite3_file *, void *, int iAmt, sqlite3_int64 iOfst);
int os_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
int os_truncate(sqlite3_file *, sqlite3_int64 size);
int os_sync(sqlite3_file *, int flags);
int os_file_size(sqlite3_file *, sqlite3_int64 *pSize);
int os_file_control(sqlite3_file *, int op, void *pArg);
int os_lock(sqlite3_file *, int eLock);
int os_unlock(sqlite3_file *, int eLock);
int os_check_reserved_lock(sqlite3_file *, int *pResOut);
static int os_file_control_w(sqlite3_file *file, int op, void *pArg) {
struct os_file *pFile = (struct os_file *)file;
switch (op) {
case SQLITE_FCNTL_VFSNAME: {
*(char **)pArg = sqlite3_mprintf("%s", "os");
return SQLITE_OK;
}
case SQLITE_FCNTL_LOCKSTATE: {
*(int *)pArg = pFile->lock;
return SQLITE_OK;
}
case SQLITE_FCNTL_LOCK_TIMEOUT: {
int iOld = pFile->lockTimeout;
pFile->lockTimeout = *(int *)pArg;
*(int *)pArg = iOld;
return SQLITE_OK;
}
case SQLITE_FCNTL_POWERSAFE_OVERWRITE: {
if (*(int *)pArg < 0) {
*(int *)pArg = pFile->psow;
} else {
pFile->psow = *(int *)pArg;
}
return SQLITE_OK;
}
case SQLITE_FCNTL_SIZE_HINT:
case SQLITE_FCNTL_HAS_MOVED:
return os_file_control(file, op, pArg);
}
// Consider also implementing these opcodes (in use by SQLite):
// SQLITE_FCNTL_BUSYHANDLER
// SQLITE_FCNTL_COMMIT_PHASETWO
// SQLITE_FCNTL_PDB
// SQLITE_FCNTL_PRAGMA
// SQLITE_FCNTL_SYNC
return SQLITE_NOTFOUND;
}
static int os_sector_size(sqlite3_file *file) {
return SQLITE_DEFAULT_SECTOR_SIZE;
}
static int os_device_characteristics(sqlite3_file *file) {
struct os_file *pFile = (struct os_file *)file;
return pFile->psow ? SQLITE_IOCAP_POWERSAFE_OVERWRITE : 0;
}
static int os_open_w(sqlite3_vfs *vfs, sqlite3_filename zName,
sqlite3_file *file, int flags, int *pOutFlags) {
static const sqlite3_io_methods os_io = {
.iVersion = 1,
.xClose = os_close,
.xRead = os_read,
.xWrite = os_write,
.xTruncate = os_truncate,
.xSync = os_sync,
.xFileSize = os_file_size,
.xLock = os_lock,
.xUnlock = os_unlock,
.xCheckReservedLock = os_check_reserved_lock,
.xFileControl = os_file_control_w,
.xSectorSize = os_sector_size,
.xDeviceCharacteristics = os_device_characteristics,
};
memset(file, 0, sizeof(struct os_file));
int rc = os_open(vfs, zName, file, flags, pOutFlags);
if (rc) {
return rc;
}
struct os_file *pFile = (struct os_file *)file;
pFile->base.pMethods = &os_io;
if (flags & SQLITE_OPEN_MAIN_DB) {
pFile->psow =
sqlite3_uri_boolean(zName, "psow", SQLITE_POWERSAFE_OVERWRITE);
}
return SQLITE_OK;
}
sqlite3_vfs *os_vfs() {
static sqlite3_vfs os_vfs = {
.iVersion = 2,
.szOsFile = sizeof(struct os_file),
.mxPathname = 512,
.zName = "os",
.xOpen = os_open_w,
.xDelete = os_delete,
.xAccess = os_access,
.xFullPathname = os_full_pathname,
.xRandomness = os_randomness,
.xSleep = os_sleep,
.xCurrentTime = os_current_time,
.xCurrentTimeInt64 = os_current_time_64,
};
return &os_vfs;
}
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
return os_localtime((sqlite3_int64)*pTime, pTm);
}

14
sqlite3/qsort.c Normal file
View File

@@ -0,0 +1,14 @@
#include <stddef.h>
void qsort_r(void *, size_t, size_t,
int (*)(const void *, const void *, void *), void *);
typedef int (*cmpfun)(const void *, const void *);
static int wrapper_cmp(const void *v1, const void *v2, void *cmp) {
return ((cmpfun)cmp)(v1, v2);
}
void qsort(void *base, size_t nel, size_t width, cmpfun cmp) {
qsort_r(base, nel, width, wrapper_cmp, cmp);
}

View File

@@ -28,18 +28,35 @@
#define SQLITE_OMIT_AUTOINIT
#define SQLITE_USE_ALLOCA
// Because WASM does not support shared memory,
// SQLite disables it for WASM builds.
// We set the default locking mode to EXCLUSIVE instead.
// https://www.sqlite.org/wal.html#noshm
#undef SQLITE_OMIT_WAL
#ifndef SQLITE_DEFAULT_LOCKING_MODE
#define SQLITE_DEFAULT_LOCKING_MODE 1
#endif
// Recommended Extensions
// #define SQLITE_ENABLE_MATH_FUNCTIONS 1
// #define SQLITE_ENABLE_FTS3 1
// #define SQLITE_ENABLE_FTS3_PARENTHESIS 1
// #define SQLITE_ENABLE_FTS4 1
// #define SQLITE_ENABLE_FTS5 1
// #define SQLITE_ENABLE_RTREE 1
// #define SQLITE_ENABLE_GEOPOLY 1
#define SQLITE_ENABLE_MATH_FUNCTIONS 1
#define SQLITE_ENABLE_JSON1 1
#define SQLITE_ENABLE_FTS3 1
#define SQLITE_ENABLE_FTS3_PARENTHESIS 1
#define SQLITE_ENABLE_FTS4 1
#define SQLITE_ENABLE_FTS5 1
#define SQLITE_ENABLE_RTREE 1
#define SQLITE_ENABLE_GEOPOLY 1
// Need this to access WAL databases without the use of shared memory.
#define SQLITE_DEFAULT_LOCKING_MODE 1
// Snapshot
// #define SQLITE_ENABLE_SNAPSHOT 1
// Session Extension
// #define SQLITE_ENABLE_SESSION 1
// #define SQLITE_ENABLE_PREUPDATE_HOOK 1
// Resumable Bulk Update Extension
// #define SQLITE_ENABLE_RBU 1
// Implemented in Go.
int localtime_s(struct tm *const pTm, time_t const *const pTime);

31
sqlite3/time.c Normal file
View File

@@ -0,0 +1,31 @@
#include <string.h>
#include "sqlite3.h"
static int time_collation(void *pArg, int nKey1, const void *pKey1, int nKey2,
const void *pKey2) {
// Remove a Z suffix if one key is no longer than the other.
// A Z suffix collates before any character but after the empty string.
// This avoids making different keys equal.
const int nK1 = nKey1;
const int nK2 = nKey2;
const char *pK1 = (const char *)pKey1;
const char *pK2 = (const char *)pKey2;
if (nK1 && nK1 <= nK2 && pK1[nK1 - 1] == 'Z') {
nKey1--;
}
if (nK2 && nK2 <= nK1 && pK2[nK2 - 1] == 'Z') {
nKey2--;
}
int n = nKey1 < nKey2 ? nKey1 : nKey2;
int rc = memcmp(pKey1, pKey2, n);
if (rc == 0) {
rc = nKey1 - nKey2;
}
return rc;
}
int sqlite3_time_collation(sqlite3 *db) {
return sqlite3_create_collation(db, "TIME", SQLITE_UTF8, 0, time_collation);
}

142
stmt.go
View File

@@ -16,7 +16,7 @@ type Stmt struct {
// Close destroys the prepared statement object.
//
// It is safe to close a nil, zero or closed prepared statement.
// It is safe to close a nil, zero or closed Stmt.
//
// https://www.sqlite.org/c3ref/finalize.html
func (s *Stmt) Close() error {
@@ -24,10 +24,7 @@ func (s *Stmt) Close() error {
return nil
}
r, err := s.c.api.finalize.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r := s.c.call(s.c.api.finalize, uint64(s.handle))
s.handle = 0
return s.c.error(r[0])
@@ -37,10 +34,7 @@ func (s *Stmt) Close() error {
//
// https://www.sqlite.org/c3ref/reset.html
func (s *Stmt) Reset() error {
r, err := s.c.api.reset.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r := s.c.call(s.c.api.reset, uint64(s.handle))
s.err = nil
return s.c.error(r[0])
}
@@ -49,10 +43,7 @@ func (s *Stmt) Reset() error {
//
// https://www.sqlite.org/c3ref/clear_bindings.html
func (s *Stmt) ClearBindings() error {
r, err := s.c.api.clearBindings.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r := s.c.call(s.c.api.clearBindings, uint64(s.handle))
return s.c.error(r[0])
}
@@ -66,10 +57,8 @@ func (s *Stmt) ClearBindings() error {
//
// https://www.sqlite.org/c3ref/step.html
func (s *Stmt) Step() bool {
r, err := s.c.api.step.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
s.c.checkInterrupt()
r := s.c.call(s.c.api.step, uint64(s.handle))
if r[0] == _ROW {
return true
}
@@ -101,11 +90,8 @@ func (s *Stmt) Exec() error {
//
// https://www.sqlite.org/c3ref/bind_parameter_count.html
func (s *Stmt) BindCount() int {
r, err := s.c.api.bindCount.Call(s.c.ctx,
r := s.c.call(s.c.api.bindCount,
uint64(s.handle))
if err != nil {
panic(err)
}
return int(r[0])
}
@@ -116,11 +102,8 @@ func (s *Stmt) BindCount() int {
func (s *Stmt) BindIndex(name string) int {
defer s.c.arena.reset()
namePtr := s.c.arena.string(name)
r, err := s.c.api.bindIndex.Call(s.c.ctx,
r := s.c.call(s.c.api.bindIndex,
uint64(s.handle), uint64(namePtr))
if err != nil {
panic(err)
}
return int(r[0])
}
@@ -129,11 +112,8 @@ func (s *Stmt) BindIndex(name string) int {
//
// https://www.sqlite.org/c3ref/bind_parameter_name.html
func (s *Stmt) BindName(param int) string {
r, err := s.c.api.bindName.Call(s.c.ctx,
r := s.c.call(s.c.api.bindName,
uint64(s.handle), uint64(param))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
@@ -168,11 +148,8 @@ func (s *Stmt) BindInt(param int, value int) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindInt64(param int, value int64) error {
r, err := s.c.api.bindInteger.Call(s.c.ctx,
r := s.c.call(s.c.api.bindInteger,
uint64(s.handle), uint64(param), uint64(value))
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
@@ -181,11 +158,8 @@ func (s *Stmt) BindInt64(param int, value int64) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindFloat(param int, value float64) error {
r, err := s.c.api.bindFloat.Call(s.c.ctx,
r := s.c.call(s.c.api.bindFloat,
uint64(s.handle), uint64(param), math.Float64bits(value))
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
@@ -195,13 +169,10 @@ func (s *Stmt) BindFloat(param int, value float64) error {
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindText(param int, value string) error {
ptr := s.c.newString(value)
r, err := s.c.api.bindText.Call(s.c.ctx,
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
s.c.api.destructor, _UTF8)
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
@@ -212,13 +183,10 @@ func (s *Stmt) BindText(param int, value string) error {
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindBlob(param int, value []byte) error {
ptr := s.c.newBytes(value)
r, err := s.c.api.bindBlob.Call(s.c.ctx,
r := s.c.call(s.c.api.bindBlob,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
s.c.api.destructor)
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
@@ -227,11 +195,8 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindZeroBlob(param int, n int64) error {
r, err := s.c.api.bindZeroBlob.Call(s.c.ctx,
r := s.c.call(s.c.api.bindZeroBlob,
uint64(s.handle), uint64(param), uint64(n))
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
@@ -240,11 +205,8 @@ func (s *Stmt) BindZeroBlob(param int, n int64) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindNull(param int) error {
r, err := s.c.api.bindNull.Call(s.c.ctx,
r := s.c.call(s.c.api.bindNull,
uint64(s.handle), uint64(param))
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
@@ -270,11 +232,8 @@ func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
//
// https://www.sqlite.org/c3ref/column_count.html
func (s *Stmt) ColumnCount() int {
r, err := s.c.api.columnCount.Call(s.c.ctx,
r := s.c.call(s.c.api.columnCount,
uint64(s.handle))
if err != nil {
panic(err)
}
return int(r[0])
}
@@ -283,11 +242,8 @@ func (s *Stmt) ColumnCount() int {
//
// https://www.sqlite.org/c3ref/column_name.html
func (s *Stmt) ColumnName(col int) string {
r, err := s.c.api.columnName.Call(s.c.ctx,
r := s.c.call(s.c.api.columnName,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
@@ -301,11 +257,8 @@ func (s *Stmt) ColumnName(col int) string {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnType(col int) Datatype {
r, err := s.c.api.columnType.Call(s.c.ctx,
r := s.c.call(s.c.api.columnType,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
return Datatype(r[0])
}
@@ -336,11 +289,8 @@ func (s *Stmt) ColumnInt(col int) int {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnInt64(col int) int64 {
r, err := s.c.api.columnInteger.Call(s.c.ctx,
r := s.c.call(s.c.api.columnInteger,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
return int64(r[0])
}
@@ -349,11 +299,8 @@ func (s *Stmt) ColumnInt64(col int) int64 {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnFloat(col int) float64 {
r, err := s.c.api.columnFloat.Call(s.c.ctx,
r := s.c.call(s.c.api.columnFloat,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
return math.Float64frombits(r[0])
}
@@ -387,29 +334,20 @@ func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnText(col int) string {
r, err := s.c.api.columnText.Call(s.c.ctx,
r := s.c.call(s.c.api.columnText,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
r, err = s.c.api.errcode.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r[0])
return ""
}
r, err = s.c.api.columnBytes.Call(s.c.ctx,
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
mem := s.c.mem.view(ptr, uint32(r[0]))
mem := s.c.mem.view(ptr, r[0])
return string(mem)
}
@@ -419,28 +357,34 @@ func (s *Stmt) ColumnText(col int) string {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
r, err := s.c.api.columnBlob.Call(s.c.ctx,
r := s.c.call(s.c.api.columnBlob,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
r, err = s.c.api.errcode.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r[0])
return buf[0:0]
}
r, err = s.c.api.columnBytes.Call(s.c.ctx,
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
mem := s.c.mem.view(ptr, uint32(r[0]))
mem := s.c.mem.view(ptr, r[0])
return append(buf[0:0], mem...)
}
// Return true if stmt is an empty SQL statement.
// This is used as an optimization.
// It's OK to always return false here.
func emptyStatement(stmt string) bool {
for _, b := range []byte(stmt) {
switch b {
case ' ', '\n', '\r', '\t', '\v', '\f':
case ';':
default:
return false
}
}
return true
}

127
tests/backup_test.go Normal file
View File

@@ -0,0 +1,127 @@
package tests
import (
"path/filepath"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestBackup(t *testing.T) {
t.Parallel()
backupName := filepath.Join(t.TempDir(), "backup.db")
func() { // Create backup.
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}
err = db.Backup("main", backupName)
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}()
func() { // Restore backup.
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Restore("main", backupName)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
row := 0
ids := []int{0, 1, 2}
names := []string{"go", "zig", "whatever"}
for ; stmt.Step(); row++ {
id := stmt.ColumnInt(0)
name := stmt.ColumnText(1)
if id != ids[row] {
t.Errorf("got %d, want %d", id, ids[row])
}
if name != names[row] {
t.Errorf("got %q, want %q", name, names[row])
}
}
if row != 3 {
t.Errorf("got %d, want %d", row, len(ids))
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}()
func() { // Errors.
db, err := sqlite3.Open(backupName)
if err != nil {
t.Fatal(err)
}
defer db.Close()
tx, err := db.BeginExclusive()
if err != nil {
t.Fatal(err)
}
err = db.Restore("main", backupName)
if err == nil {
t.Fatal("want error")
}
err = tx.Rollback()
if err != nil {
t.Fatal(err)
}
err = db.Restore("main", backupName)
if err == nil {
t.Fatal("want error")
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}()
}

297
tests/blob_test.go Normal file
View File

@@ -0,0 +1,297 @@
package tests
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"io"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestBlob(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
size := blob.Size()
if size != 1024 {
t.Errorf("got %d, want 1024", size)
}
var data [1280]byte
_, err = rand.Read(data[:])
if err != nil {
t.Fatal(err)
}
_, err = io.Copy(blob, bytes.NewReader(data[:size/2]))
if err != nil {
t.Fatal(err)
}
_, err = io.Copy(blob, bytes.NewReader(data[:]))
if !errors.Is(err, sqlite3.ERROR) {
t.Fatal("want error")
}
_, err = io.Copy(blob, bytes.NewReader(data[size/2:size]))
if err != nil {
t.Fatal(err)
}
_, err = blob.Seek(size/4, io.SeekStart)
if err != nil {
t.Fatal(err)
}
if got, err := io.ReadAll(blob); err != nil {
t.Fatal(err)
} else if !bytes.Equal(got, data[size/4:size]) {
t.Errorf("got %q, want %q", got, data[size/4:size])
}
if n, err := blob.Read(make([]byte, 1)); n != 0 || err != io.EOF {
t.Errorf("got (%d, %v), want (0, EOF)", n, err)
}
if err := blob.Close(); err != nil {
t.Fatal(err)
}
if err := db.Close(); err != nil {
t.Fatal(err)
}
}
func TestBlob_invalid(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
_, err = db.OpenBlob("", "test", "col", db.LastInsertRowID(), false)
if !errors.Is(err, sqlite3.ERROR) {
t.Fatal("want error")
}
}
func TestBlob_Write_readonly(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), false)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
_, err = blob.Write([]byte("data"))
if !errors.Is(err, sqlite3.READONLY) {
t.Fatal("want error")
}
}
func TestBlob_Read_expired(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), false)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
err = db.Exec(`DELETE FROM test`)
if err != nil {
t.Fatal(err)
}
_, err = io.ReadAll(blob)
if !errors.Is(err, sqlite3.ABORT) {
t.Fatal("want error", err)
}
}
func TestBlob_Seek(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
_, err = blob.Seek(0, 10)
if err == nil {
t.Fatal("want error")
}
_, err = blob.Seek(-1, io.SeekCurrent)
if err == nil {
t.Fatal("want error")
}
n, err := blob.Seek(1, io.SeekEnd)
if err != nil {
t.Fatal(err)
}
if n != blob.Size()+1 {
t.Errorf("got %d, want %d", n, blob.Size())
}
_, err = blob.Write([]byte("data"))
if !errors.Is(err, sqlite3.ERROR) {
t.Fatal("want error")
}
}
func TestBlob_Reopen(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
var rowids []int64
for i := 0; i < 100; i++ {
err = db.Exec(`INSERT INTO test VALUES (zeroblob(10))`)
if err != nil {
t.Fatal(err)
}
rowids = append(rowids, db.LastInsertRowID())
}
var blob *sqlite3.Blob
for i, rowid := range rowids {
if i > 0 {
err = blob.Reopen(rowid)
} else {
blob, err = db.OpenBlob("main", "test", "col", rowid, true)
}
if err != nil {
t.Fatal(err)
}
_, err = fmt.Fprintf(blob, "blob %d\n", i)
if err != nil {
t.Fatal(err)
}
}
if err := blob.Close(); err != nil {
t.Fatal(err)
}
for i, rowid := range rowids {
if i > 0 {
err = blob.Reopen(rowid)
} else {
blob, err = db.OpenBlob("main", "test", "col", rowid, false)
}
if err != nil {
t.Fatal(err)
}
var got int
_, err = fmt.Fscanf(blob, "blob %d\n", &got)
if err != nil {
t.Fatal(err)
}
if got != i {
t.Errorf("got %d, want %d", got, i)
}
}
if err := blob.Close(); err != nil {
t.Fatal(err)
}
}

View File

@@ -41,7 +41,9 @@ func (t params) mustExec(sql string, args ...interface{}) sql.Result {
}
func (sqliteDB) RunTest(t *testing.T, fn func(params)) {
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "foo.db"))
db, err := sql.Open("sqlite3", "file:"+
filepath.Join(t.TempDir(), "foo.db")+
"?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)&_pragma=synchronous(off)")
if err != nil {
t.Fatalf("foo.db open fail: %v", err)
}
@@ -104,7 +106,7 @@ func testTxQuery(t params) {
}
defer tx.Rollback()
_, err = t.DB.Exec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))")
_, err = tx.Exec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))")
if err != nil {
t.Logf("cannot drop table "+TablePrefix+"foo: %s", err)
}

View File

@@ -13,19 +13,12 @@ import (
func TestConn_Open_dir(t *testing.T) {
t.Parallel()
_, err := sqlite3.Open(".")
_, err := sqlite3.OpenFlags(".", 0)
if err == nil {
t.Fatal("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.CANTOPEN {
t.Errorf("got %d, want sqlite3.CANTOPEN", rc)
}
if got := err.Error(); got != `sqlite3: unable to open database file` {
t.Error("got message: ", got)
if !errors.Is(err, sqlite3.CANTOPEN) {
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
}
}
@@ -53,23 +46,21 @@ func TestConn_Close_BUSY(t *testing.T) {
if err == nil {
t.Fatal("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.BUSY {
t.Errorf("got %d, want sqlite3.BUSY", rc)
if !errors.Is(err, sqlite3.BUSY) {
t.Errorf("got %v, want sqlite3.BUSY", err)
}
var terr interface{ Temporary() bool }
if !errors.As(err, &terr) || !terr.Temporary() {
t.Error("not temporary", err)
}
if got := err.Error(); got != `sqlite3: database is locked: unable to close due to unfinalized statements or unfinished backups` {
t.Error("got message: ", got)
t.Error("got message:", got)
}
}
func TestConn_SetInterrupt(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
@@ -77,7 +68,7 @@ func TestConn_SetInterrupt(t *testing.T) {
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx.Done())
db.SetInterrupt(ctx)
// Interrupt doesn't interrupt this.
err = db.Exec(`SELECT 1`)
@@ -85,7 +76,7 @@ func TestConn_SetInterrupt(t *testing.T) {
t.Fatal(err)
}
db.SetInterrupt(nil)
db.SetInterrupt(context.Background())
stmt, _, err := db.Prepare(`
WITH RECURSIVE
@@ -103,40 +94,24 @@ func TestConn_SetInterrupt(t *testing.T) {
}
defer stmt.Close()
db.SetInterrupt(ctx)
cancel()
db.SetInterrupt(ctx.Done())
var serr *sqlite3.Error
// Interrupting works.
err = stmt.Exec()
if err != nil {
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
// Interrupting sticks.
err = db.Exec(`SELECT 1`)
if err != nil {
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
db.SetInterrupt(nil)
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
db.SetInterrupt(ctx)
// Interrupting can be cleared.
err = db.Exec(`SELECT 1`)
@@ -207,7 +182,7 @@ func TestConn_Prepare_invalid(t *testing.T) {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message: ", got)
t.Error("got message:", got)
}
_, _, err = db.Prepare(`SELECT * FRM sqlite_schema`)
@@ -221,9 +196,9 @@ func TestConn_Prepare_invalid(t *testing.T) {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := serr.SQL(); got != `FRM sqlite_schema` {
t.Error("got SQL: ", got)
t.Error("got SQL:", got)
}
if got := serr.Error(); got != `sqlite3: SQL logic error: near "FRM": syntax error` {
t.Error("got message: ", got)
t.Error("got message:", got)
}
}

View File

@@ -30,7 +30,7 @@ func testDB(t *testing.T, name string) {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}

View File

@@ -34,7 +34,7 @@ func TestDriver(t *testing.T) {
}
res, err := conn.ExecContext(ctx,
`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}

186
tests/mptest/mptest_test.go Normal file
View File

@@ -0,0 +1,186 @@
package mptest
import (
"bytes"
"context"
"crypto/rand"
"embed"
"io"
"io/fs"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
_ "unsafe"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
_ "github.com/ncruces/go-sqlite3"
)
//go:embed testdata/mptest.wasm
var binary []byte
//go:embed testdata/*.*test
var scripts embed.FS
//go:linkname vfsNewEnvModuleBuilder github.com/ncruces/go-sqlite3.vfsNewEnvModuleBuilder
func vfsNewEnvModuleBuilder(r wazero.Runtime) wazero.HostModuleBuilder
//go:linkname vfsContext github.com/ncruces/go-sqlite3.vfsContext
func vfsContext(ctx context.Context) (context.Context, io.Closer)
var (
rt wazero.Runtime
module wazero.CompiledModule
instances atomic.Uint64
)
func init() {
ctx := context.TODO()
rt = wazero.NewRuntime(ctx)
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
env := vfsNewEnvModuleBuilder(rt)
env.NewFunctionBuilder().WithFunc(system).Export("system")
_, err := env.Instantiate(ctx)
if err != nil {
panic(err)
}
module, err = rt.CompileModule(ctx, binary)
if err != nil {
panic(err)
}
}
func config(ctx context.Context) wazero.ModuleConfig {
name := strconv.FormatUint(instances.Add(1), 10)
log := ctx.Value(logger{}).(io.Writer)
fs, err := fs.Sub(scripts, "testdata")
if err != nil {
panic(err)
}
return wazero.NewModuleConfig().
WithName(name).WithStdout(log).WithStderr(log).WithFS(fs).
WithSysWalltime().WithSysNanotime().WithSysNanosleep().
WithOsyield(runtime.Gosched).
WithRandSource(rand.Reader)
}
func system(ctx context.Context, mod api.Module, ptr uint32) uint32 {
buf, _ := mod.Memory().Read(ptr, mod.Memory().Size()-ptr)
buf = buf[:bytes.IndexByte(buf, 0)]
args := strings.Split(string(buf), " ")
for i := range args {
args[i] = strings.Trim(args[i], `"`)
}
args = args[:len(args)-1]
cfg := config(ctx).WithArgs(args...)
go func() {
ctx, vfs := vfsContext(ctx)
rt.InstantiateModule(ctx, module, cfg)
vfs.Close()
}()
return 0
}
func Test_config01(t *testing.T) {
ctx, vfs := vfsContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "config01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
vfs.Close()
mod.Close(ctx)
}
func Test_config02(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
if os.Getenv("CI") != "" {
t.Skip("skipping in CI")
}
ctx, vfs := vfsContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "config02.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
vfs.Close()
mod.Close(ctx)
}
func Test_crash01(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
ctx, vfs := vfsContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "crash01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
vfs.Close()
mod.Close(ctx)
}
func Test_multiwrite01(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
ctx, vfs := vfsContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "multiwrite01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
vfs.Close()
mod.Close(ctx)
}
func newContext(t *testing.T) context.Context {
return context.WithValue(context.Background(), logger{}, &testWriter{T: t})
}
type logger struct{}
type testWriter struct {
*testing.T
buf []byte
mtx sync.Mutex
}
func (l *testWriter) Write(p []byte) (n int, err error) {
l.mtx.Lock()
defer l.mtx.Unlock()
l.buf = append(l.buf, p...)
for {
before, after, found := bytes.Cut(l.buf, []byte("\n"))
if !found {
return len(p), nil
}
l.Logf("%s", before)
l.buf = after
}
}

2
tests/mptest/testdata/.gitattributes vendored Normal file
View File

@@ -0,0 +1,2 @@
mptest.wasm filter=lfs diff=lfs merge=lfs -text
*.*test -crlf

1
tests/mptest/testdata/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
mptest.c

31
tests/mptest/testdata/build.sh vendored Executable file
View File

@@ -0,0 +1,31 @@
#!/usr/bin/env bash
set -eo pipefail
cd -P -- "$(dirname -- "$0")"
if [ ! -f "mptest.c" ]; then
curl -sOL "https://github.com/sqlite/sqlite/raw/version-3.41.1/mptest/mptest.c"
curl -sOL "https://github.com/sqlite/sqlite/raw/version-3.41.1/mptest/config01.test"
curl -sOL "https://github.com/sqlite/sqlite/raw/version-3.41.1/mptest/config02.test"
curl -sOL "https://github.com/sqlite/sqlite/raw/version-3.41.1/mptest/crash01.test"
curl -sOL "https://github.com/sqlite/sqlite/raw/version-3.41.1/mptest/crash02.subtest"
curl -sOL "https://github.com/sqlite/sqlite/raw/version-3.41.1/mptest/multiwrite01.test"
fi
zig cc --target=wasm32-wasi -flto -g0 -Os \
-o mptest.wasm main.c \
-I../../../sqlite3 \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-D_HAVE_SQLITE_CONFIG_H \
-DSQLITE_DEFAULT_SYNCHRONOUS=0 \
-DSQLITE_DEFAULT_LOCKING_MODE=0 \
-DHAVE_USLEEP -DSQLITE_NO_SYNC \
-DSQLITE_THREADSAFE=0 -DSQLITE_OMIT_LOAD_EXTENSION \
-D_WASI_EMULATED_GETPID -lwasi-emulated-getpid
if which wasm-opt; then
wasm-opt -g -O -o mptest.tmp mptest.wasm
mv mptest.tmp mptest.wasm
fi

46
tests/mptest/testdata/config01.test vendored Normal file
View File

@@ -0,0 +1,46 @@
/*
** Configure five tasks in different ways, then run tests.
*/
--if vfsname() GLOB 'unix'
PRAGMA page_size=8192;
--task 1
PRAGMA journal_mode=PERSIST;
PRAGMA mmap_size=0;
--end
--task 2
PRAGMA journal_mode=TRUNCATE;
PRAGMA mmap_size=28672;
--end
--task 3
PRAGMA journal_mode=MEMORY;
--end
--task 4
PRAGMA journal_mode=OFF;
--end
--task 4
PRAGMA mmap_size(268435456);
--end
--source multiwrite01.test
--wait all
PRAGMA page_size=16384;
VACUUM;
CREATE TABLE pgsz(taskid, sz INTEGER);
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--wait all
SELECT sz FROM pgsz;
--match 16384 16384 16384 16384 16384

123
tests/mptest/testdata/config02.test vendored Normal file
View File

@@ -0,0 +1,123 @@
/*
** Configure five tasks in different ways, then run tests.
*/
PRAGMA page_size=512;
--task 1
PRAGMA mmap_size=0;
--end
--task 2
PRAGMA mmap_size=28672;
--end
--task 3
PRAGMA mmap_size=8192;
--end
--task 4
PRAGMA mmap_size=65536;
--end
--task 5
PRAGMA mmap_size=268435456;
--end
--source multiwrite01.test
--source crash02.subtest
PRAGMA page_size=1024;
VACUUM;
CREATE TABLE pgsz(taskid, sz INTEGER);
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--source crash02.subtest
--wait all
SELECT sz FROM pgsz;
--match 1024 1024 1024 1024 1024
PRAGMA page_size=2048;
VACUUM;
DELETE FROM pgsz;
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--source crash02.subtest
--wait all
SELECT sz FROM pgsz;
--match 2048 2048 2048 2048 2048
PRAGMA page_size=8192;
VACUUM;
DELETE FROM pgsz;
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--source crash02.subtest
--wait all
SELECT sz FROM pgsz;
--match 8192 8192 8192 8192 8192
PRAGMA page_size=16384;
VACUUM;
DELETE FROM pgsz;
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--source crash02.subtest
--wait all
SELECT sz FROM pgsz;
--match 16384 16384 16384 16384 16384
PRAGMA auto_vacuum=FULL;
VACUUM;
--source multiwrite01.test
--source crash02.subtest
--wait all
PRAGMA auto_vacuum=FULL;
PRAGMA page_size=512;
VACUUM;
--source multiwrite01.test
--source crash02.subtest

106
tests/mptest/testdata/crash01.test vendored Normal file
View File

@@ -0,0 +1,106 @@
/* Test cases involving incomplete transactions that must be rolled back.
*/
--task 1
DROP TABLE IF EXISTS t1;
CREATE TABLE t1(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t1 VALUES(1, randomblob(2000));
INSERT INTO t1 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t1 SELECT a+2, randomblob(1500) FROM t1;
INSERT INTO t1 SELECT a+4, randomblob(1500) FROM t1;
INSERT INTO t1 SELECT a+8, randomblob(1500) FROM t1;
--sleep 1
INSERT INTO t1 SELECT a+16, randomblob(1500) FROM t1;
--sleep 1
INSERT INTO t1 SELECT a+32, randomblob(1500) FROM t1;
SELECT count(*) FROM t1;
--match 64
SELECT avg(length(b)) FROM t1;
--match 1500.0
--sleep 2
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t1;
--match 247
SELECT a FROM t1 WHERE b='x17y';
--match 17
CREATE INDEX t1b ON t1(b);
SELECT a FROM t1 WHERE b='x17y';
--match 17
SELECT a FROM t1 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--wait 1
--task 2
DROP TABLE IF EXISTS t2;
CREATE TABLE t2(a INTEGER PRIMARY KEY, b);
INSERT INTO t2 SELECT a, b FROM t1;
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t2;
--match 247
SELECT a FROM t2 WHERE b='x17y';
--match 17
CREATE INDEX t2b ON t2(b);
SELECT a FROM t2 WHERE b='x17y';
--match 17
SELECT a FROM t2 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 3
DROP TABLE IF EXISTS t3;
CREATE TABLE t3(a INTEGER PRIMARY KEY, b);
INSERT INTO t3 SELECT a, b FROM t1;
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t3;
--match 247
SELECT a FROM t3 WHERE b='x17y';
--match 17
CREATE INDEX t3b ON t3(b);
SELECT a FROM t3 WHERE b='x17y';
--match 17
SELECT a FROM t3 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 4
DROP TABLE IF EXISTS t4;
CREATE TABLE t4(a INTEGER PRIMARY KEY, b);
INSERT INTO t4 SELECT a, b FROM t1;
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t4;
--match 247
SELECT a FROM t4 WHERE b='x17y';
--match 17
CREATE INDEX t4b ON t4(b);
SELECT a FROM t4 WHERE b='x17y';
--match 17
SELECT a FROM t4 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 5
DROP TABLE IF EXISTS t5;
CREATE TABLE t5(a INTEGER PRIMARY KEY, b);
INSERT INTO t5 SELECT a, b FROM t1;
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t5;
--match 247
SELECT a FROM t5 WHERE b='x17y';
--match 17
CREATE INDEX t5b ON t5(b);
SELECT a FROM t5 WHERE b='x17y';
--match 17
SELECT a FROM t5 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--wait all
/* After the database file has been set up, run the crash2 subscript
** multiple times. */
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest

53
tests/mptest/testdata/crash02.subtest vendored Normal file
View File

@@ -0,0 +1,53 @@
/*
** This script is called from crash01.test and config02.test and perhaps other
** script. After the database file has been set up, make a big rollback
** journal in client 1, then crash client 1.
** Then in the other clients, do an integrity check.
*/
--task 1 leave-hot-journal
--sleep 5
--finish
PRAGMA cache_size=10;
BEGIN;
UPDATE t1 SET b=randomblob(20000);
UPDATE t2 SET b=randomblob(20000);
UPDATE t3 SET b=randomblob(20000);
UPDATE t4 SET b=randomblob(20000);
UPDATE t5 SET b=randomblob(20000);
UPDATE t1 SET b=NULL;
UPDATE t2 SET b=NULL;
UPDATE t3 SET b=NULL;
UPDATE t4 SET b=NULL;
UPDATE t5 SET b=NULL;
--print Task one crashing an incomplete transaction
--exit 1
--end
--task 2 integrity_check-2
SELECT count(*) FROM t1;
--match 64
--sleep 100
PRAGMA integrity_check(10);
--match ok
--end
--task 3 integrity_check-3
SELECT count(*) FROM t1;
--match 64
--sleep 100
PRAGMA integrity_check(10);
--match ok
--end
--task 4 integrity_check-4
SELECT count(*) FROM t1;
--match 64
--sleep 100
PRAGMA integrity_check(10);
--match ok
--end
--task 5 integrity_check-5
SELECT count(*) FROM t1;
--match 64
--sleep 100
PRAGMA integrity_check(10);
--match ok
--end
--wait all

24
tests/mptest/testdata/main.c vendored Normal file
View File

@@ -0,0 +1,24 @@
#include <stdbool.h>
#include <stddef.h>
#include "sqlite3.c"
//
#include "os.c"
#include "qsort.c"
#include "time.c"
sqlite3_destructor_type malloc_destructor = &free;
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);
int sqlite3_os_init() {
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
}
__attribute__((constructor)) void premain() { sqlite3_initialize(); }
static int dont_unlink(const char *pathname) { return 0; }
#define sqlite3_enable_load_extension(...)
#define sqlite3_trace(...)
#define unlink dont_unlink
#undef UNUSED_PARAMETER
#include "mptest.c"

3
tests/mptest/testdata/mptest.wasm vendored Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:af307c3555fcf5f78e07a079d8f099400c3e013508f30f5efbe3b7f259be2092
size 969309

415
tests/mptest/testdata/multiwrite01.test vendored Normal file
View File

@@ -0,0 +1,415 @@
/*
** This script sets up five different tasks all writing and updating
** the database at the same time, but each in its own table.
*/
--task 1 build-t1
DROP TABLE IF EXISTS t1;
CREATE TABLE t1(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t1 VALUES(1, randomblob(2000));
INSERT INTO t1 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t1 SELECT a+2, randomblob(1500) FROM t1;
INSERT INTO t1 SELECT a+4, randomblob(1500) FROM t1;
INSERT INTO t1 SELECT a+8, randomblob(1500) FROM t1;
--sleep 1
INSERT INTO t1 SELECT a+16, randomblob(1500) FROM t1;
--sleep 1
INSERT INTO t1 SELECT a+32, randomblob(1500) FROM t1;
SELECT count(*) FROM t1;
--match 64
SELECT avg(length(b)) FROM t1;
--match 1500.0
--sleep 2
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t1;
--match 247
SELECT a FROM t1 WHERE b='x17y';
--match 17
CREATE INDEX t1b ON t1(b);
SELECT a FROM t1 WHERE b='x17y';
--match 17
SELECT a FROM t1 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 2 build-t2
DROP TABLE IF EXISTS t2;
CREATE TABLE t2(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t2 VALUES(1, randomblob(2000));
INSERT INTO t2 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t2 SELECT a+2, randomblob(1500) FROM t2;
INSERT INTO t2 SELECT a+4, randomblob(1500) FROM t2;
INSERT INTO t2 SELECT a+8, randomblob(1500) FROM t2;
--sleep 1
INSERT INTO t2 SELECT a+16, randomblob(1500) FROM t2;
--sleep 1
INSERT INTO t2 SELECT a+32, randomblob(1500) FROM t2;
SELECT count(*) FROM t2;
--match 64
SELECT avg(length(b)) FROM t2;
--match 1500.0
--sleep 2
UPDATE t2 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t2;
--match 247
SELECT a FROM t2 WHERE b='x17y';
--match 17
CREATE INDEX t2b ON t2(b);
SELECT a FROM t2 WHERE b='x17y';
--match 17
SELECT a FROM t2 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 3 build-t3
DROP TABLE IF EXISTS t3;
CREATE TABLE t3(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t3 VALUES(1, randomblob(2000));
INSERT INTO t3 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t3 SELECT a+2, randomblob(1500) FROM t3;
INSERT INTO t3 SELECT a+4, randomblob(1500) FROM t3;
INSERT INTO t3 SELECT a+8, randomblob(1500) FROM t3;
--sleep 1
INSERT INTO t3 SELECT a+16, randomblob(1500) FROM t3;
--sleep 1
INSERT INTO t3 SELECT a+32, randomblob(1500) FROM t3;
SELECT count(*) FROM t3;
--match 64
SELECT avg(length(b)) FROM t3;
--match 1500.0
--sleep 2
UPDATE t3 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t3;
--match 247
SELECT a FROM t3 WHERE b='x17y';
--match 17
CREATE INDEX t3b ON t3(b);
SELECT a FROM t3 WHERE b='x17y';
--match 17
SELECT a FROM t3 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 4 build-t4
DROP TABLE IF EXISTS t4;
CREATE TABLE t4(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t4 VALUES(1, randomblob(2000));
INSERT INTO t4 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t4 SELECT a+2, randomblob(1500) FROM t4;
INSERT INTO t4 SELECT a+4, randomblob(1500) FROM t4;
INSERT INTO t4 SELECT a+8, randomblob(1500) FROM t4;
--sleep 1
INSERT INTO t4 SELECT a+16, randomblob(1500) FROM t4;
--sleep 1
INSERT INTO t4 SELECT a+32, randomblob(1500) FROM t4;
SELECT count(*) FROM t4;
--match 64
SELECT avg(length(b)) FROM t4;
--match 1500.0
--sleep 2
UPDATE t4 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t4;
--match 247
SELECT a FROM t4 WHERE b='x17y';
--match 17
CREATE INDEX t4b ON t4(b);
SELECT a FROM t4 WHERE b='x17y';
--match 17
SELECT a FROM t4 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 5 build-t5
DROP TABLE IF EXISTS t5;
CREATE TABLE t5(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t5 VALUES(1, randomblob(2000));
INSERT INTO t5 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t5 SELECT a+2, randomblob(1500) FROM t5;
INSERT INTO t5 SELECT a+4, randomblob(1500) FROM t5;
INSERT INTO t5 SELECT a+8, randomblob(1500) FROM t5;
--sleep 1
INSERT INTO t5 SELECT a+16, randomblob(1500) FROM t5;
--sleep 1
INSERT INTO t5 SELECT a+32, randomblob(1500) FROM t5;
SELECT count(*) FROM t5;
--match 64
SELECT avg(length(b)) FROM t5;
--match 1500.0
--sleep 2
UPDATE t5 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t5;
--match 247
SELECT a FROM t5 WHERE b='x17y';
--match 17
CREATE INDEX t5b ON t5(b);
SELECT a FROM t5 WHERE b='x17y';
--match 17
SELECT a FROM t5 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--wait all
SELECT count(*), sum(length(b)) FROM t1;
--match 64 247
SELECT count(*), sum(length(b)) FROM t2;
--match 64 247
SELECT count(*), sum(length(b)) FROM t3;
--match 64 247
SELECT count(*), sum(length(b)) FROM t4;
--match 64 247
SELECT count(*), sum(length(b)) FROM t5;
--match 64 247
--task 1
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 5
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 3
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 2
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 4
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--wait all
--task 5
DROP INDEX t5b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t5b ON t5(b DESC);
--end
--task 3
DROP INDEX t3b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t3b ON t3(b DESC);
--end
--task 1
DROP INDEX t1b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t1b ON t1(b DESC);
--end
--task 2
DROP INDEX t2b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t2b ON t2(b DESC);
--end
--task 4
DROP INDEX t4b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t4b ON t4(b DESC);
--end
--wait all
--task 1
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 5
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 3
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 2
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 4
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--wait all
VACUUM;
PRAGMA integrity_check(10);
--match ok
--task 1
UPDATE t1 SET b=randomblob(20000);
--sleep 5
UPDATE t1 SET b='x'||a||'y';
SELECT a FROM t1 WHERE b='x63y';
--match 63
--end
--task 2
UPDATE t2 SET b=randomblob(20000);
--sleep 5
UPDATE t2 SET b='x'||a||'y';
SELECT a FROM t2 WHERE b='x63y';
--match 63
--end
--task 3
UPDATE t3 SET b=randomblob(20000);
--sleep 5
UPDATE t3 SET b='x'||a||'y';
SELECT a FROM t3 WHERE b='x63y';
--match 63
--end
--task 4
UPDATE t4 SET b=randomblob(20000);
--sleep 5
UPDATE t4 SET b='x'||a||'y';
SELECT a FROM t4 WHERE b='x63y';
--match 63
--end
--task 5
UPDATE t5 SET b=randomblob(20000);
--sleep 5
UPDATE t5 SET b='x'||a||'y';
SELECT a FROM t5 WHERE b='x63y';
--match 63
--end
--wait all
--task 1
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--task 5
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--task 3
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--task 2
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--task 4
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--wait all

View File

@@ -1,7 +1,6 @@
package tests
import (
"errors"
"io"
"os"
"os/exec"
@@ -15,14 +14,21 @@ import (
)
func TestParallel(t *testing.T) {
var iter int
if testing.Short() {
iter = 1000
} else {
iter = 5000
}
name := filepath.Join(t.TempDir(), "test.db")
testParallel(t, name, 100)
testParallel(t, name, iter)
testIntegrity(t, name)
}
func TestMultiProcess(t *testing.T) {
if testing.Short() {
t.Skip()
t.Skip("skipping in short mode")
}
name := filepath.Join(t.TempDir(), "test.db")
@@ -46,10 +52,6 @@ func TestMultiProcess(t *testing.T) {
testParallel(t, name, 1000)
if err := cmd.Wait(); err != nil {
t.Error(err)
var eerr *exec.ExitError
if errors.As(err, &eerr) {
t.Error(eerr.Stderr)
}
}
testIntegrity(t, name)
}
@@ -72,8 +74,10 @@ func testParallel(t *testing.T, name string, n int) {
defer db.Close()
err = db.Exec(`
PRAGMA locking_mode = NORMAL;
PRAGMA busy_timeout = 10000;
PRAGMA busy_timeout=10000;
PRAGMA synchronous=off;
PRAGMA locking_mode=normal;
PRAGMA journal_mode=truncate;
`)
if err != nil {
return err
@@ -84,7 +88,7 @@ func testParallel(t *testing.T, name string, n int) {
return err
}
err = db.Exec(`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
return err
}
@@ -100,8 +104,8 @@ func testParallel(t *testing.T, name string, n int) {
defer db.Close()
err = db.Exec(`
PRAGMA locking_mode = NORMAL;
PRAGMA busy_timeout = 10000;
PRAGMA busy_timeout=10000;
PRAGMA locking_mode=normal;
`)
if err != nil {
return err
@@ -138,7 +142,7 @@ func testParallel(t *testing.T, name string, n int) {
}
var group errgroup.Group
group.SetLimit(4)
group.SetLimit(6)
for i := 0; i < n; i++ {
if i&7 != 7 {
group.Go(reader)
@@ -148,7 +152,7 @@ func testParallel(t *testing.T, name string, n int) {
}
err = group.Wait()
if err != nil {
t.Fatal(err)
t.Error(err)
}
}

View File

@@ -0,0 +1,92 @@
package speedtest1
import (
"bytes"
"context"
"crypto/rand"
"io"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"testing"
_ "embed"
_ "unsafe"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
_ "github.com/ncruces/go-sqlite3"
)
//go:embed testdata/speedtest1.wasm
var binary []byte
//go:linkname vfsNewEnvModuleBuilder github.com/ncruces/go-sqlite3.vfsNewEnvModuleBuilder
func vfsNewEnvModuleBuilder(r wazero.Runtime) wazero.HostModuleBuilder
//go:linkname vfsContext github.com/ncruces/go-sqlite3.vfsContext
func vfsContext(ctx context.Context) (context.Context, io.Closer)
var (
rt wazero.Runtime
module wazero.CompiledModule
output bytes.Buffer
options []string
)
func init() {
ctx := context.TODO()
rt = wazero.NewRuntime(ctx)
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
env := vfsNewEnvModuleBuilder(rt)
_, err := env.Instantiate(ctx)
if err != nil {
panic(err)
}
module, err = rt.CompileModule(ctx, binary)
if err != nil {
panic(err)
}
}
func TestMain(m *testing.M) {
i := 1
options = append(options, "speedtest1")
for _, arg := range os.Args[1:] {
if strings.HasPrefix(arg, "-test.") {
os.Args[i] = arg
i++
} else {
options = append(options, arg)
}
}
os.Args = os.Args[:i]
code := m.Run()
io.Copy(os.Stderr, &output)
os.Exit(code)
}
func Benchmark_speedtest1(b *testing.B) {
output.Reset()
ctx, vfs := vfsContext(context.Background())
name := filepath.Join(b.TempDir(), "test.db")
args := append(options, "--size", strconv.Itoa(b.N), name)
cfg := wazero.NewModuleConfig().
WithArgs(args...).WithName("speedtest1").
WithStdout(&output).WithStderr(&output).
WithSysWalltime().WithSysNanotime().WithSysNanosleep().
WithOsyield(runtime.Gosched).
WithRandSource(rand.Reader)
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
b.Error(err)
}
vfs.Close()
mod.Close(ctx)
}

View File

@@ -0,0 +1 @@
speedtest1.wasm filter=lfs diff=lfs merge=lfs -text

1
tests/speedtest1/testdata/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
speedtest1.c

21
tests/speedtest1/testdata/build.sh vendored Executable file
View File

@@ -0,0 +1,21 @@
#!/usr/bin/env bash
set -eo pipefail
cd -P -- "$(dirname -- "$0")"
if [ ! -f "mptest.c" ]; then
curl -sOL "https://github.com/sqlite/sqlite/raw/version-3.41.1/test/speedtest1.c"
fi
zig cc --target=wasm32-wasi -flto -g0 -Os \
-o speedtest1.wasm main.c \
-I../../../sqlite3 \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-D_HAVE_SQLITE_CONFIG_H
if which wasm-opt; then
wasm-opt -g -O -o speedtest1.tmp speedtest1.wasm
mv speedtest1.tmp speedtest1.wasm
fi

20
tests/speedtest1/testdata/main.c vendored Normal file
View File

@@ -0,0 +1,20 @@
#include <stdbool.h>
#include <stddef.h>
#include "sqlite3.c"
//
#include "os.c"
#include "qsort.c"
#include "time.c"
sqlite3_destructor_type malloc_destructor = &free;
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);
int sqlite3_os_init() {
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
}
__attribute__((constructor)) void premain() { sqlite3_initialize(); }
#define randomFunc(args...) randomFunc2(args)
#include "speedtest1.c"

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:996c2445ad12c91dce1a59ab5929cfaa3a09a4ac82859fecc7de5e5f9c955f80
size 1001607

View File

@@ -22,7 +22,7 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`INSERT INTO test(col) VALUES(?)`)
stmt, _, err := db.Prepare(`INSERT INTO test VALUES (?)`)
if err != nil {
t.Fatal(err)
}
@@ -400,7 +400,7 @@ func TestStmt_BindName(t *testing.T) {
}
}
func TestStmt_Time(t *testing.T) {
func TestStmt_ColumnTime(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
@@ -430,23 +430,23 @@ func TestStmt_Time(t *testing.T) {
}
if now := time.Now(); stmt.Step() {
if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); !reference.Equal(got) {
if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); !got.Equal(reference) {
t.Errorf("got %v, want %v", got, reference)
}
if got := stmt.ColumnTime(1, sqlite3.TimeFormatAuto); !reference.Equal(got) {
if got := stmt.ColumnTime(1, sqlite3.TimeFormatAuto); !got.Equal(reference) {
t.Errorf("got %v, want %v", got, reference)
}
if got := stmt.ColumnTime(2, sqlite3.TimeFormatAuto); reference.Sub(got) > time.Millisecond {
if got := stmt.ColumnTime(2, sqlite3.TimeFormatAuto); got.Sub(reference).Abs() > time.Millisecond {
t.Errorf("got %v, want %v", got, reference)
}
if got := stmt.ColumnTime(3, sqlite3.TimeFormatAuto); now.Sub(got) > time.Second {
if got := stmt.ColumnTime(3, sqlite3.TimeFormatAuto); got.Sub(now).Abs() > time.Second {
t.Errorf("got %v, want %v", got, now)
}
if got := stmt.ColumnTime(4, sqlite3.TimeFormatAuto); now.Sub(got) > time.Second {
if got := stmt.ColumnTime(4, sqlite3.TimeFormatAuto); got.Sub(now).Abs() > time.Second {
t.Errorf("got %v, want %v", got, now)
}
if got := stmt.ColumnTime(5, sqlite3.TimeFormatAuto); now.Sub(got) > time.Millisecond {
if got := stmt.ColumnTime(5, sqlite3.TimeFormatAuto); got.Sub(now).Abs() > time.Second/10 {
t.Errorf("got %v, want %v", got, now)
}

169
tests/time_test.go Normal file
View File

@@ -0,0 +1,169 @@
package tests
import (
"reflect"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
)
func TestTimeFormat_Encode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt sqlite3.TimeFormat
time time.Time
want any
}{
{sqlite3.TimeFormatDefault, reference, "2013-10-07T04:23:19.12-04:00"},
{sqlite3.TimeFormatJulianDay, reference, 2456572.849526851851852},
{sqlite3.TimeFormatUnix, reference, int64(1381134199)},
{sqlite3.TimeFormatUnixFrac, reference, 1381134199.120},
{sqlite3.TimeFormatUnixMilli, reference, int64(1381134199_120)},
{sqlite3.TimeFormatUnixMicro, reference, int64(1381134199_120000)},
{sqlite3.TimeFormatUnixNano, reference, int64(1381134199_120000000)},
{sqlite3.TimeFormat7, reference, "2013-10-07T08:23:19.120"},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
if got := tt.fmt.Encode(tt.time); !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q.Encode(%v) = %v, want %v", tt.fmt, tt.time, got, tt.want)
}
})
}
}
func TestTimeFormat_Decode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt sqlite3.TimeFormat
val any
want time.Time
wantDelta time.Duration
wantErr bool
}{
{sqlite3.TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
{sqlite3.TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
{sqlite3.TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
{sqlite3.TimeFormatJulianDay, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
{sqlite3.TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
{sqlite3.TimeFormatUnix, int64(1381134199), reference, time.Second, false},
{sqlite3.TimeFormatUnix, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnix, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMilli, "1381134199120", reference, 0, false},
{sqlite3.TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
{sqlite3.TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
{sqlite3.TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMilli, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
{sqlite3.TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
{sqlite3.TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
{sqlite3.TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMicro, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
{sqlite3.TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
{sqlite3.TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
{sqlite3.TimeFormatUnixNano, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixNano, false, time.Time{}, 0, true},
{sqlite3.TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
{sqlite3.TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
{sqlite3.TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199", reference, time.Second, false},
{sqlite3.TimeFormatAuto, "1381134199120", reference, 0, false},
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, false},
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, false},
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, true},
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, false},
{sqlite3.TimeFormat3, false, time.Time{}, 0, true},
{sqlite3.TimeFormat9, false, time.Time{}, 0, true},
{sqlite3.TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
{sqlite3.TimeFormatDefault, false, time.Time{}, 0, true},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
got, err := tt.fmt.Decode(tt.val)
if (err != nil) != tt.wantErr {
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
return
}
if tt.want.Sub(got).Abs() > tt.wantDelta {
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
}
})
}
}
func TestDB_timeCollation(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS times (tstamp COLLATE TIME)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`INSERT INTO times VALUES (?), (?), (?)`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
stmt.BindTime(1, time.Unix(0, 0).UTC(), sqlite3.TimeFormatDefault)
stmt.BindTime(2, time.Unix(0, -1).UTC(), sqlite3.TimeFormatDefault)
stmt.BindTime(3, time.Unix(0, +1).UTC(), sqlite3.TimeFormatDefault)
stmt.Step()
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
stmt, _, err = db.Prepare(`SELECT tstamp FROM times ORDER BY tstamp`)
if err != nil {
t.Fatal(err)
}
var t0 time.Time
for stmt.Step() {
t1 := stmt.ColumnTime(0, sqlite3.TimeFormatAuto)
if t0.After(t1) {
t.Errorf("got %v after %v", t0, t1)
}
t0 = t1
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
}

535
tests/tx_test.go Normal file
View File

@@ -0,0 +1,535 @@
package tests
import (
"context"
"errors"
"testing"
"github.com/ncruces/go-sqlite3"
)
func TestConn_Transaction_exec(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
errFailed := errors.New("failed")
count := func() int {
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
return stmt.ColumnInt(0)
}
t.Fatal(stmt.Err())
return 0
}
insert := func(succeed bool) (err error) {
tx := db.Begin()
defer tx.End(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
t.Fatal(err)
}
if succeed {
return nil
}
return errFailed
}
err = insert(true)
if err != nil {
t.Fatal(err)
}
if got := count(); got != 1 {
t.Errorf("got %d, want 1", got)
}
err = insert(true)
if err != nil {
t.Fatal(err)
}
if got := count(); got != 2 {
t.Errorf("got %d, want 2", got)
}
err = insert(false)
if err != errFailed {
t.Errorf("got %v, want errFailed", err)
}
if got := count(); got != 2 {
t.Errorf("got %d, want 2", got)
}
}
func TestConn_Transaction_panic(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES ('one');`)
if err != nil {
t.Fatal(err)
}
panics := func() (err error) {
tx := db.Begin()
defer tx.End(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
return err
}
panic("omg!")
}
defer func() {
p := recover()
if p != "omg!" {
t.Errorf("got %v, want panic", p)
}
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
return
}
t.Fatal(stmt.Err())
}()
err = panics()
if err != nil {
t.Error(err)
}
}
func TestConn_Transaction_interrupt(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
tx, err := db.BeginImmediate()
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
}
tx.End(&err)
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
tx, err = db.BeginExclusive()
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (2)`)
if err != nil {
t.Fatal(err)
}
cancel()
_, err = db.BeginImmediate()
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = db.Exec(`INSERT INTO test VALUES (3)`)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = nil
tx.End(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
db.SetInterrupt(context.Background())
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
}
err = stmt.Err()
if err != nil {
t.Error(err)
}
}
func TestConn_Transaction_interrupted(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
cancel()
tx := db.Begin()
err = tx.Commit()
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = nil
tx.End(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
}
func TestConn_Transaction_rollback(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
tx := db.Begin()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`COMMIT`)
if err != nil {
t.Fatal(err)
}
tx.End(&err)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
}
err = stmt.Err()
if err != nil {
t.Error(err)
}
}
func TestConn_Savepoint_exec(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
errFailed := errors.New("failed")
count := func() int {
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
return stmt.ColumnInt(0)
}
t.Fatal(stmt.Err())
return 0
}
insert := func(succeed bool) (err error) {
defer db.Savepoint().Release(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
t.Fatal(err)
}
if succeed {
return nil
}
return errFailed
}
err = insert(true)
if err != nil {
t.Fatal(err)
}
if got := count(); got != 1 {
t.Errorf("got %d, want 1", got)
}
err = insert(true)
if err != nil {
t.Fatal(err)
}
if got := count(); got != 2 {
t.Errorf("got %d, want 2", got)
}
err = insert(false)
if err != errFailed {
t.Errorf("got %v, want errFailed", err)
}
if got := count(); got != 2 {
t.Errorf("got %d, want 2", got)
}
}
func TestConn_Savepoint_panic(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES ('one');`)
if err != nil {
t.Fatal(err)
}
panics := func() (err error) {
defer db.Savepoint().Release(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
return err
}
panic("omg!")
}
defer func() {
p := recover()
if p != "omg!" {
t.Errorf("got %v, want panic", p)
}
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
return
}
t.Fatal(stmt.Err())
}()
err = panics()
if err != nil {
t.Error(err)
}
}
func TestConn_Savepoint_interrupt(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
savept := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
}
savept.Release(&err)
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
savept1 := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (2)`)
if err != nil {
t.Fatal(err)
}
savept2 := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (3)`)
if err != nil {
t.Fatal(err)
}
cancel()
db.Savepoint().Release(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = db.Exec(`INSERT INTO test VALUES (4)`)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = context.Canceled
savept2.Release(&err)
if err != context.Canceled {
t.Fatal(err)
}
err = nil
savept1.Release(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
db.SetInterrupt(context.Background())
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
}
err = stmt.Err()
if err != nil {
t.Error(err)
}
}
func TestConn_Savepoint_rollback(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
savept := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`COMMIT`)
if err != nil {
t.Fatal(err)
}
savept.Release(&err)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
}
err = stmt.Err()
if err != nil {
t.Error(err)
}
}

View File

@@ -1,19 +1,23 @@
package sqlite3
package tests
import "testing"
import (
"testing"
"github.com/ncruces/go-sqlite3"
)
func TestDatatype_String(t *testing.T) {
t.Parallel()
tests := []struct {
data Datatype
data sqlite3.Datatype
want string
}{
{INTEGER, "INTEGER"},
{FLOAT, "FLOAT"},
{TEXT, "TEXT"},
{BLOB, "BLOB"},
{NULL, "NULL"},
{sqlite3.INTEGER, "INTEGER"},
{sqlite3.FLOAT, "FLOAT"},
{sqlite3.TEXT, "TEXT"},
{sqlite3.BLOB, "BLOB"},
{sqlite3.NULL, "NULL"},
{10, "10"},
}
for _, tt := range tests {

86
time.go
View File

@@ -11,6 +11,9 @@ import (
// TimeFormat specifies how to encode/decode time values.
//
// See the documentation for the [TimeFormatDefault] constant
// for formats recognized by SQLite.
//
// https://www.sqlite.org/lang_datefunc.html
type TimeFormat string
@@ -57,12 +60,31 @@ const (
// Encode encodes a time value using this format.
//
// [TimeFormatDefault] and [TimeFormatAuto] encode using [time.RFC3339Nano],
// with nanosecond accuracy, and preserving timezone.
// with nanosecond accuracy, and preserving any timezone offset.
//
// Formats that don't record the timezone
// This is the format used by the [database/sql] driver:
// [database/sql.Row.Scan] will decode as [time.Time]
// values encoded with [time.RFC3339Nano].
//
// Time values encoded with [time.RFC3339Nano] cannot be sorted as strings
// to produce a time-ordered sequence.
//
// Assuming that the time zones of the time values are the same (e.g., all in UTC),
// and expressed using the same string (e.g., all "Z" or all "+00:00"),
// use the TIME [collating sequence] to produce a time-ordered sequence.
//
// Otherwise, use [TimeFormat7] for time-ordered encoding.
//
// Formats [TimeFormat1] through [TimeFormat10]
// convert time values to UTC before encoding.
//
// Returns a string for the text formats,
// a float64 for [TimeFormatJulianDay] and [TimeFormatUnixFrac],
// or an int64 for the other numeric formats.
//
// https://www.sqlite.org/lang_datefunc.html
//
// [collating sequence]: https://www.sqlite.org/datatype3.html#collating_sequences
func (f TimeFormat) Encode(t time.Time) any {
switch f {
// Numeric formats
@@ -81,11 +103,13 @@ func (f TimeFormat) Encode(t time.Time) any {
// Special formats
case TimeFormatDefault, TimeFormatAuto:
f = time.RFC3339Nano
}
// SQLite assumes UTC if unspecified.
if !strings.Contains(string(f), "MST") &&
!strings.Contains(string(f), "Z07") &&
!strings.Contains(string(f), "-07") {
case
TimeFormat1, TimeFormat2,
TimeFormat3, TimeFormat4,
TimeFormat5, TimeFormat6,
TimeFormat7, TimeFormat8,
TimeFormat9, TimeFormat10:
t = t.UTC()
}
return t.Format(string(f))
@@ -93,8 +117,23 @@ func (f TimeFormat) Encode(t time.Time) any {
// Decode decodes a time value using this format.
//
// Decoding of SQLite recognized formats is lenient:
// timezones and fractional seconds are always optional.
// The time value can be a string, an int64, or a float64.
//
// Formats [TimeFormat8] through [TimeFormat10]
// (and [TimeFormat8TZ] through [TimeFormat10TZ])
// assume a date of 2000-01-01.
//
// The timezone indicator and fractional seconds are always optional
// for formats [TimeFormat2] through [TimeFormat10]
// (and [TimeFormat2TZ] through [TimeFormat10TZ]).
//
// [TimeFormatAuto] implements (and extends) the SQLite auto modifier.
// Julian day numbers are safe to use for historical dates,
// from 4712BC through 9999AD.
// Unix timestamps (expressed in seconds, milliseconds, microseconds, or nanoseconds)
// are safe to use for current events, from at least 1980 through at least 2260.
// Unix timestamps before 1980 and after 9999 may be misinterpreted as julian day numbers,
// or have the wrong time unit.
//
// https://www.sqlite.org/lang_datefunc.html
func (f TimeFormat) Decode(v any) (time.Time, error) {
@@ -263,14 +302,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
if !ok {
return time.Time{}, timeErr
}
f := string(f)
f = strings.TrimSuffix(f, "Z07:00")
f = strings.TrimSuffix(f, ".000")
t, err := time.Parse(f+"Z07:00", s)
if err != nil {
t, err = time.Parse(f, s)
}
return t, err
return f.parseRelaxed(s)
case
TimeFormat8, TimeFormat8TZ,
@@ -280,13 +312,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
if !ok {
return time.Time{}, timeErr
}
f := string(f)
f = strings.TrimSuffix(f, "Z07:00")
f = strings.TrimSuffix(f, ".000")
t, err := time.Parse(f+"Z07:00", s)
if err != nil {
t, err = time.Parse(f, s)
}
t, err := f.parseRelaxed(s)
return t.AddDate(2000, 0, 0), err
default:
@@ -294,10 +320,20 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
if !ok {
return time.Time{}, timeErr
}
f := string(f)
if f == "" {
f = time.RFC3339Nano
}
return time.Parse(f, s)
return time.Parse(string(f), s)
}
}
func (f TimeFormat) parseRelaxed(s string) (time.Time, error) {
fs := string(f)
fs = strings.TrimSuffix(fs, "Z07:00")
fs = strings.TrimSuffix(fs, ".000")
t, err := time.Parse(fs+"Z07:00", s)
if err != nil {
return time.Parse(fs, s)
}
return t, nil
}

View File

@@ -1,118 +0,0 @@
package sqlite3
import (
"reflect"
"testing"
"time"
)
func TestTimeFormat_Encode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt TimeFormat
time time.Time
want any
}{
{TimeFormatDefault, reference, "2013-10-07T04:23:19.12-04:00"},
{TimeFormatJulianDay, reference, 2456572.849526851851852},
{TimeFormatUnix, reference, int64(1381134199)},
{TimeFormatUnixFrac, reference, 1381134199.120},
{TimeFormatUnixMilli, reference, int64(1381134199_120)},
{TimeFormatUnixMicro, reference, int64(1381134199_120000)},
{TimeFormatUnixNano, reference, int64(1381134199_120000000)},
{TimeFormat7, reference, "2013-10-07T08:23:19.120"},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
if got := tt.fmt.Encode(tt.time); !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q.Encode(%v) = %v, want %v", tt.fmt, tt.time, got, tt.want)
}
})
}
}
func TestTimeFormat_Decode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
reftime := time.Date(2000, 1, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt TimeFormat
val any
want time.Time
wantDelta time.Duration
wantErr bool
}{
{TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
{TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
{TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
{TimeFormatJulianDay, false, time.Time{}, 0, true},
{TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
{TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
{TimeFormatUnix, int64(1381134199), reference, time.Second, false},
{TimeFormatUnix, "abc", time.Time{}, 0, true},
{TimeFormatUnix, false, time.Time{}, 0, true},
{TimeFormatUnixMilli, "1381134199120", reference, 0, false},
{TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
{TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
{TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
{TimeFormatUnixMilli, false, time.Time{}, 0, true},
{TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
{TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
{TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
{TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
{TimeFormatUnixMicro, false, time.Time{}, 0, true},
{TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
{TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
{TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
{TimeFormatUnixNano, "abc", time.Time{}, 0, true},
{TimeFormatUnixNano, false, time.Time{}, 0, true},
{TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
{TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
{TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199", reference, time.Second, false},
{TimeFormatAuto, "1381134199120", reference, 0, false},
{TimeFormatAuto, "1381134199120000", reference, 0, false},
{TimeFormatAuto, "1381134199120000000", reference, 0, false},
{TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{TimeFormatAuto, "04:23:19.12-04:00", reftime, 0, false},
{TimeFormatAuto, "abc", time.Time{}, 0, true},
{TimeFormatAuto, false, time.Time{}, 0, true},
{TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
{TimeFormat9, "04:23:19.12-04:00", reftime, 0, false},
{TimeFormat9, "08:23:19.12", reftime, 0, false},
{TimeFormat3, false, time.Time{}, 0, true},
{TimeFormat9, false, time.Time{}, 0, true},
{TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
{TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
{TimeFormatDefault, false, time.Time{}, 0, true},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
got, err := tt.fmt.Decode(tt.val)
if (err != nil) != tt.wantErr {
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
return
}
if tt.want.Sub(got).Abs() > tt.wantDelta {
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
}
})
}
}

202
tx.go Normal file
View File

@@ -0,0 +1,202 @@
package sqlite3
import (
"context"
"errors"
"fmt"
"math/rand"
"runtime"
"strconv"
)
// Tx is an in-progress database transaction.
//
// https://www.sqlite.org/lang_transaction.html
type Tx struct {
c *Conn
}
// Begin starts a deferred transaction.
//
// https://www.sqlite.org/lang_transaction.html
func (c *Conn) Begin() Tx {
// BEGIN even if interrupted.
err := c.txExecInterrupted(`BEGIN DEFERRED`)
if err != nil {
panic(err)
}
return Tx{c}
}
// BeginImmediate starts an immediate transaction.
//
// https://www.sqlite.org/lang_transaction.html
func (c *Conn) BeginImmediate() (Tx, error) {
err := c.Exec(`BEGIN IMMEDIATE`)
if err != nil {
return Tx{}, err
}
return Tx{c}, nil
}
// BeginExclusive starts an exclusive transaction.
//
// https://www.sqlite.org/lang_transaction.html
func (c *Conn) BeginExclusive() (Tx, error) {
err := c.Exec(`BEGIN EXCLUSIVE`)
if err != nil {
return Tx{}, err
}
return Tx{c}, nil
}
// End calls either [Tx.Commit] or [Tx.Rollback]
// depending on whether *error points to a nil or non-nil error.
//
// This is meant to be deferred:
//
// func doWork(conn *sqlite3.Conn) (err error) {
// tx := conn.Begin()
// defer tx.End(&err)
//
// // ... do work in the transaction
// }
//
// https://www.sqlite.org/lang_transaction.html
func (tx Tx) End(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
if (errp == nil || *errp == nil) && recovered == nil {
// Success path.
if tx.c.GetAutocommit() { // There is nothing to commit.
return
}
*errp = tx.Commit()
if *errp == nil {
return
}
// Fall through to the error path.
}
// Error path.
if tx.c.GetAutocommit() { // There is nothing to rollback.
return
}
err := tx.Rollback()
if err != nil {
panic(err)
}
}
// Commit commits the transaction.
//
// https://www.sqlite.org/lang_transaction.html
func (tx Tx) Commit() error {
return tx.c.Exec(`COMMIT`)
}
// Rollback rolls back the transaction,
// even if the connection has been interrupted.
//
// https://www.sqlite.org/lang_transaction.html
func (tx Tx) Rollback() error {
return tx.c.txExecInterrupted(`ROLLBACK`)
}
// Savepoint is a marker within a transaction
// that allows for partial rollback.
//
// https://www.sqlite.org/lang_savepoint.html
type Savepoint struct {
c *Conn
name string
}
// Savepoint establishes a new transaction savepoint.
//
// https://www.sqlite.org/lang_savepoint.html
func (c *Conn) Savepoint() Savepoint {
name := "sqlite3.Savepoint"
var pc [1]uintptr
if n := runtime.Callers(2, pc[:]); n > 0 {
frames := runtime.CallersFrames(pc[:n])
frame, _ := frames.Next()
if frame.Function != "" {
name = frame.Function
}
}
// Names can be reused; this makes catching bugs more likely.
name += "#" + strconv.Itoa(int(rand.Int31()))
err := c.txExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name))
if err != nil {
panic(err)
}
return Savepoint{c: c, name: name}
}
// Release releases the savepoint rolling back any changes
// if *error points to a non-nil error.
//
// This is meant to be deferred:
//
// func doWork(conn *sqlite3.Conn) (err error) {
// savept := conn.Savepoint()
// defer savept.Release(&err)
//
// // ... do work in the transaction
// }
func (s Savepoint) Release(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
if (errp == nil || *errp == nil) && recovered == nil {
// Success path.
if s.c.GetAutocommit() { // There is nothing to commit.
return
}
*errp = s.c.Exec(fmt.Sprintf("RELEASE %q;", s.name))
if *errp == nil {
return
}
// Fall through to the error path.
}
// Error path.
if s.c.GetAutocommit() { // There is nothing to rollback.
return
}
// ROLLBACK and RELEASE even if interrupted.
err := s.c.txExecInterrupted(fmt.Sprintf(`
ROLLBACK TO %[1]q;
RELEASE %[1]q;
`, s.name))
if err != nil {
panic(err)
}
}
// Rollback rolls the transaction back to the savepoint,
// even if the connection has been interrupted.
// Rollback does not release the savepoint.
//
// https://www.sqlite.org/lang_transaction.html
func (s Savepoint) Rollback() error {
// ROLLBACK even if interrupted.
return s.c.txExecInterrupted(fmt.Sprintf("ROLLBACK TO %q;", s.name))
}
func (c *Conn) txExecInterrupted(sql string) error {
err := c.Exec(sql)
if errors.Is(err, INTERRUPT) {
old := c.SetInterrupt(context.Background())
defer c.SetInterrupt(old)
err = c.Exec(sql)
}
return err
}

16
util.go
View File

@@ -1,16 +0,0 @@
package sqlite3
// Return true if stmt is an empty SQL statement.
// This is used as an optimization.
// It's OK to always return false here.
func emptyStatement(stmt string) bool {
for _, b := range []byte(stmt) {
switch b {
case ' ', '\n', '\r', '\t', '\v', '\f':
case ';':
default:
return false
}
}
return true
}

233
vfs.go
View File

@@ -9,7 +9,6 @@ import (
"os"
"path/filepath"
"runtime"
"syscall"
"time"
"github.com/ncruces/julianday"
@@ -26,32 +25,68 @@ func vfsInstantiate(ctx context.Context, r wazero.Runtime) {
panic(err)
}
env := r.NewHostModuleBuilder("env")
env.NewFunctionBuilder().WithFunc(vfsLocaltime).Export("go_localtime")
env.NewFunctionBuilder().WithFunc(vfsRandomness).Export("go_randomness")
env.NewFunctionBuilder().WithFunc(vfsSleep).Export("go_sleep")
env.NewFunctionBuilder().WithFunc(vfsCurrentTime).Export("go_current_time")
env.NewFunctionBuilder().WithFunc(vfsCurrentTime64).Export("go_current_time_64")
env.NewFunctionBuilder().WithFunc(vfsFullPathname).Export("go_full_pathname")
env.NewFunctionBuilder().WithFunc(vfsDelete).Export("go_delete")
env.NewFunctionBuilder().WithFunc(vfsAccess).Export("go_access")
env.NewFunctionBuilder().WithFunc(vfsOpen).Export("go_open")
env.NewFunctionBuilder().WithFunc(vfsClose).Export("go_close")
env.NewFunctionBuilder().WithFunc(vfsRead).Export("go_read")
env.NewFunctionBuilder().WithFunc(vfsWrite).Export("go_write")
env.NewFunctionBuilder().WithFunc(vfsTruncate).Export("go_truncate")
env.NewFunctionBuilder().WithFunc(vfsSync).Export("go_sync")
env.NewFunctionBuilder().WithFunc(vfsFileSize).Export("go_file_size")
env.NewFunctionBuilder().WithFunc(vfsLock).Export("go_lock")
env.NewFunctionBuilder().WithFunc(vfsUnlock).Export("go_unlock")
env.NewFunctionBuilder().WithFunc(vfsCheckReservedLock).Export("go_check_reserved_lock")
env.NewFunctionBuilder().WithFunc(vfsFileControl).Export("go_file_control")
env := vfsNewEnvModuleBuilder(r)
_, err = env.Instantiate(ctx)
if err != nil {
panic(err)
}
}
func vfsNewEnvModuleBuilder(r wazero.Runtime) wazero.HostModuleBuilder {
env := r.NewHostModuleBuilder("env")
env.NewFunctionBuilder().WithFunc(vfsLocaltime).Export("os_localtime")
env.NewFunctionBuilder().WithFunc(vfsRandomness).Export("os_randomness")
env.NewFunctionBuilder().WithFunc(vfsSleep).Export("os_sleep")
env.NewFunctionBuilder().WithFunc(vfsCurrentTime).Export("os_current_time")
env.NewFunctionBuilder().WithFunc(vfsCurrentTime64).Export("os_current_time_64")
env.NewFunctionBuilder().WithFunc(vfsFullPathname).Export("os_full_pathname")
env.NewFunctionBuilder().WithFunc(vfsDelete).Export("os_delete")
env.NewFunctionBuilder().WithFunc(vfsAccess).Export("os_access")
env.NewFunctionBuilder().WithFunc(vfsOpen).Export("os_open")
env.NewFunctionBuilder().WithFunc(vfsClose).Export("os_close")
env.NewFunctionBuilder().WithFunc(vfsRead).Export("os_read")
env.NewFunctionBuilder().WithFunc(vfsWrite).Export("os_write")
env.NewFunctionBuilder().WithFunc(vfsTruncate).Export("os_truncate")
env.NewFunctionBuilder().WithFunc(vfsSync).Export("os_sync")
env.NewFunctionBuilder().WithFunc(vfsFileSize).Export("os_file_size")
env.NewFunctionBuilder().WithFunc(vfsLock).Export("os_lock")
env.NewFunctionBuilder().WithFunc(vfsUnlock).Export("os_unlock")
env.NewFunctionBuilder().WithFunc(vfsCheckReservedLock).Export("os_check_reserved_lock")
env.NewFunctionBuilder().WithFunc(vfsFileControl).Export("os_file_control")
return env
}
// Poor man's namespaces.
const (
vfsOS vfsOSMethods = false
vfsFile vfsFileMethods = false
)
type (
vfsOSMethods bool
vfsFileMethods bool
)
type vfsKey struct{}
type vfsState struct {
files []*os.File
}
func vfsContext(ctx context.Context) (context.Context, io.Closer) {
vfs := &vfsState{}
return context.WithValue(ctx, vfsKey{}, vfs), vfs
}
func (vfs *vfsState) Close() error {
for _, f := range vfs.files {
if f != nil {
f.Close()
}
}
vfs.files = nil
return nil
}
func vfsExit(ctx context.Context, mod api.Module, exitCode uint32) {
// Ensure other callers see the exit code.
_ = mod.CloseWithExitCode(ctx, exitCode)
@@ -81,7 +116,7 @@ func vfsLocaltime(ctx context.Context, mod api.Module, t uint64, pTm uint32) uin
}
func vfsRandomness(ctx context.Context, mod api.Module, pVfs, nByte, zByte uint32) uint32 {
mem := memory{mod}.view(zByte, nByte)
mem := memory{mod}.view(zByte, uint64(nByte))
n, _ := rand.Reader.Read(mem)
return uint32(n)
}
@@ -108,88 +143,83 @@ func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull
rel := memory{mod}.readString(zRelative, _MAX_PATHNAME)
abs, err := filepath.Abs(rel)
if err != nil {
return uint32(IOERR)
return uint32(CANTOPEN_FULLPATH)
}
// Consider either using [filepath.EvalSymlinks] to canonicalize the path (as the Unix VFS does).
// Or using [os.Readlink] to resolve a symbolic link (as the Unix VFS did).
// This might be buggy on Windows (the Windows VFS doesn't try).
size := uint32(len(abs) + 1)
if size > nFull {
size := uint64(len(abs) + 1)
if size > uint64(nFull) {
return uint32(CANTOPEN_FULLPATH)
}
mem := memory{mod}.view(zFull, size)
mem[len(abs)] = 0
copy(mem, abs)
return _OK
if fi, err := os.Lstat(abs); err == nil {
if fi.Mode()&fs.ModeSymlink != 0 {
return _OK_SYMLINK
}
return _OK
} else if errors.Is(err, fs.ErrNotExist) {
return _OK
}
return uint32(CANTOPEN_FULLPATH)
}
func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32) uint32 {
path := memory{mod}.readString(zPath, _MAX_PATHNAME)
err := os.Remove(path)
if errors.Is(err, fs.ErrNotExist) {
return _OK
return uint32(IOERR_DELETE_NOENT)
}
if err != nil {
return uint32(IOERR_DELETE)
}
if runtime.GOOS != "windows" && syncDir != 0 {
f, err := os.Open(filepath.Dir(path))
if err == nil {
err = f.Sync()
f.Close()
}
if err != nil {
return uint32(IOERR_DELETE)
return _OK
}
defer f.Close()
err = vfsOS.Sync(f, false, false)
if err != nil {
return uint32(IOERR_DIR_FSYNC)
}
}
return _OK
}
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags _AccessFlag, pResOut uint32) uint32 {
// Consider using [syscall.Access] for [ACCESS_READWRITE]/[ACCESS_READ]
// (as the Unix VFS does).
path := memory{mod}.readString(zPath, _MAX_PATHNAME)
fi, err := os.Stat(path)
err := vfsOS.Access(path, flags)
var res uint32
switch {
case flags == _ACCESS_EXISTS:
var rc xErrorCode
if flags == _ACCESS_EXISTS {
switch {
case err == nil:
res = 1
case errors.Is(err, fs.ErrNotExist):
res = 0
default:
return uint32(IOERR_ACCESS)
rc = IOERR_ACCESS
}
case err == nil:
var want fs.FileMode = syscall.S_IRUSR
if flags == _ACCESS_READWRITE {
want |= syscall.S_IWUSR
}
if fi.IsDir() {
want |= syscall.S_IXUSR
}
if fi.Mode()&want == want {
} else {
switch {
case err == nil:
res = 1
} else {
case errors.Is(err, fs.ErrPermission):
res = 0
default:
rc = IOERR_ACCESS
}
case errors.Is(err, fs.ErrPermission):
res = 0
default:
return uint32(IOERR_ACCESS)
}
memory{mod}.writeUint32(pResOut, res)
return _OK
return uint32(rc)
}
func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, flags OpenFlag, pOutFlags uint32) uint32 {
@@ -213,25 +243,17 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, fla
file, err = os.CreateTemp("", "*.db")
} else {
name := memory{mod}.readString(zName, _MAX_PATHNAME)
file, err = os.OpenFile(name, oflags, 0600)
file, err = vfsOS.OpenFile(name, oflags, 0600)
}
if err != nil {
return uint32(CANTOPEN)
}
if flags&OPEN_DELETEONCLOSE != 0 {
deleteOnClose(file)
os.Remove(file.Name())
}
info, err := file.Stat()
if err != nil {
return uint32(CANTOPEN)
}
if info.IsDir() {
return uint32(CANTOPEN_ISDIR)
}
id := vfsGetOpenFileID(file, info)
vfsFilePtr{mod, pFile}.SetID(id).SetLock(_NO_LOCK)
vfsFile.Open(ctx, mod, pFile, file)
if pOutFlags != 0 {
memory{mod}.writeUint32(pOutFlags, uint32(flags))
@@ -240,8 +262,7 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, fla
}
func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 {
id := vfsFilePtr{mod, pFile}.ID()
err := vfsReleaseOpenFile(id)
err := vfsFile.Close(ctx, mod, pFile)
if err != nil {
return uint32(IOERR_CLOSE)
}
@@ -249,9 +270,9 @@ func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 {
}
func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 {
buf := memory{mod}.view(zBuf, iAmt)
buf := memory{mod}.view(zBuf, uint64(iAmt))
file := vfsFilePtr{mod, pFile}.OSFile()
file := vfsFile.GetOS(ctx, mod, pFile)
n, err := file.ReadAt(buf, int64(iOfst))
if n == int(iAmt) {
return _OK
@@ -266,9 +287,9 @@ func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfs
}
func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 {
buf := memory{mod}.view(zBuf, iAmt)
buf := memory{mod}.view(zBuf, uint64(iAmt))
file := vfsFilePtr{mod, pFile}.OSFile()
file := vfsFile.GetOS(ctx, mod, pFile)
_, err := file.WriteAt(buf, int64(iOfst))
if err != nil {
return uint32(IOERR_WRITE)
@@ -277,7 +298,7 @@ func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOf
}
func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte uint64) uint32 {
file := vfsFilePtr{mod, pFile}.OSFile()
file := vfsFile.GetOS(ctx, mod, pFile)
err := file.Truncate(int64(nByte))
if err != nil {
return uint32(IOERR_TRUNCATE)
@@ -285,9 +306,11 @@ func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte uint64
return _OK
}
func vfsSync(ctx context.Context, mod api.Module, pFile, flags uint32) uint32 {
file := vfsFilePtr{mod, pFile}.OSFile()
err := file.Sync()
func vfsSync(ctx context.Context, mod api.Module, pFile uint32, flags _SyncFlag) uint32 {
dataonly := (flags & _SYNC_DATAONLY) != 0
fullsync := (flags & 0x0f) == _SYNC_FULL
file := vfsFile.GetOS(ctx, mod, pFile)
err := vfsOS.Sync(file, fullsync, dataonly)
if err != nil {
return uint32(IOERR_FSYNC)
}
@@ -295,10 +318,7 @@ func vfsSync(ctx context.Context, mod api.Module, pFile, flags uint32) uint32 {
}
func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) uint32 {
// This uses [os.File.Seek] because we don't care about the offset for reading/writing.
// But consider using [os.File.Stat] instead (as other VFSes do).
file := vfsFilePtr{mod, pFile}.OSFile()
file := vfsFile.GetOS(ctx, mod, pFile)
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return uint32(IOERR_SEEK)
@@ -308,14 +328,43 @@ func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) uint3
return _OK
}
func vfsFileControl(ctx context.Context, pFile, op, pArg uint32) uint32 {
// SQLite calls vfsFileControl with these opcodes:
// SQLITE_FCNTL_SIZE_HINT
// SQLITE_FCNTL_PRAGMA
// SQLITE_FCNTL_BUSYHANDLER
// SQLITE_FCNTL_HAS_MOVED
// SQLITE_FCNTL_SYNC
// SQLITE_FCNTL_COMMIT_PHASETWO
// SQLITE_FCNTL_PDB
func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _FcntlOpcode, pArg uint32) uint32 {
switch op {
case _FCNTL_SIZE_HINT:
return vfsSizeHint(ctx, mod, pFile, pArg)
case _FCNTL_HAS_MOVED:
return vfsFileMoved(ctx, mod, pFile, pArg)
}
return uint32(NOTFOUND)
}
func vfsSizeHint(ctx context.Context, mod api.Module, pFile, pArg uint32) uint32 {
file := vfsFile.GetOS(ctx, mod, pFile)
size := memory{mod}.readUint64(pArg)
err := vfsOS.Allocate(file, int64(size))
if err == notImplErr {
return uint32(NOTFOUND)
}
if err != nil {
return uint32(IOERR_TRUNCATE)
}
return _OK
}
func vfsFileMoved(ctx context.Context, mod api.Module, pFile, pResOut uint32) uint32 {
file := vfsFile.GetOS(ctx, mod, pFile)
fi, err := file.Stat()
if err != nil {
return uint32(IOERR_FSTAT)
}
pi, err := os.Stat(file.Name())
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return uint32(IOERR_FSTAT)
}
var res uint32
if !os.SameFile(fi, pi) {
res = 1
}
memory{mod}.writeUint32(pResOut, res)
return _OK
}

69
vfs_file.go Normal file
View File

@@ -0,0 +1,69 @@
package sqlite3
import (
"context"
"os"
"time"
"github.com/tetratelabs/wazero/api"
)
const (
// These need to match the offsets asserted in os.c
vfsFileIDOffset = 4
vfsFileLockOffset = 8
vfsFileLockTimeoutOffset = 12
)
func (vfsFileMethods) NewID(ctx context.Context, file *os.File) uint32 {
vfs := ctx.Value(vfsKey{}).(*vfsState)
// Find an empty slot.
for id, ptr := range vfs.files {
if ptr == nil {
vfs.files[id] = file
return uint32(id)
}
}
// Add a new slot.
vfs.files = append(vfs.files, file)
return uint32(len(vfs.files) - 1)
}
func (vfsFileMethods) Open(ctx context.Context, mod api.Module, pFile uint32, file *os.File) {
mem := memory{mod}
id := vfsFile.NewID(ctx, file)
mem.writeUint32(pFile+vfsFileIDOffset, id)
}
func (vfsFileMethods) Close(ctx context.Context, mod api.Module, pFile uint32) error {
mem := memory{mod}
id := mem.readUint32(pFile + vfsFileIDOffset)
vfs := ctx.Value(vfsKey{}).(*vfsState)
file := vfs.files[id]
vfs.files[id] = nil
return file.Close()
}
func (vfsFileMethods) GetOS(ctx context.Context, mod api.Module, pFile uint32) *os.File {
mem := memory{mod}
id := mem.readUint32(pFile + vfsFileIDOffset)
vfs := ctx.Value(vfsKey{}).(*vfsState)
return vfs.files[id]
}
func (vfsFileMethods) GetLock(ctx context.Context, mod api.Module, pFile uint32) vfsLockState {
mem := memory{mod}
return vfsLockState(mem.readUint8(pFile + vfsFileLockOffset))
}
func (vfsFileMethods) SetLock(ctx context.Context, mod api.Module, pFile uint32, lock vfsLockState) {
mem := memory{mod}
mem.writeUint8(pFile+vfsFileLockOffset, uint8(lock))
}
func (vfsFileMethods) GetLockTimeout(ctx context.Context, mod api.Module, pFile uint32) time.Duration {
mem := memory{mod}
return time.Duration(mem.readUint32(pFile+vfsFileLockTimeoutOffset)) * time.Millisecond
}

View File

@@ -1,107 +0,0 @@
package sqlite3
import (
"os"
"sync"
"github.com/tetratelabs/wazero/api"
)
type vfsOpenFile struct {
file *os.File
info os.FileInfo
nref int
locker vfsFileLocker
}
var (
vfsOpenFiles []*vfsOpenFile
vfsOpenFilesMtx sync.Mutex
)
func vfsGetOpenFileID(file *os.File, info os.FileInfo) uint32 {
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
// Reuse an already opened file.
for id, of := range vfsOpenFiles {
if of == nil {
continue
}
if os.SameFile(info, of.info) {
of.nref++
_ = file.Close()
return uint32(id)
}
}
of := &vfsOpenFile{
file: file,
info: info,
nref: 1,
locker: vfsFileLocker{file: file},
}
// Find an empty slot.
for id, ptr := range vfsOpenFiles {
if ptr == nil {
vfsOpenFiles[id] = of
return uint32(id)
}
}
// Add a new slot.
id := len(vfsOpenFiles)
vfsOpenFiles = append(vfsOpenFiles, of)
return uint32(id)
}
func vfsReleaseOpenFile(id uint32) error {
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
of := vfsOpenFiles[id]
if of.nref--; of.nref > 0 {
return nil
}
err := of.file.Close()
vfsOpenFiles[id] = nil
return err
}
type vfsFilePtr struct {
api.Module
ptr uint32
}
func (p vfsFilePtr) OSFile() *os.File {
id := p.ID()
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
return vfsOpenFiles[id].file
}
func (p vfsFilePtr) Locker() *vfsFileLocker {
id := p.ID()
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
return &vfsOpenFiles[id].locker
}
func (p vfsFilePtr) ID() uint32 {
return memory{p}.readUint32(p.ptr + ptrlen)
}
func (p vfsFilePtr) Lock() vfsLockState {
return vfsLockState(memory{p}.readUint32(p.ptr + 2*ptrlen))
}
func (p vfsFilePtr) SetID(id uint32) vfsFilePtr {
memory{p}.writeUint32(p.ptr+ptrlen, id)
return p
}
func (p vfsFilePtr) SetLock(lock vfsLockState) vfsFilePtr {
memory{p}.writeUint32(p.ptr+2*ptrlen, uint32(lock))
return p
}

View File

@@ -3,7 +3,7 @@ package sqlite3
import (
"context"
"os"
"sync"
"time"
"github.com/tetratelabs/wazero/api"
)
@@ -56,21 +56,15 @@ const (
type vfsLockState uint32
type vfsFileLocker struct {
sync.Mutex
file *os.File
state vfsLockState
shared int
}
func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) uint32 {
// Argument check. SQLite never explicitly requests a pendig lock.
// Argument check. SQLite never explicitly requests a pending lock.
if eLock != _SHARED_LOCK && eLock != _RESERVED_LOCK && eLock != _EXCLUSIVE_LOCK {
panic(assertErr())
}
ptr := vfsFilePtr{mod, pFile}
cLock := ptr.Lock()
file := vfsFile.GetOS(ctx, mod, pFile)
cLock := vfsFile.GetLock(ctx, mod, pFile)
timeout := vfsFile.GetLockTimeout(ctx, mod, pFile)
switch {
case cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK:
@@ -89,93 +83,49 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
return _OK
}
fLock := ptr.Locker()
fLock.Lock()
defer fLock.Unlock()
// File state check.
switch {
case fLock.state < _NO_LOCK || fLock.state > _EXCLUSIVE_LOCK:
panic(assertErr())
case fLock.state == _NO_LOCK && fLock.shared != 0:
panic(assertErr())
case fLock.state == _EXCLUSIVE_LOCK && fLock.shared != 1:
panic(assertErr())
case fLock.state != _NO_LOCK && fLock.shared <= 0:
panic(assertErr())
case fLock.state < cLock:
panic(assertErr())
}
// If some other connection has a lock that precludes the requested lock, return BUSY.
if cLock != fLock.state && (eLock > _SHARED_LOCK || fLock.state >= _PENDING_LOCK) {
return uint32(BUSY)
}
switch eLock {
case _SHARED_LOCK:
// Test the PENDING lock before acquiring a new SHARED lock.
if locked, _ := fLock.CheckPending(); locked {
return uint32(BUSY)
}
// If some other connection has a SHARED or RESERVED lock,
// increment the reference count and return OK.
if fLock.state == _SHARED_LOCK || fLock.state == _RESERVED_LOCK {
ptr.SetLock(_SHARED_LOCK)
fLock.shared++
return _OK
}
// Must be unlocked to get SHARED.
if fLock.state != _NO_LOCK {
if cLock != _NO_LOCK {
panic(assertErr())
}
if rc := fLock.GetShared(); rc != _OK {
// Test the PENDING lock before acquiring a new SHARED lock.
if locked, _ := vfsOS.CheckPendingLock(file); locked {
return uint32(BUSY)
}
if rc := vfsOS.GetSharedLock(file, timeout); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_SHARED_LOCK)
fLock.state = _SHARED_LOCK
fLock.shared = 1
vfsFile.SetLock(ctx, mod, pFile, _SHARED_LOCK)
return _OK
case _RESERVED_LOCK:
// Must be SHARED to get RESERVED.
if fLock.state != _SHARED_LOCK {
if cLock != _SHARED_LOCK {
panic(assertErr())
}
if rc := fLock.GetReserved(); rc != _OK {
if rc := vfsOS.GetReservedLock(file, timeout); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_RESERVED_LOCK)
fLock.state = _RESERVED_LOCK
vfsFile.SetLock(ctx, mod, pFile, _RESERVED_LOCK)
return _OK
case _EXCLUSIVE_LOCK:
// Must be SHARED, PENDING or RESERVED to get EXCLUSIVE.
if fLock.state <= _NO_LOCK || fLock.state >= _EXCLUSIVE_LOCK {
// Must be SHARED, RESERVED or PENDING to get EXCLUSIVE.
if cLock <= _NO_LOCK || cLock >= _EXCLUSIVE_LOCK {
panic(assertErr())
}
// A PENDING lock is needed before acquiring an EXCLUSIVE lock.
if fLock.state == _RESERVED_LOCK {
if rc := fLock.GetPending(); rc != _OK {
if cLock < _PENDING_LOCK {
if rc := vfsOS.GetPendingLock(file); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_PENDING_LOCK)
fLock.state = _PENDING_LOCK
vfsFile.SetLock(ctx, mod, pFile, _PENDING_LOCK)
}
// We are trying for an EXCLUSIVE lock but another connection is still holding a shared lock.
if fLock.shared > 1 {
return uint32(BUSY)
}
if rc := fLock.GetExclusive(); rc != _OK {
if rc := vfsOS.GetExclusiveLock(file, timeout); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_EXCLUSIVE_LOCK)
fLock.state = _EXCLUSIVE_LOCK
vfsFile.SetLock(ctx, mod, pFile, _EXCLUSIVE_LOCK)
return _OK
default:
@@ -189,8 +139,8 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS
panic(assertErr())
}
ptr := vfsFilePtr{mod, pFile}
cLock := ptr.Lock()
file := vfsFile.GetOS(ctx, mod, pFile)
cLock := vfsFile.GetLock(ctx, mod, pFile)
// Connection state check.
if cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK {
@@ -202,71 +152,28 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS
return _OK
}
fLock := ptr.Locker()
fLock.Lock()
defer fLock.Unlock()
// File state check.
switch {
case fLock.state <= _NO_LOCK || fLock.state > _EXCLUSIVE_LOCK:
panic(assertErr())
case fLock.state == _EXCLUSIVE_LOCK && fLock.shared != 1:
panic(assertErr())
case fLock.shared <= 0:
panic(assertErr())
case fLock.state < cLock:
panic(assertErr())
}
if cLock > _SHARED_LOCK {
// The connection must own the lock to release it.
if cLock != fLock.state {
panic(assertErr())
switch eLock {
case _SHARED_LOCK:
if rc := vfsOS.DowngradeLock(file, cLock); rc != _OK {
return uint32(rc)
}
if eLock == _SHARED_LOCK {
if rc := fLock.Downgrade(); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_SHARED_LOCK)
fLock.state = _SHARED_LOCK
return _OK
}
}
vfsFile.SetLock(ctx, mod, pFile, _SHARED_LOCK)
return _OK
// If we get here, make sure we're dropping all locks.
if eLock != _NO_LOCK {
panic(assertErr())
}
// Release the connection lock and decrement the shared lock counter.
// Release the file lock only when all connections have released the lock.
ptr.SetLock(_NO_LOCK)
if fLock.shared--; fLock.shared == 0 {
rc := fLock.Release()
fLock.state = _NO_LOCK
case _NO_LOCK:
rc := vfsOS.ReleaseLock(file, cLock)
vfsFile.SetLock(ctx, mod, pFile, _NO_LOCK)
return uint32(rc)
default:
panic(assertErr())
}
return _OK
}
func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) uint32 {
ptr := vfsFilePtr{mod, pFile}
cLock := ptr.Lock()
file := vfsFile.GetOS(ctx, mod, pFile)
if cLock > _SHARED_LOCK {
panic(assertErr())
}
fLock := ptr.Locker()
fLock.Lock()
defer fLock.Unlock()
if fLock.state >= _RESERVED_LOCK {
memory{mod}.writeUint32(pResOut, 1)
return _OK
}
locked, rc := fLock.CheckReserved()
locked, rc := vfsOS.CheckReservedLock(file)
var res uint32
if locked {
res = 1
@@ -274,3 +181,28 @@ func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut ui
memory{mod}.writeUint32(pResOut, res)
return uint32(rc)
}
func (vfsOSMethods) GetSharedLock(file *os.File, timeout time.Duration) xErrorCode {
// Acquire the SHARED lock.
return vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
}
func (vfsOSMethods) GetReservedLock(file *os.File, timeout time.Duration) xErrorCode {
// Acquire the RESERVED lock.
return vfsOS.writeLock(file, _RESERVED_BYTE, 1, timeout)
}
func (vfsOSMethods) GetPendingLock(file *os.File) xErrorCode {
// Acquire the PENDING lock.
return vfsOS.writeLock(file, _PENDING_BYTE, 1, 0)
}
func (vfsOSMethods) CheckReservedLock(file *os.File) (bool, xErrorCode) {
// Test the RESERVED lock.
return vfsOS.checkLock(file, _RESERVED_BYTE, 1)
}
func (vfsOSMethods) CheckPendingLock(file *os.File) (bool, xErrorCode) {
// Test the PENDING lock.
return vfsOS.checkLock(file, _PENDING_BYTE, 1)
}

View File

@@ -9,12 +9,11 @@ import (
)
func Test_vfsLock(t *testing.T) {
// Other OSes lack open file descriptors locks.
switch runtime.GOOS {
case "linux", "darwin", "illumos", "windows":
//
case "linux", "darwin", "windows":
break
default:
t.Skip()
t.Skip("OS lacks OFD locks")
}
name := filepath.Join(t.TempDir(), "test.db")
@@ -33,28 +32,19 @@ func Test_vfsLock(t *testing.T) {
}
defer file2.Close()
// Bypass open file reuse.
vfsOpenFiles = append(vfsOpenFiles, &vfsOpenFile{
file: file1,
nref: 1,
locker: vfsFileLocker{file: file1},
}, &vfsOpenFile{
file: file2,
nref: 1,
locker: vfsFileLocker{file: file2},
})
mem := newMemory(128)
mem.writeUint32(4+4, 0)
mem.writeUint32(16+4, 1)
const (
pFile1 = 4
pFile2 = 16
pOutput = 32
)
mem := newMemory(128)
ctx, vfs := vfsContext(context.TODO())
defer vfs.Close()
rc := vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
vfsFile.Open(ctx, mem.mod, pFile1, file1)
vfsFile.Open(ctx, mem.mod, pFile2, file2)
rc := vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -62,12 +52,12 @@ func Test_vfsLock(t *testing.T) {
t.Error("file was locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile2, _SHARED_LOCK)
rc = vfsLock(ctx, mem.mod, pFile2, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -75,16 +65,16 @@ func Test_vfsLock(t *testing.T) {
t.Error("file was locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile2, _RESERVED_LOCK)
rc = vfsLock(ctx, mem.mod, pFile2, _RESERVED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsLock(context.TODO(), mem.mod, pFile2, _SHARED_LOCK)
rc = vfsLock(ctx, mem.mod, pFile2, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -92,12 +82,12 @@ func Test_vfsLock(t *testing.T) {
t.Error("file wasn't locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile2, _EXCLUSIVE_LOCK)
rc = vfsLock(ctx, mem.mod, pFile2, _EXCLUSIVE_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -105,17 +95,33 @@ func Test_vfsLock(t *testing.T) {
t.Error("file wasn't locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile1, _SHARED_LOCK)
rc = vfsLock(ctx, mem.mod, pFile1, _SHARED_LOCK)
if rc == _OK {
t.Fatal("returned", rc)
}
rc = vfsUnlock(context.TODO(), mem.mod, pFile2, _SHARED_LOCK)
rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsUnlock(ctx, mem.mod, pFile2, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsLock(context.TODO(), mem.mod, pFile1, _SHARED_LOCK)
rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(ctx, mem.mod, pFile1, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}

60
vfs_os_darwin.go Normal file
View File

@@ -0,0 +1,60 @@
package sqlite3
import (
"os"
"time"
"golang.org/x/sys/unix"
)
func (vfsOSMethods) Sync(file *os.File, fullsync, dataonly bool) error {
if !fullsync {
return unix.Fsync(int(file.Fd()))
}
return file.Sync()
}
func (vfsOSMethods) Allocate(file *os.File, size int64) error {
// https://stackoverflow.com/a/11497568/867786
store := unix.Fstore_t{
Flags: unix.F_ALLOCATECONTIG,
Posmode: unix.F_PEOFPOSMODE,
Offset: 0,
Length: size,
}
// Try to get a continous chunk of disk space.
err := unix.FcntlFstore(file.Fd(), unix.F_PREALLOCATE, &store)
if err != nil {
// OK, perhaps we are too fragmented, allocate non-continuous.
store.Flags = unix.F_ALLOCATEALL
return unix.FcntlFstore(file.Fd(), unix.F_PREALLOCATE, &store)
}
return nil
}
func (vfsOSMethods) fcntlGetLock(file *os.File, lock *unix.Flock_t) error {
const F_OFD_GETLK = 92 // https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
return unix.FcntlFlock(file.Fd(), F_OFD_GETLK, lock)
}
func (vfsOSMethods) fcntlSetLock(file *os.File, lock unix.Flock_t) error {
const F_OFD_SETLK = 90 // https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
return unix.FcntlFlock(file.Fd(), F_OFD_SETLK, &lock)
}
func (vfsOSMethods) fcntlSetLockTimeout(file *os.File, lock unix.Flock_t, timeout time.Duration) error {
if timeout == 0 {
return vfsOS.fcntlSetLock(file, lock)
}
const F_OFD_SETLKWTIMEOUT = 93 // https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
flocktimeout := &struct {
unix.Flock_t
unix.Timespec
}{
Flock_t: lock,
Timespec: unix.NsecToTimespec(int64(timeout / time.Nanosecond)),
}
return unix.FcntlFlock(file.Fd(), F_OFD_SETLKWTIMEOUT, &flocktimeout.Flock_t)
}

49
vfs_os_linux.go Normal file
View File

@@ -0,0 +1,49 @@
package sqlite3
import (
"os"
"time"
"golang.org/x/sys/unix"
)
func (vfsOSMethods) Sync(file *os.File, fullsync, dataonly bool) error {
if dataonly {
//lint:ignore SA1019 OK on linux
_, _, err := unix.Syscall(unix.SYS_FDATASYNC, file.Fd(), 0, 0)
if err != 0 {
return err
}
return nil
}
return file.Sync()
}
func (vfsOSMethods) Allocate(file *os.File, size int64) error {
if size == 0 {
return nil
}
return unix.Fallocate(int(file.Fd()), 0, 0, size)
}
func (vfsOSMethods) fcntlGetLock(file *os.File, lock *unix.Flock_t) error {
return unix.FcntlFlock(file.Fd(), unix.F_OFD_GETLK, lock)
}
func (vfsOSMethods) fcntlSetLock(file *os.File, lock unix.Flock_t) error {
return unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock)
}
func (vfsOSMethods) fcntlSetLockTimeout(file *os.File, lock unix.Flock_t, timeout time.Duration) error {
for {
err := unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock)
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
return err
}
if timeout < time.Millisecond {
return err
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
}

30
vfs_os_posix.go Normal file
View File

@@ -0,0 +1,30 @@
//go:build !windows && !linux && !darwin
package sqlite3
import (
"os"
"time"
"golang.org/x/sys/unix"
)
func (vfsOSMethods) Sync(file *os.File, fullsync, dataonly bool) error {
return file.Sync()
}
func (vfsOSMethods) Allocate(file *os.File, size int64) error {
return notImplErr
}
func (vfsOSMethods) fcntlGetLock(file *os.File, lock *unix.Flock_t) error {
return notImplErr
}
func (vfsOSMethods) fcntlSetLock(file *os.File, lock unix.Flock_t) error {
return notImplErr
}
func (vfsOSMethods) fcntlSetLockTimeout(file *os.File, lock unix.Flock_t, timeout time.Duration) error {
return notImplErr
}

119
vfs_os_unix.go Normal file
View File

@@ -0,0 +1,119 @@
//go:build unix
package sqlite3
import (
"io/fs"
"os"
"time"
"golang.org/x/sys/unix"
)
func (vfsOSMethods) OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(name, flag, perm)
}
func (vfsOSMethods) Access(path string, flags _AccessFlag) error {
var access uint32 = unix.F_OK
switch flags {
case _ACCESS_READWRITE:
access = unix.R_OK | unix.W_OK
case _ACCESS_READ:
access = unix.R_OK
}
return unix.Access(path, access)
}
func (vfsOSMethods) GetExclusiveLock(file *os.File, timeout time.Duration) xErrorCode {
if timeout == 0 {
timeout = time.Millisecond
}
// Acquire the EXCLUSIVE lock.
return vfsOS.writeLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
}
func (vfsOSMethods) DowngradeLock(file *os.File, state vfsLockState) xErrorCode {
if state >= _EXCLUSIVE_LOCK {
// Downgrade to a SHARED lock.
if rc := vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return vfsOS.unlock(file, _PENDING_BYTE, 2)
}
func (vfsOSMethods) ReleaseLock(file *os.File, _ vfsLockState) xErrorCode {
// Release all locks.
return vfsOS.unlock(file, 0, 0)
}
func (vfsOSMethods) unlock(file *os.File, start, len int64) xErrorCode {
err := vfsOS.fcntlSetLock(file, unix.Flock_t{
Type: unix.F_UNLCK,
Start: start,
Len: len,
})
if err != nil {
return IOERR_UNLOCK
}
return _OK
}
func (vfsOSMethods) readLock(file *os.File, start, len int64, timeout time.Duration) xErrorCode {
return vfsOS.lockErrorCode(vfsOS.fcntlSetLockTimeout(file, unix.Flock_t{
Type: unix.F_RDLCK,
Start: start,
Len: len,
}, timeout), IOERR_RDLOCK)
}
func (vfsOSMethods) writeLock(file *os.File, start, len int64, timeout time.Duration) xErrorCode {
// TODO: implement timeouts.
return vfsOS.lockErrorCode(vfsOS.fcntlSetLockTimeout(file, unix.Flock_t{
Type: unix.F_WRLCK,
Start: start,
Len: len,
}, timeout), IOERR_LOCK)
}
func (vfsOSMethods) checkLock(file *os.File, start, len int64) (bool, xErrorCode) {
lock := unix.Flock_t{
Type: unix.F_RDLCK,
Start: start,
Len: len,
}
if vfsOS.fcntlGetLock(file, &lock) != nil {
return false, IOERR_CHECKRESERVEDLOCK
}
return lock.Type != unix.F_UNLCK, _OK
}
func (vfsOSMethods) lockErrorCode(err error, def xErrorCode) xErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(unix.Errno); ok {
switch errno {
case
unix.EACCES,
unix.EAGAIN,
unix.EBUSY,
unix.EINTR,
unix.ENOLCK,
unix.EDEADLK,
unix.ETIMEDOUT:
return xErrorCode(BUSY)
case unix.EPERM:
return xErrorCode(PERM)
}
}
return def
}

240
vfs_os_windows.go Normal file
View File

@@ -0,0 +1,240 @@
package sqlite3
import (
"io"
"io/fs"
"os"
"syscall"
"time"
"golang.org/x/sys/windows"
)
// OpenFile is a simplified copy of [os.openFileNolog]
// that uses syscall.FILE_SHARE_DELETE.
// https://go.dev/src/os/file_windows.go
//
// See: https://go.dev/issue/32088
func (vfsOSMethods) OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
if name == "" {
return nil, &os.PathError{Op: "open", Path: name, Err: syscall.ENOENT}
}
r, e := syscallOpen(name, flag, uint32(perm.Perm()))
if e != nil {
return nil, &os.PathError{Op: "open", Path: name, Err: e}
}
return os.NewFile(uintptr(r), name), nil
}
func (vfsOSMethods) Access(path string, flags _AccessFlag) error {
fi, err := os.Stat(path)
if err != nil {
return err
}
if flags == _ACCESS_EXISTS {
return nil
}
var want fs.FileMode = windows.S_IRUSR
if flags == _ACCESS_READWRITE {
want |= windows.S_IWUSR
}
if fi.IsDir() {
want |= windows.S_IXUSR
}
if fi.Mode()&want != want {
return fs.ErrPermission
}
return nil
}
func (vfsOSMethods) Sync(file *os.File, fullsync, dataonly bool) error {
return file.Sync()
}
func (vfsOSMethods) Allocate(file *os.File, size int64) error {
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return err
}
if size > off {
return file.Truncate(size)
}
return nil
}
func (vfsOSMethods) GetExclusiveLock(file *os.File, timeout time.Duration) xErrorCode {
if timeout == 0 {
timeout = time.Millisecond
}
// Release the SHARED lock.
vfsOS.unlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Acquire the EXCLUSIVE lock.
rc := vfsOS.writeLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
// Reacquire the SHARED lock.
if rc != _OK {
vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
}
return rc
}
func (vfsOSMethods) DowngradeLock(file *os.File, state vfsLockState) xErrorCode {
if state >= _EXCLUSIVE_LOCK {
// Release the SHARED lock.
vfsOS.unlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Reacquire the SHARED lock.
if rc := vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// This should never happen.
// We should always be able to reacquire the read lock.
return IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
if state >= _RESERVED_LOCK {
vfsOS.unlock(file, _RESERVED_BYTE, 1)
}
if state >= _PENDING_LOCK {
vfsOS.unlock(file, _PENDING_BYTE, 1)
}
return _OK
}
func (vfsOSMethods) ReleaseLock(file *os.File, state vfsLockState) xErrorCode {
// Release all locks.
if state >= _RESERVED_LOCK {
vfsOS.unlock(file, _RESERVED_BYTE, 1)
}
if state >= _SHARED_LOCK {
vfsOS.unlock(file, _SHARED_FIRST, _SHARED_SIZE)
}
if state >= _PENDING_LOCK {
vfsOS.unlock(file, _PENDING_BYTE, 1)
}
return _OK
}
func (vfsOSMethods) unlock(file *os.File, start, len uint32) xErrorCode {
err := windows.UnlockFileEx(windows.Handle(file.Fd()),
0, len, 0, &windows.Overlapped{Offset: start})
if err == windows.ERROR_NOT_LOCKED {
return _OK
}
if err != nil {
return IOERR_UNLOCK
}
return _OK
}
func (vfsOSMethods) lock(file *os.File, flags, start, len uint32, timeout time.Duration, def xErrorCode) xErrorCode {
for {
err := windows.LockFileEx(windows.Handle(file.Fd()), flags,
0, len, 0, &windows.Overlapped{Offset: start})
if errno, _ := err.(windows.Errno); errno != windows.ERROR_LOCK_VIOLATION {
return vfsOS.lockErrorCode(err, def)
}
if timeout < time.Millisecond {
return vfsOS.lockErrorCode(err, def)
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
}
func (vfsOSMethods) readLock(file *os.File, start, len uint32, timeout time.Duration) xErrorCode {
return vfsOS.lock(file,
windows.LOCKFILE_FAIL_IMMEDIATELY,
start, len, timeout, IOERR_RDLOCK)
}
func (vfsOSMethods) writeLock(file *os.File, start, len uint32, timeout time.Duration) xErrorCode {
return vfsOS.lock(file,
windows.LOCKFILE_FAIL_IMMEDIATELY|windows.LOCKFILE_EXCLUSIVE_LOCK,
start, len, timeout, IOERR_LOCK)
}
func (vfsOSMethods) checkLock(file *os.File, start, len uint32) (bool, xErrorCode) {
rc := vfsOS.lock(file,
windows.LOCKFILE_FAIL_IMMEDIATELY,
start, len, 0, IOERR_CHECKRESERVEDLOCK)
if rc == xErrorCode(BUSY) {
return true, _OK
}
if rc == _OK {
vfsOS.unlock(file, start, len)
}
return false, rc
}
func (vfsOSMethods) lockErrorCode(err error, def xErrorCode) xErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(windows.Errno); ok {
// https://devblogs.microsoft.com/oldnewthing/20140905-00/?p=63
switch errno {
case
windows.ERROR_LOCK_VIOLATION,
windows.ERROR_IO_PENDING,
windows.ERROR_OPERATION_ABORTED:
return xErrorCode(BUSY)
}
}
return def
}
// syscallOpen is a simplified copy of [syscall.Open]
// that uses syscall.FILE_SHARE_DELETE.
// https://go.dev/src/syscall/syscall_windows.go
func syscallOpen(path string, mode int, perm uint32) (fd syscall.Handle, err error) {
if len(path) == 0 {
return syscall.InvalidHandle, syscall.ERROR_FILE_NOT_FOUND
}
pathp, err := syscall.UTF16PtrFromString(path)
if err != nil {
return syscall.InvalidHandle, err
}
var access uint32
switch mode & (syscall.O_RDONLY | syscall.O_WRONLY | syscall.O_RDWR) {
case syscall.O_RDONLY:
access = syscall.GENERIC_READ
case syscall.O_WRONLY:
access = syscall.GENERIC_WRITE
case syscall.O_RDWR:
access = syscall.GENERIC_READ | syscall.GENERIC_WRITE
}
if mode&syscall.O_CREAT != 0 {
access |= syscall.GENERIC_WRITE
}
if mode&syscall.O_APPEND != 0 {
access &^= syscall.GENERIC_WRITE
access |= syscall.FILE_APPEND_DATA
}
sharemode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE)
var createmode uint32
switch {
case mode&(syscall.O_CREAT|syscall.O_EXCL) == (syscall.O_CREAT | syscall.O_EXCL):
createmode = syscall.CREATE_NEW
case mode&(syscall.O_CREAT|syscall.O_TRUNC) == (syscall.O_CREAT | syscall.O_TRUNC):
createmode = syscall.CREATE_ALWAYS
case mode&syscall.O_CREAT == syscall.O_CREAT:
createmode = syscall.OPEN_ALWAYS
case mode&syscall.O_TRUNC == syscall.O_TRUNC:
createmode = syscall.TRUNCATE_EXISTING
default:
createmode = syscall.OPEN_EXISTING
}
var attrs uint32 = syscall.FILE_ATTRIBUTE_NORMAL
if perm&syscall.S_IWRITE == 0 {
attrs = syscall.FILE_ATTRIBUTE_READONLY
}
if createmode == syscall.OPEN_EXISTING && access == syscall.GENERIC_READ {
// Necessary for opening directory handles.
attrs |= syscall.FILE_FLAG_BACKUP_SEMANTICS
}
return syscall.CreateFile(pathp, access, sharemode, nil, createmode, attrs, 0)
}

View File

@@ -16,15 +16,17 @@ import (
func Test_vfsExit(t *testing.T) {
mem := newMemory(128)
ctx := context.TODO()
defer func() { _ = recover() }()
vfsExit(context.TODO(), mem.mod, 1)
vfsExit(ctx, mem.mod, 1)
t.Error("want panic")
}
func Test_vfsLocaltime(t *testing.T) {
mem := newMemory(128)
ctx := context.TODO()
rc := vfsLocaltime(context.TODO(), mem.mod, 0, 4)
rc := vfsLocaltime(ctx, mem.mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
@@ -71,24 +73,26 @@ func Test_vfsRandomness(t *testing.T) {
}
func Test_vfsSleep(t *testing.T) {
start := time.Now()
ctx := context.TODO()
rc := vfsSleep(context.TODO(), 0, 123456)
now := time.Now()
rc := vfsSleep(ctx, 0, 123456)
if rc != 0 {
t.Fatal("returned", rc)
}
want := 123456 * time.Microsecond
if got := time.Since(start); got < want {
if got := time.Since(now); got < want {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsCurrentTime(t *testing.T) {
mem := newMemory(128)
ctx := context.TODO()
now := time.Now()
rc := vfsCurrentTime(context.TODO(), mem.mod, 0, 4)
rc := vfsCurrentTime(ctx, mem.mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
@@ -101,10 +105,11 @@ func Test_vfsCurrentTime(t *testing.T) {
func Test_vfsCurrentTime64(t *testing.T) {
mem := newMemory(128)
ctx := context.TODO()
now := time.Now()
time.Sleep(time.Millisecond)
rc := vfsCurrentTime64(context.TODO(), mem.mod, 0, 4)
rc := vfsCurrentTime64(ctx, mem.mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
@@ -119,13 +124,14 @@ func Test_vfsCurrentTime64(t *testing.T) {
func Test_vfsFullPathname(t *testing.T) {
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(4, ".")
ctx := context.TODO()
rc := vfsFullPathname(context.TODO(), mem.mod, 0, 4, 0, 8)
rc := vfsFullPathname(ctx, mem.mod, 0, 4, 0, 8)
if rc != uint32(CANTOPEN_FULLPATH) {
t.Errorf("returned %d, want %d", rc, CANTOPEN_FULLPATH)
}
rc = vfsFullPathname(context.TODO(), mem.mod, 0, 4, _MAX_PATHNAME, 8)
rc = vfsFullPathname(ctx, mem.mod, 0, 4, _MAX_PATHNAME, 8)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -147,8 +153,9 @@ func Test_vfsDelete(t *testing.T) {
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(4, name)
ctx := context.TODO()
rc := vfsDelete(context.TODO(), mem.mod, 0, 4, 1)
rc := vfsDelete(ctx, mem.mod, 0, 4, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -157,8 +164,8 @@ func Test_vfsDelete(t *testing.T) {
t.Fatal("did not delete the file")
}
rc = vfsDelete(context.TODO(), mem.mod, 0, 4, 1)
if rc != _OK {
rc = vfsDelete(ctx, mem.mod, 0, 4, 1)
if rc != uint32(IOERR_DELETE_NOENT) {
t.Fatal("returned", rc)
}
}
@@ -177,8 +184,9 @@ func Test_vfsAccess(t *testing.T) {
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(8, dir)
ctx := context.TODO()
rc := vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_EXISTS, 4)
rc := vfsAccess(ctx, mem.mod, 0, 8, _ACCESS_EXISTS, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -186,7 +194,7 @@ func Test_vfsAccess(t *testing.T) {
t.Error("directory did not exist")
}
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_READWRITE, 4)
rc = vfsAccess(ctx, mem.mod, 0, 8, _ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -195,7 +203,7 @@ func Test_vfsAccess(t *testing.T) {
}
mem.writeString(8, file)
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_READWRITE, 4)
rc = vfsAccess(ctx, mem.mod, 0, 8, _ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -206,9 +214,11 @@ func Test_vfsAccess(t *testing.T) {
func Test_vfsFile(t *testing.T) {
mem := newMemory(128)
ctx, vfs := vfsContext(context.TODO())
defer vfs.Close()
// Open a temporary file.
rc := vfsOpen(context.TODO(), mem.mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)
rc := vfsOpen(ctx, mem.mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -216,13 +226,13 @@ func Test_vfsFile(t *testing.T) {
// Write stuff.
text := "Hello world!"
mem.writeString(16, text)
rc = vfsWrite(context.TODO(), mem.mod, 4, 16, uint32(len(text)), 0)
rc = vfsWrite(ctx, mem.mod, 4, 16, uint32(len(text)), 0)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check file size.
rc = vfsFileSize(context.TODO(), mem.mod, 4, 16)
rc = vfsFileSize(ctx, mem.mod, 4, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -231,7 +241,7 @@ func Test_vfsFile(t *testing.T) {
}
// Partial read at offset.
rc = vfsRead(context.TODO(), mem.mod, 4, 16, uint32(len(text)), 4)
rc = vfsRead(ctx, mem.mod, 4, 16, uint32(len(text)), 4)
if rc != uint32(IOERR_SHORT_READ) {
t.Fatal("returned", rc)
}
@@ -240,13 +250,13 @@ func Test_vfsFile(t *testing.T) {
}
// Truncate the file.
rc = vfsTruncate(context.TODO(), mem.mod, 4, 4)
rc = vfsTruncate(ctx, mem.mod, 4, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check file size.
rc = vfsFileSize(context.TODO(), mem.mod, 4, 16)
rc = vfsFileSize(ctx, mem.mod, 4, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -255,7 +265,7 @@ func Test_vfsFile(t *testing.T) {
}
// Read at offset.
rc = vfsRead(context.TODO(), mem.mod, 4, 32, 4, 0)
rc = vfsRead(ctx, mem.mod, 4, 32, 4, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -264,7 +274,7 @@ func Test_vfsFile(t *testing.T) {
}
// Close the file.
rc = vfsClose(context.TODO(), mem.mod, 4)
rc = vfsClose(ctx, mem.mod, 4)
if rc != _OK {
t.Fatal("returned", rc)
}

View File

@@ -1,158 +0,0 @@
//go:build unix
package sqlite3
import (
"os"
"runtime"
"syscall"
)
func deleteOnClose(f *os.File) {
_ = os.Remove(f.Name())
}
func (l *vfsFileLocker) GetShared() xErrorCode {
// Acquire the SHARED lock.
return l.readLock(_SHARED_FIRST, _SHARED_SIZE)
}
func (l *vfsFileLocker) GetReserved() xErrorCode {
// Acquire the RESERVED lock.
return l.writeLock(_RESERVED_BYTE, 1)
}
func (l *vfsFileLocker) GetPending() xErrorCode {
// Acquire the PENDING lock.
return l.writeLock(_PENDING_BYTE, 1)
}
func (l *vfsFileLocker) GetExclusive() xErrorCode {
// Acquire the EXCLUSIVE lock.
return l.writeLock(_SHARED_FIRST, _SHARED_SIZE)
}
func (l *vfsFileLocker) Downgrade() xErrorCode {
if l.state >= _EXCLUSIVE_LOCK {
// Downgrade to a SHARED lock.
if rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return l.unlock(_PENDING_BYTE, 2)
}
func (l *vfsFileLocker) Release() xErrorCode {
// Release all locks.
return l.unlock(0, 0)
}
func (l *vfsFileLocker) CheckReserved() (bool, xErrorCode) {
// Test the RESERVED lock.
return l.checkLock(_RESERVED_BYTE, 1)
}
func (l *vfsFileLocker) CheckPending() (bool, xErrorCode) {
// Test the PENDING lock.
return l.checkLock(_PENDING_BYTE, 1)
}
func (l *vfsFileLocker) unlock(start, len int64) xErrorCode {
err := l.fcntlSetLock(&syscall.Flock_t{
Type: syscall.F_UNLCK,
Start: start,
Len: len,
})
if err != nil {
return IOERR_UNLOCK
}
return _OK
}
func (l *vfsFileLocker) readLock(start, len int64) xErrorCode {
return l.errorCode(l.fcntlSetLock(&syscall.Flock_t{
Type: syscall.F_RDLCK,
Start: start,
Len: len,
}), IOERR_RDLOCK)
}
func (l *vfsFileLocker) writeLock(start, len int64) xErrorCode {
return l.errorCode(l.fcntlSetLock(&syscall.Flock_t{
Type: syscall.F_WRLCK,
Start: start,
Len: len,
}), IOERR_LOCK)
}
func (l *vfsFileLocker) checkLock(start, len int64) (bool, xErrorCode) {
lock := syscall.Flock_t{
Type: syscall.F_RDLCK,
Start: start,
Len: len,
}
if l.fcntlGetLock(&lock) != nil {
return false, IOERR_CHECKRESERVEDLOCK
}
return lock.Type != syscall.F_UNLCK, _OK
}
func (l *vfsFileLocker) fcntlGetLock(lock *syscall.Flock_t) error {
F_GETLK := syscall.F_GETLK
switch runtime.GOOS {
case "linux":
// https://github.com/torvalds/linux/blob/master/include/uapi/asm-generic/fcntl.h
F_GETLK = 36 // F_OFD_GETLK
case "darwin":
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
F_GETLK = 92 // F_OFD_GETLK
case "illumos":
// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/fcntl.h
F_GETLK = 47 // F_OFD_GETLK
}
return syscall.FcntlFlock(l.file.Fd(), F_GETLK, lock)
}
func (l *vfsFileLocker) fcntlSetLock(lock *syscall.Flock_t) error {
F_SETLK := syscall.F_SETLK
switch runtime.GOOS {
case "linux":
// https://github.com/torvalds/linux/blob/master/include/uapi/asm-generic/fcntl.h
F_SETLK = 37 // F_OFD_SETLK
case "darwin":
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
F_SETLK = 90 // F_OFD_SETLK
case "illumos":
// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/fcntl.h
F_SETLK = 48 // F_OFD_SETLK
}
return syscall.FcntlFlock(l.file.Fd(), F_SETLK, lock)
}
func (*vfsFileLocker) errorCode(err error, def xErrorCode) xErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(syscall.Errno); ok {
switch errno {
case
syscall.EACCES,
syscall.EAGAIN,
syscall.EBUSY,
syscall.EINTR,
syscall.ENOLCK,
syscall.EDEADLK,
syscall.ETIMEDOUT:
return xErrorCode(BUSY)
case syscall.EPERM:
return xErrorCode(PERM)
}
}
return def
}

View File

@@ -1,127 +0,0 @@
package sqlite3
import (
"os"
"syscall"
"golang.org/x/sys/windows"
)
func deleteOnClose(f *os.File) {}
func (l *vfsFileLocker) GetShared() xErrorCode {
// Acquire the SHARED lock.
return l.readLock(_SHARED_FIRST, _SHARED_SIZE)
}
func (l *vfsFileLocker) GetReserved() xErrorCode {
// Acquire the RESERVED lock.
return l.writeLock(_RESERVED_BYTE, 1)
}
func (l *vfsFileLocker) GetPending() xErrorCode {
// Acquire the PENDING lock.
return l.writeLock(_PENDING_BYTE, 1)
}
func (l *vfsFileLocker) GetExclusive() xErrorCode {
// Release the SHARED lock.
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
// Acquire the EXCLUSIVE lock.
rc := l.writeLock(_SHARED_FIRST, _SHARED_SIZE)
// Reacquire the SHARED lock.
if rc != _OK {
l.readLock(_SHARED_FIRST, _SHARED_SIZE)
}
return rc
}
func (l *vfsFileLocker) Downgrade() xErrorCode {
if l.state >= _EXCLUSIVE_LOCK {
// Release the SHARED lock.
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
// Reacquire the SHARED lock.
if rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE); rc != _OK {
// This should never happen.
// We should always be able to reacquire the read lock.
return IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
if l.state >= _RESERVED_LOCK {
l.unlock(_RESERVED_BYTE, 1)
}
if l.state >= _PENDING_LOCK {
l.unlock(_PENDING_BYTE, 1)
}
return _OK
}
func (l *vfsFileLocker) Release() xErrorCode {
// Release all locks.
if l.state >= _RESERVED_LOCK {
l.unlock(_RESERVED_BYTE, 1)
}
if l.state >= _SHARED_LOCK {
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
}
if l.state >= _PENDING_LOCK {
l.unlock(_PENDING_BYTE, 1)
}
return _OK
}
func (l *vfsFileLocker) CheckReserved() (bool, xErrorCode) {
// Test the RESERVED lock.
rc := l.readLock(_RESERVED_BYTE, 1)
if rc == _OK {
l.unlock(_RESERVED_BYTE, 1)
}
return rc != _OK, _OK
}
func (l *vfsFileLocker) CheckPending() (bool, xErrorCode) {
// Test the PENDING lock.
rc := l.readLock(_PENDING_BYTE, 1)
if rc == _OK {
l.unlock(_PENDING_BYTE, 1)
}
return rc != _OK, _OK
}
func (l *vfsFileLocker) unlock(start, len uint32) xErrorCode {
err := windows.UnlockFileEx(windows.Handle(l.file.Fd()),
0, len, 0, &windows.Overlapped{Offset: start})
if err != nil {
return IOERR_UNLOCK
}
return _OK
}
func (l *vfsFileLocker) readLock(start, len uint32) xErrorCode {
return l.errorCode(windows.LockFileEx(windows.Handle(l.file.Fd()),
windows.LOCKFILE_FAIL_IMMEDIATELY,
0, len, 0, &windows.Overlapped{Offset: start}),
IOERR_RDLOCK)
}
func (l *vfsFileLocker) writeLock(start, len uint32) xErrorCode {
return l.errorCode(windows.LockFileEx(windows.Handle(l.file.Fd()),
windows.LOCKFILE_FAIL_IMMEDIATELY|windows.LOCKFILE_EXCLUSIVE_LOCK,
0, len, 0, &windows.Overlapped{Offset: start}),
IOERR_LOCK)
}
func (*vfsFileLocker) errorCode(err error, def xErrorCode) xErrorCode {
if err == nil {
return _OK
}
if errno, _ := err.(syscall.Errno); errno == windows.ERROR_INVALID_HANDLE {
return def
}
return xErrorCode(BUSY)
}