Compare commits

...

91 Commits

Author SHA1 Message Date
Nuno Cruces
759b11a05d wazero 1.0.2. 2023-04-18 23:33:56 +01:00
Nuno Cruces
93ce586139 Optimize time. 2023-04-18 01:00:59 +01:00
Nuno Cruces
2e5082c616 Query pragmas at startup. 2023-04-17 00:29:20 +01:00
Nuno Cruces
34acc28af8 Fix CI. 2023-04-14 15:48:20 +01:00
Nuno Cruces
c1a640f7d8 Build using wasi-sdk. 2023-04-14 15:31:17 +01:00
Nuno Cruces
005b15610a Memory optimizations. 2023-04-11 15:33:38 +01:00
Nuno Cruces
23ee4ccb0b Refactor. 2023-04-10 19:55:44 +01:00
Nuno Cruces
3a8cfd036d Dependencies. 2023-04-10 14:24:06 +01:00
Nuno Cruces
c38382fd8e Refactor. 2023-03-31 14:33:24 +01:00
Nuno Cruces
8509e0b6c8 Test coverage. 2023-03-31 13:42:31 +01:00
Nuno Cruces
9c07e57252 Refactor. 2023-03-29 15:06:22 +01:00
Nuno Cruces
80039385d3 Read only files. 2023-03-25 11:46:13 +00:00
Nuno Cruces
89f4327b2b Sync journal directories. 2023-03-25 11:16:51 +00:00
Nuno Cruces
37a3ff37e8 wazero 1.0. 2023-03-24 21:17:30 +00:00
Nuno Cruces
d880d6842c Refactor VFS. 2023-03-23 13:29:26 +00:00
Nuno Cruces
bef46e7954 Locking improvements (windows). 2023-03-23 12:40:55 +00:00
Nuno Cruces
4e72b4d117 Locking fix. 2023-03-23 11:26:19 +00:00
Nuno Cruces
3b08d02a83 Lock refactoring. 2023-03-23 01:55:54 +00:00
Nuno Cruces
b19c12c4c7 SQLite 3.41.2, prefer speed over size. 2023-03-23 00:44:43 +00:00
Nuno Cruces
859a21ef4e CI improvements. 2023-03-22 12:08:33 +00:00
Nuno Cruces
8ff0ee752f Use flock. 2023-03-22 03:15:54 +00:00
Nuno Cruces
589ad86f76 Extensions. 2023-03-21 00:13:12 +00:00
Nuno Cruces
1a3a1be1f6 Fix test. 2023-03-20 14:26:25 +00:00
Nuno Cruces
222c217bc8 Scripts. 2023-03-20 13:06:31 +00:00
Nuno Cruces
c1dc716391 VFS performance. 2023-03-20 11:02:34 +00:00
Nuno Cruces
71e1e5a8ee Avoid some copies. 2023-03-20 02:16:42 +00:00
Nuno Cruces
e4efb20c71 Generate coverage chart. 2023-03-18 03:51:05 +00:00
Nuno Cruces
2c9459d907 Add SQLite speedtest1. 2023-03-18 03:03:11 +00:00
Nuno Cruces
d0875e5fab Lock timeouts. 2023-03-18 01:13:31 +00:00
Nuno Cruces
15dec13f15 FCNTL_SIZE_HINT, refactor. 2023-03-17 17:13:03 +00:00
Nuno Cruces
f38e36109a FCNTL_HAS_MOVED. 2023-03-17 14:11:09 +00:00
Nuno Cruces
4cb65ccbd9 xFileControl, xDeviceCharacteristics, PSOW. 2023-03-17 13:39:19 +00:00
Nuno Cruces
f789c2fb8b OPEN_NOFOLLOW. 2023-03-16 12:27:44 +00:00
Nuno Cruces
c6a2617dfc Locking fixes. 2023-03-16 02:52:22 +00:00
Nuno Cruces
6fc0afcd12 Towards lock timeouts. 2023-03-15 13:58:16 +00:00
Nuno Cruces
77088962f5 SQLite 3.41.1. 2023-03-15 13:29:09 +00:00
Nuno Cruces
71da34861b Fix time collation. 2023-03-13 04:19:58 +00:00
Nuno Cruces
56e8281bdb Time collation tests. 2023-03-10 16:42:20 +00:00
Nuno Cruces
f61d430e65 Documentation. 2023-03-10 16:26:19 +00:00
Nuno Cruces
dbaed53b9a Sync and delete improvements. 2023-03-10 14:17:02 +00:00
Nuno Cruces
8b1bfd04e3 Simplify windows hacks. 2023-03-10 10:43:02 +00:00
Nuno Cruces
11c1687146 Time collation. 2023-03-09 14:42:29 +00:00
Nuno Cruces
94c43a8685 Use access syscall. 2023-03-09 01:59:46 +00:00
Nuno Cruces
a25159a070 Fix sharing violation. 2023-03-09 01:23:52 +00:00
Nuno Cruces
e007e9b060 Windows fixes. 2023-03-08 20:10:46 +00:00
Nuno Cruces
66a730893f Fix readonly transaction rollback. 2023-03-08 18:07:21 +00:00
Nuno Cruces
926adeb3f5 Remove MustPrepare. 2023-03-08 17:39:41 +00:00
Nuno Cruces
677f51bec1 Savepoint API. 2023-03-08 17:39:23 +00:00
Nuno Cruces
5d6f92b733 Documentation, tests, tweaks. 2023-03-08 13:29:33 +00:00
Nuno Cruces
f5747f19fb Tests. 2023-03-07 14:19:22 +00:00
Nuno Cruces
dfcdbf9c4c Online backup. 2023-03-07 12:15:29 +00:00
Nuno Cruces
ad1e8f4b0e Refactor. 2023-03-07 10:47:55 +00:00
Nuno Cruces
8f29882671 Pass mptest crash. 2023-03-07 04:37:55 +00:00
Nuno Cruces
6c96a019e6 Towards shared modules: refactor. 2023-03-07 04:37:55 +00:00
Nuno Cruces
d291738b81 Towards shared modules: refactor. 2023-03-07 04:37:55 +00:00
Nuno Cruces
c1263d4f33 Towards shared modules: refactor. 2023-03-07 04:37:55 +00:00
Nuno Cruces
1ebdc1aa93 Towards shared modules: refactor. 2023-03-07 04:37:55 +00:00
Nuno Cruces
4dd10f071a Towards shared modules: backup. 2023-03-07 04:37:55 +00:00
Nuno Cruces
7dbddfa5c0 Towards shared modules. 2023-03-07 04:37:55 +00:00
dependabot[bot]
ce5e035801 Bump golang.org/x/sys from 0.5.0 to 0.6.0
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.5.0 to 0.6.0.
- [Release notes](https://github.com/golang/sys/releases)
- [Commits](https://github.com/golang/sys/compare/v0.5.0...v0.6.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-03-07 04:11:45 +00:00
Nuno Cruces
8bb8367a36 Refactor mptest. 2023-03-06 14:27:49 +00:00
Nuno Cruces
9f59b3d0ec Pass mptest multiwrite on Windows. 2023-03-05 14:33:51 +00:00
Nuno Cruces
5f893b5459 Add SQLite mptest. 2023-03-05 12:20:02 +00:00
Nuno Cruces
35b1a97b88 Use finalizers to detect unclosed connections. 2023-03-03 14:50:55 +00:00
Nuno Cruces
416c3863a0 Documentation, tests, dependencies. 2023-03-01 23:47:24 +00:00
Nuno Cruces
dbc400eb15 Refactor native code. 2023-03-01 13:12:32 +00:00
Nuno Cruces
35265271aa Rename. 2023-03-01 11:18:25 +00:00
Nuno Cruces
c7165a2e56 Documentation. 2023-03-01 10:34:39 +00:00
Nuno Cruces
e64bffa520 Pragmas. 2023-02-28 16:03:31 +00:00
Nuno Cruces
54046b6adc Documentation. 2023-02-28 16:02:13 +00:00
Nuno Cruces
1b3823483f Incremental blobs. 2023-02-27 13:45:32 +00:00
Nuno Cruces
ce6d0627b2 Tests. 2023-02-27 12:07:48 +00:00
Nuno Cruces
dd30215702 Incremental blobs. 2023-02-27 04:08:55 +00:00
Nuno Cruces
21aff4c9f5 Towards incremental blobs. 2023-02-27 03:20:23 +00:00
Nuno Cruces
b30f127547 WAL mode, extensions. 2023-02-26 04:49:10 +00:00
Nuno Cruces
6509e5deb2 Transactions. 2023-02-26 03:22:08 +00:00
Nuno Cruces
125b8053f8 Fix readonly transactions. 2023-02-25 15:34:24 +00:00
Nuno Cruces
1e4a246d2f Error handling. 2023-02-25 15:11:07 +00:00
Nuno Cruces
e6cd0aaf87 MustPrepare. 2023-02-25 01:29:46 +00:00
Nuno Cruces
c1472a48b0 Tests. 2023-02-25 00:50:03 +00:00
Nuno Cruces
a69ab1ebe3 Fix data race. 2023-02-24 15:19:57 +00:00
Nuno Cruces
1190c21684 Refactor. 2023-02-24 15:06:19 +00:00
Nuno Cruces
8c28c3a6f4 Interrupt API. 2023-02-24 14:56:49 +00:00
Nuno Cruces
0146496036 Nested transactions. 2023-02-24 14:31:41 +00:00
Nuno Cruces
fcd33d2f0f Time improvements. 2023-02-24 11:09:30 +00:00
Nuno Cruces
627df5db0f No sandbox. 2023-02-23 14:16:37 +00:00
Nuno Cruces
1ed62d300d Require OFD locks. 2023-02-23 13:29:51 +00:00
Nuno Cruces
5b2451c3ad Default sector size. 2023-02-23 03:22:39 +00:00
Nuno Cruces
d52e0371eb Only reuse main db files. 2023-02-23 02:22:57 +00:00
Nuno Cruces
75f2644b0e SQLite 3.41.0. 2023-02-22 20:08:50 +00:00
Nuno Cruces
71ae26e5c9 Documentation. 2023-02-22 17:51:30 +00:00
99 changed files with 7123 additions and 2795 deletions

View File

@@ -15,12 +15,30 @@ jobs:
steps:
- uses: actions/checkout@v3
with:
lfs: 'true'
- name: Set up Go
uses: actions/setup-go@v3
- name: Set up
uses: actions/setup-go@v4
with:
go-version: stable
cache: true
- name: Format
run: gofmt -s -w . && git diff --exit-code
if: matrix.os != 'windows-latest'
- name: Tidy
run: go mod tidy && git diff --exit-code
- name: Download
run: go mod download
- name: Verify
run: go mod verify
- name: Vet
run: go vet ./...
continue-on-error: true
- name: Build
run: go build -v ./...
@@ -28,12 +46,15 @@ jobs:
- name: Test
run: go test -v ./...
- name: Test data races
run: go test -v -race ./...
if: matrix.os == 'ubuntu-latest'
- name: Test BSD locks
run: go test -v -tags sqlite3_bsd ./...
if: matrix.os == 'macos-latest'
- name: Update coverage report
uses: ncruces/go-coverage-report@main
- name: Coverage report
uses: ncruces/go-coverage-report@v0
with:
chart: 'true'
amend: 'true'
if: |
matrix.os == 'ubuntu-latest' &&
github.event_name == 'push'

5
.gitignore vendored
View File

@@ -13,7 +13,4 @@
# Dependency directories (remove the comment below to include it)
# vendor/
tools
# Project
sqlite3/sqlite3*
tools

View File

@@ -2,25 +2,82 @@
[![Go Reference](https://pkg.go.dev/badge/image)](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
[![Go Report](https://goreportcard.com/badge/github.com/ncruces/go-sqlite3)](https://goreportcard.com/report/github.com/ncruces/go-sqlite3)
[![Go Coverage](https://github.com/ncruces/go-sqlite3/wiki/coverage.svg)](https://raw.githack.com/wiki/ncruces/go-sqlite3/coverage.html)
[![Go Coverage](https://github.com/ncruces/go-sqlite3/wiki/coverage.svg)](https://github.com/ncruces/go-sqlite3/wiki/Test-coverage-report)
⚠️ CAUTION ⚠️
Go module `github.com/ncruces/go-sqlite3` wraps a [WASM](https://webassembly.org/) build of [SQLite](https://sqlite.org/),
and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings.
This is a WIP.\
DO NOT USE with data you care about.
- Package [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
wraps the [C SQLite API](https://www.sqlite.org/cintro.html)
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
- Package [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver)
provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
- Package [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
embeds a build of SQLite into your application.
Roadmap:
- [x] build SQLite using `zig cc --target=wasm32-wasi`
- [x] `:memory:` databases
- [x] port [`test_demovfs.c`](https://www.sqlite.org/src/doc/trunk/src/test_demovfs.c) to Go
- branch [`wasi`](https://github.com/ncruces/go-sqlite3/tree/wasi) uses `test_demovfs.c` directly
- [x] design a simple, nice API, enough for simple use cases
- [x] provide a simple `database/sql` driver
- [x] file locking, compatible with SQLite on Windows/Unix
- [ ] shared memory, compatible with SQLite on Windows/Unix
- needed for improved WAL mode
- [ ] advanced features
- [ ] incremental BLOB I/O
- [ ] online backup
### Caveats
This module replaces the SQLite [OS Interface](https://www.sqlite.org/vfs.html) (aka VFS)
with a [pure Go](internal/vfs/) implementation.
This has numerous benefits, but also comes with some drawbacks.
#### Write-Ahead Logging
Because WASM does not support shared memory,
[WAL](https://www.sqlite.org/wal.html) support is [limited](https://www.sqlite.org/wal.html#noshm).
To work around this limitation, SQLite is compiled with
[`SQLITE_DEFAULT_LOCKING_MODE=1`](https://www.sqlite.org/compile.html#default_locking_mode),
making `EXCLUSIVE` the default locking mode.
For non-WAL databases, `NORMAL` locking mode can be activated with
[`PRAGMA locking_mode=NORMAL`](https://www.sqlite.org/pragma.html#pragma_locking_mode).
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
the `database/sql` driver defaults to `NORMAL` locking mode.
To open WAL databases, or use `EXCLUSIVE` locking mode,
disable connection pooling by calling
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
#### POSIX Advisory Locks
POSIX advisory locks, which SQLite uses, are
[broken by design](https://www.sqlite.org/src/artifact/90c4fa?ln=1073-1161).
On Linux, macOS and illumos, this module uses
[OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
to synchronize access to database files.
OFD locks are fully compatible with process-associated POSIX advisory locks.
On BSD Unixes, this module uses
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2).
BSD locks may _not_ be compatible with process-associated POSIX advisory locks.
#### Testing
The pure Go VFS is stress tested by running an unmodified build of SQLite's
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
on Linux, macOS and Windows.
Performance is tested by running
[speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c).
### Roadmap
- [ ] advanced SQLite features
- [x] nested transactions
- [x] incremental BLOB I/O
- [x] online backup
- [ ] session extension
- [ ] snapshot
- [ ] custom SQL functions
- [ ] custom VFSes
- [ ] in-memory VFS
- [ ] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
- [ ] custom VFS API
### Alternatives
- [`modernc.org/sqlite`](https://pkg.go.dev/modernc.org/sqlite)
- [`crawshaw.io/sqlite`](https://pkg.go.dev/crawshaw.io/sqlite)
- [`github.com/mattn/go-sqlite3`](https://pkg.go.dev/github.com/mattn/go-sqlite3)
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)

112
api.go
View File

@@ -1,112 +0,0 @@
// Package sqlite3 wraps the C SQLite API.
package sqlite3
import (
"context"
"github.com/tetratelabs/wazero/api"
)
func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) {
getFun := func(name string) api.Function {
f := module.ExportedFunction(name)
if f == nil {
err = noFuncErr + errorString(name)
return nil
}
return f
}
getPtr := func(name string) uint32 {
global := module.ExportedGlobal(name)
if global == nil {
err = noGlobalErr + errorString(name)
return 0
}
return memory{module}.readUint32(uint32(global.Get()))
}
c := Conn{
ctx: ctx,
mem: memory{module},
api: sqliteAPI{
malloc: getFun("malloc"),
free: getFun("free"),
destructor: uint64(getPtr("malloc_destructor")),
errcode: getFun("sqlite3_errcode"),
errstr: getFun("sqlite3_errstr"),
errmsg: getFun("sqlite3_errmsg"),
erroff: getFun("sqlite3_error_offset"),
open: getFun("sqlite3_open_v2"),
close: getFun("sqlite3_close"),
prepare: getFun("sqlite3_prepare_v3"),
finalize: getFun("sqlite3_finalize"),
reset: getFun("sqlite3_reset"),
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
bindName: getFun("sqlite3_bind_parameter_name"),
bindNull: getFun("sqlite3_bind_null"),
bindInteger: getFun("sqlite3_bind_int64"),
bindFloat: getFun("sqlite3_bind_double"),
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
columnInteger: getFun("sqlite3_column_int64"),
columnFloat: getFun("sqlite3_column_double"),
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
changes: getFun("sqlite3_changes64"),
interrupt: getFun("sqlite3_interrupt"),
},
}
if err != nil {
return nil, err
}
return &c, nil
}
type sqliteAPI struct {
malloc api.Function
free api.Function
destructor uint64
errcode api.Function
errstr api.Function
errmsg api.Function
erroff api.Function
open api.Function
close api.Function
prepare api.Function
finalize api.Function
reset api.Function
step api.Function
exec api.Function
clearBindings api.Function
bindNull api.Function
bindCount api.Function
bindIndex api.Function
bindName api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
columnInteger api.Function
columnFloat api.Function
columnText api.Function
columnBlob api.Function
columnBytes api.Function
lastRowid api.Function
changes api.Function
interrupt api.Function
}

134
backup.go Normal file
View File

@@ -0,0 +1,134 @@
package sqlite3
// Backup is a handle to an open BLOB.
//
// https://www.sqlite.org/c3ref/backup.html
type Backup struct {
c *Conn
handle uint32
otherc uint32
}
// Backup backs up srcDB on the src connection to the "main" database in dstURI.
//
// Backup calls [Open] to open the SQLite database file dstURI,
// and blocks until the entire backup is complete.
// Use [Conn.BackupInit] for incremental backup.
//
// https://www.sqlite.org/backup.html
func (src *Conn) Backup(srcDB, dstURI string) error {
b, err := src.BackupInit(srcDB, dstURI)
if err != nil {
return err
}
defer b.Close()
_, err = b.Step(-1)
return err
}
// Restore restores dstDB on the dst connection from the "main" database in srcURI.
//
// Restore calls [Open] to open the SQLite database file srcURI,
// and blocks until the entire restore is complete.
//
// https://www.sqlite.org/backup.html
func (dst *Conn) Restore(dstDB, srcURI string) error {
src, err := dst.openDB(srcURI, OPEN_READONLY|OPEN_URI)
if err != nil {
return err
}
b, err := dst.backupInit(dst.handle, dstDB, src, "main")
if err != nil {
return err
}
defer b.Close()
_, err = b.Step(-1)
return err
}
// BackupInit initializes a backup operation to copy the content of one database into another.
//
// BackupInit calls [Open] to open the SQLite database file dstURI,
// then initializes a backup that copies the contents of srcDB on the src connection
// to the "main" database in dstURI.
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupinit
func (src *Conn) BackupInit(srcDB, dstURI string) (*Backup, error) {
dst, err := src.openDB(dstURI, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
if err != nil {
return nil, err
}
return src.backupInit(dst, "main", src.handle, srcDB)
}
func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string) (*Backup, error) {
defer c.arena.reset()
dstPtr := c.arena.string(dstName)
srcPtr := c.arena.string(srcName)
other := dst
if c.handle == dst {
other = src
}
r := c.call(c.api.backupInit,
uint64(dst), uint64(dstPtr),
uint64(src), uint64(srcPtr))
if r[0] == 0 {
defer c.closeDB(other)
r = c.call(c.api.errcode, uint64(dst))
return nil, c.module.error(r[0], dst)
}
return &Backup{
c: c,
otherc: other,
handle: uint32(r[0]),
}, nil
}
// Close finishes a backup operation.
//
// It is safe to close a nil, zero or closed Backup.
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupfinish
func (b *Backup) Close() error {
if b == nil || b.handle == 0 {
return nil
}
r := b.c.call(b.c.api.backupFinish, uint64(b.handle))
b.c.closeDB(b.otherc)
b.handle = 0
return b.c.error(r[0])
}
// Step copies up to nPage pages between the source and destination databases.
// If nPage is negative, all remaining source pages are copied.
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupstep
func (b *Backup) Step(nPage int) (done bool, err error) {
r := b.c.call(b.c.api.backupStep, uint64(b.handle), uint64(nPage))
if r[0] == _DONE {
return true, nil
}
return false, b.c.error(r[0])
}
// Remaining returns the number of pages still to be backed up
// at the conclusion of the most recent [Backup.Step].
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupremaining
func (b *Backup) Remaining() int {
r := b.c.call(b.c.api.backupRemaining, uint64(b.handle))
return int(r[0])
}
// PageCount returns the total number of pages in the source database
// at the conclusion of the most recent [Backup.Step].
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backuppagecount
func (b *Backup) PageCount() int {
r := b.c.call(b.c.api.backupFinish, uint64(b.handle))
return int(r[0])
}

159
blob.go
View File

@@ -1,6 +1,163 @@
package sqlite3
import (
"io"
"github.com/ncruces/go-sqlite3/internal/util"
)
// ZeroBlob represents a zero-filled, length n BLOB
// that can be used as an argument to
// [database.sql.DB.Exec] and similar methods.
// [database/sql.DB.Exec] and similar methods.
type ZeroBlob int64
// Blob is a handle to an open BLOB.
//
// It implements [io.ReadWriteSeeker] for incremental BLOB I/O.
//
// https://www.sqlite.org/c3ref/blob.html
type Blob struct {
c *Conn
bytes int64
offset int64
handle uint32
}
var _ io.ReadWriteSeeker = &Blob{}
// OpenBlob opens a BLOB for incremental I/O.
//
// https://www.sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
c.checkInterrupt()
defer c.arena.reset()
blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db)
tablePtr := c.arena.string(table)
columnPtr := c.arena.string(column)
var flags uint64
if write {
flags = 1
}
r := c.call(c.api.blobOpen, uint64(c.handle),
uint64(dbPtr), uint64(tablePtr), uint64(columnPtr),
uint64(row), flags, uint64(blobPtr))
if err := c.error(r[0]); err != nil {
return nil, err
}
blob := Blob{c: c}
blob.handle = util.ReadUint32(c.mod, blobPtr)
blob.bytes = int64(c.call(c.api.blobBytes, uint64(blob.handle))[0])
return &blob, nil
}
// Close closes a BLOB handle.
//
// It is safe to close a nil, zero or closed Blob.
//
// https://www.sqlite.org/c3ref/blob_close.html
func (b *Blob) Close() error {
if b == nil || b.handle == 0 {
return nil
}
r := b.c.call(b.c.api.blobClose, uint64(b.handle))
b.handle = 0
return b.c.error(r[0])
}
// Size returns the size of the BLOB in bytes.
//
// https://www.sqlite.org/c3ref/blob_bytes.html
func (b *Blob) Size() int64 {
return b.bytes
}
// Read implements the [io.Reader] interface.
//
// https://www.sqlite.org/c3ref/blob_read.html
func (b *Blob) Read(p []byte) (n int, err error) {
if b.offset >= b.bytes {
return 0, io.EOF
}
want := int64(len(p))
avail := b.bytes - b.offset
if want > avail {
want = avail
}
ptr := b.c.new(uint64(want))
defer b.c.free(ptr)
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
uint64(ptr), uint64(want), uint64(b.offset))
err = b.c.error(r[0])
if err != nil {
return 0, err
}
mem := util.View(b.c.mod, ptr, uint64(want))
copy(p, mem)
b.offset += want
if b.offset >= b.bytes {
err = io.EOF
}
return int(want), err
}
// Write implements the [io.Writer] interface.
//
// https://www.sqlite.org/c3ref/blob_write.html
func (b *Blob) Write(p []byte) (n int, err error) {
offset := b.offset
if offset > b.bytes {
offset = b.bytes
}
ptr := b.c.newBytes(p)
defer b.c.free(ptr)
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
uint64(ptr), uint64(len(p)), uint64(offset))
err = b.c.error(r[0])
if err != nil {
return 0, err
}
b.offset += int64(len(p))
return len(p), nil
}
// Seek implements the [io.Seeker] interface.
func (b *Blob) Seek(offset int64, whence int) (int64, error) {
switch whence {
default:
return 0, util.WhenceErr
case io.SeekStart:
break
case io.SeekCurrent:
offset += b.offset
case io.SeekEnd:
offset += b.bytes
}
if offset < 0 {
return 0, util.OffsetErr
}
b.offset = offset
return offset, nil
}
// Reopen moves a BLOB handle to a new row of the same database table.
//
// https://www.sqlite.org/c3ref/blob_reopen.html
func (b *Blob) Reopen(row int64) error {
r := b.c.call(b.c.api.blobReopen, uint64(b.handle), uint64(row))
b.bytes = int64(b.c.call(b.c.api.blobBytes, uint64(b.handle))[0])
b.offset = 0
return b.c.error(r[0])
}

View File

@@ -1,58 +0,0 @@
package sqlite3
import (
"context"
"os"
"strconv"
"sync"
"sync/atomic"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
// Configure SQLite.
var (
Binary []byte // Binary to load.
Path string // Path to load the binary from.
)
var sqlite3 sqlite3Runtime
type sqlite3Runtime struct {
once sync.Once
runtime wazero.Runtime
compiled wazero.CompiledModule
instances atomic.Uint64
err error
}
func (s *sqlite3Runtime) instantiateModule(ctx context.Context) (api.Module, error) {
s.once.Do(func() { s.compileModule(ctx) })
if s.err != nil {
return nil, s.err
}
cfg := wazero.NewModuleConfig().
WithName("sqlite3-" + strconv.FormatUint(s.instances.Add(1), 10))
return s.runtime.InstantiateModule(ctx, s.compiled, cfg)
}
func (s *sqlite3Runtime) compileModule(ctx context.Context) {
s.runtime = wazero.NewRuntime(ctx)
vfsInstantiate(ctx, s.runtime)
bin := Binary
if bin == nil && Path != "" {
bin, s.err = os.ReadFile(Path)
if s.err != nil {
return
}
}
if bin == nil {
s.err = binaryErr
return
}
s.compiled, s.err = s.runtime.CompileModule(ctx, bin)
}

437
conn.go
View File

@@ -2,64 +2,120 @@ package sqlite3
import (
"context"
"math"
"database/sql/driver"
"errors"
"fmt"
"net/url"
"runtime"
"strings"
"sync/atomic"
"unsafe"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Conn is a database connection handle.
// A Conn is not safe for concurrent use by multiple goroutines.
//
// https://www.sqlite.org/c3ref/sqlite3.html
type Conn struct {
ctx context.Context
api sqliteAPI
mem memory
handle uint32
*module
arena arena
pending *Stmt
waiter chan struct{}
done <-chan struct{}
interrupt context.Context
waiter chan struct{}
pending *Stmt
arena arena
handle uint32
}
// Open calls [OpenFlags] with [OPEN_READWRITE] and [OPEN_CREATE].
func Open(filename string) (conn *Conn, err error) {
return OpenFlags(filename, OPEN_READWRITE|OPEN_CREATE)
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE], [OPEN_URI] and [OPEN_NOFOLLOW].
func Open(filename string) (*Conn, error) {
return newConn(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI|OPEN_NOFOLLOW)
}
// OpenFlags opens an SQLite database file as specified by the filename argument.
//
// If none of the required flags is used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used.
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
//
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
//
// https://www.sqlite.org/c3ref/open.html
func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
ctx := context.Background()
module, err := sqlite3.instantiateModule(ctx)
func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
if flags&(OPEN_READONLY|OPEN_READWRITE|OPEN_CREATE) == 0 {
flags |= OPEN_READWRITE | OPEN_CREATE
}
return newConn(filename, flags)
}
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
mod, err := instantiateModule()
if err != nil {
return nil, err
}
defer func() {
if conn == nil {
module.Close(ctx)
mod.close()
} else {
runtime.SetFinalizer(conn, util.Finalizer[Conn](3))
}
}()
c, err := newConn(ctx, module)
c := &Conn{module: mod}
c.arena = c.newArena(1024)
c.handle, err = c.openDB(filename, flags)
if err != nil {
return nil, err
}
c.arena = c.newArena(1024)
return c, nil
}
func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
defer c.arena.reset()
connPtr := c.arena.new(ptrlen)
namePtr := c.arena.string(filename)
r, err := c.api.open.Call(c.ctx, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
if err != nil {
panic(err)
flags |= OPEN_EXRESCODE
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
handle := util.ReadUint32(c.mod, connPtr)
if err := c.module.error(r[0], handle); err != nil {
c.closeDB(handle)
return 0, err
}
c.handle = c.mem.readUint32(connPtr)
if err := c.error(r[0]); err != nil {
return nil, err
if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
var pragmas strings.Builder
if _, after, ok := strings.Cut(filename, "?"); ok {
query, _ := url.ParseQuery(after)
for _, p := range query["_pragma"] {
pragmas.WriteString(`PRAGMA `)
pragmas.WriteString(p)
pragmas.WriteByte(';')
}
}
c.arena.reset()
pragmaPtr := c.arena.string(pragmas.String())
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.module.error(r[0], handle, pragmas.String()); err != nil {
if errors.Is(err, ERROR) {
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
c.closeDB(handle)
return 0, err
}
}
return handle, nil
}
func (c *Conn) closeDB(handle uint32) {
r := c.call(c.api.closeZombie, uint64(handle))
if err := c.module.error(r[0], handle); err != nil {
panic(err)
}
return c, nil
}
// Close closes the database connection.
@@ -68,7 +124,7 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
// open blob handles, and/or unfinished backup objects,
// Close will leave the database connection open and return [BUSY].
//
// It is safe to close a nil, zero or closed connection.
// It is safe to close a nil, zero or closed Conn.
//
// https://www.sqlite.org/c3ref/close.html
func (c *Conn) Close() error {
@@ -76,82 +132,18 @@ func (c *Conn) Close() error {
return nil
}
c.SetInterrupt(nil)
r, err := c.api.close.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
c.SetInterrupt(context.Background())
c.pending.Close()
c.pending = nil
r := c.call(c.api.close, uint64(c.handle))
if err := c.error(r[0]); err != nil {
return err
}
c.handle = 0
return c.mem.mod.Close(c.ctx)
}
// SetInterrupt interrupts a long-running query when done is closed.
//
// Subsequent uses of the connection will return [INTERRUPT]
// until done is reset by another call to SetInterrupt.
//
// Typically, done is provided by [context.Context.Done]:
//
// ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
// conn.SetInterrupt(ctx.Done())
// defer cancel()
//
// https://www.sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) {
// Is a waiter running?
if c.waiter != nil {
c.waiter <- struct{}{} // Cancel the waiter.
<-c.waiter // Wait for it to finish.
c.waiter = nil
}
// Finalize the uncompleted SQL statement.
if c.pending != nil {
c.pending.Close()
c.pending = nil
}
old = c.done
c.done = done
if done == nil {
return old
}
// Creating an uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
c.pending.Step()
waiter := make(chan struct{})
c.waiter = waiter
go func() {
select {
case <-waiter: // Waiter was cancelled.
break
case <-done: // Done was closed.
// This is safe to call from a goroutine
// because it doesn't touch the C stack.
_, err := c.api.interrupt.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
// Wait for the next call to SetInterrupt.
<-waiter
}
// Signal that the waiter has finished.
waiter <- struct{}{}
}()
return old
runtime.SetFinalizer(c, nil)
return c.module.close()
}
// Exec is a convenience function that allows an application to run
@@ -159,13 +151,11 @@ func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) {
//
// https://www.sqlite.org/c3ref/exec.html
func (c *Conn) Exec(sql string) error {
c.checkInterrupt()
defer c.arena.reset()
sqlPtr := c.arena.string(sql)
r, err := c.api.exec.Call(c.ctx, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
if err != nil {
panic(err)
}
r := c.call(c.api.exec, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
return c.error(r[0])
}
@@ -190,16 +180,13 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
tailPtr := c.arena.new(ptrlen)
sqlPtr := c.arena.string(sql)
r, err := c.api.prepare.Call(c.ctx, uint64(c.handle),
r := c.call(c.api.prepare, uint64(c.handle),
uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),
uint64(stmtPtr), uint64(tailPtr))
if err != nil {
panic(err)
}
stmt = &Stmt{c: c}
stmt.handle = c.mem.readUint32(stmtPtr)
i := c.mem.readUint32(tailPtr)
stmt.handle = util.ReadUint32(c.mod, stmtPtr)
i := util.ReadUint32(c.mod, tailPtr)
tail = sql[i-sqlPtr:]
if err := c.error(r[0], sql); err != nil {
@@ -211,16 +198,21 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
return
}
// GetAutocommit tests the connection for auto-commit mode.
//
// https://www.sqlite.org/c3ref/get_autocommit.html
func (c *Conn) GetAutocommit() bool {
r := c.call(c.api.autocommit, uint64(c.handle))
return r[0] != 0
}
// LastInsertRowID returns the rowid of the most recent successful INSERT
// on the database connection.
//
// https://www.sqlite.org/c3ref/last_insert_rowid.html
func (c *Conn) LastInsertRowID() uint64 {
r, err := c.api.lastRowid.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
return r[0]
func (c *Conn) LastInsertRowID() int64 {
r := c.call(c.api.lastRowid, uint64(c.handle))
return int64(r[0])
}
// Changes returns the number of rows modified, inserted or deleted
@@ -228,125 +220,118 @@ func (c *Conn) LastInsertRowID() uint64 {
// on the database connection.
//
// https://www.sqlite.org/c3ref/changes.html
func (c *Conn) Changes() uint64 {
r, err := c.api.changes.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
func (c *Conn) Changes() int64 {
r := c.call(c.api.changes, uint64(c.handle))
return int64(r[0])
}
// SetInterrupt interrupts a long-running query when a context is done.
//
// Subsequent uses of the connection will return [INTERRUPT]
// until the context is reset by another call to SetInterrupt.
//
// To associate a timeout with a connection:
//
// ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
// conn.SetInterrupt(ctx)
// defer cancel()
//
// SetInterrupt returns the old context assigned to the connection.
//
// https://www.sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
// Is a waiter running?
if c.waiter != nil {
c.waiter <- struct{}{} // Cancel the waiter.
<-c.waiter // Wait for it to finish.
c.waiter = nil
}
return r[0]
// Reset the pending statement.
if c.pending != nil {
c.pending.Reset()
}
old = c.interrupt
c.interrupt = ctx
if ctx == nil || ctx.Done() == nil {
return old
}
// Creating an uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
if c.pending == nil {
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
}
c.pending.Step()
// Don't create the goroutine if we're already interrupted.
// This happens frequently while restoring to a previously interrupted state.
if c.checkInterrupt() {
return old
}
waiter := make(chan struct{})
c.waiter = waiter
go func() {
select {
case <-waiter: // Waiter was cancelled.
break
case <-ctx.Done(): // Done was closed.
buf := util.View(c.mod, c.handle+c.api.interrupt, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
// Wait for the next call to SetInterrupt.
<-waiter
}
// Signal that the waiter has finished.
waiter <- struct{}{}
}()
return old
}
func (c *Conn) checkInterrupt() bool {
if c.interrupt == nil || c.interrupt.Err() == nil {
return false
}
buf := util.View(c.mod, c.handle+c.api.interrupt, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
return true
}
// Pragma executes a PRAGMA statement and returns any results.
//
// https://www.sqlite.org/pragma.html
func (c *Conn) Pragma(str string) ([]string, error) {
stmt, _, err := c.Prepare(`PRAGMA ` + str)
if err != nil {
return nil, err
}
defer stmt.Close()
var pragmas []string
for stmt.Step() {
pragmas = append(pragmas, stmt.ColumnText(0))
}
return pragmas, stmt.Close()
}
func (c *Conn) error(rc uint64, sql ...string) error {
if rc == _OK {
return nil
}
err := Error{code: rc}
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
panic(oomErr)
}
var r []uint64
r, _ = c.api.errstr.Call(c.ctx, rc)
if r != nil {
err.str = c.mem.readString(uint32(r[0]), _MAX_STRING)
}
r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle))
if r != nil {
err.msg = c.mem.readString(uint32(r[0]), _MAX_STRING)
}
if sql != nil {
r, _ = c.api.erroff.Call(c.ctx, uint64(c.handle))
if r != nil && r[0] != math.MaxUint32 {
err.sql = sql[0][r[0]:]
}
}
switch err.msg {
case err.str, "not an error":
err.msg = ""
}
return &err
return c.module.error(rc, c.handle, sql...)
}
func (c *Conn) free(ptr uint32) {
if ptr == 0 {
return
}
_, err := c.api.free.Call(c.ctx, uint64(ptr))
if err != nil {
panic(err)
}
}
// DriverConn is implemented by the SQLite [database/sql] driver connection.
//
// It can be used to access advanced SQLite features like
// [savepoints] and [incremental BLOB I/O].
//
// [savepoints]: https://www.sqlite.org/lang_savepoint.html
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
type DriverConn interface {
driver.ConnBeginTx
driver.ExecerContext
driver.ConnPrepareContext
func (c *Conn) new(size uint32) uint32 {
r, err := c.api.malloc.Call(c.ctx, uint64(size))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 && size != 0 {
panic(oomErr)
}
return ptr
}
func (c *Conn) newBytes(b []byte) uint32 {
if b == nil {
return 0
}
ptr := c.new(uint32(len(b)))
c.mem.writeBytes(ptr, b)
return ptr
}
func (c *Conn) newString(s string) uint32 {
ptr := c.new(uint32(len(s) + 1))
c.mem.writeString(ptr, s)
return ptr
}
func (c *Conn) newArena(size uint32) arena {
return arena{
c: c,
size: size,
base: c.new(size),
}
}
type arena struct {
c *Conn
base uint32
next uint32
size uint32
ptrs []uint32
}
func (a *arena) reset() {
for _, ptr := range a.ptrs {
a.c.free(ptr)
}
a.ptrs = nil
a.next = 0
}
func (a *arena) new(size uint32) uint32 {
if a.next+size <= a.size {
ptr := a.base + a.next
a.next += size
return ptr
}
ptr := a.c.new(size)
a.ptrs = append(a.ptrs, ptr)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint32(len(s) + 1))
a.c.mem.writeString(ptr, s)
return ptr
Savepoint() Savepoint
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
}

View File

@@ -9,8 +9,9 @@ const (
_UTF8 = 1
_MAX_STRING = 512 // Used for short strings: names, error messages…
_MAX_PATHNAME = 512
_MAX_STRING = 512 // Used for short strings: names, error messages…
_MAX_ALLOCATION_SIZE = 0x7ffffeff
ptrlen = 4
)
@@ -131,6 +132,7 @@ const (
CONSTRAINT_DATATYPE ExtendedErrorCode = xErrorCode(CONSTRAINT) | (12 << 8)
NOTICE_RECOVER_WAL ExtendedErrorCode = xErrorCode(NOTICE) | (1 << 8)
NOTICE_RECOVER_ROLLBACK ExtendedErrorCode = xErrorCode(NOTICE) | (2 << 8)
NOTICE_RBU ExtendedErrorCode = xErrorCode(NOTICE) | (3 << 8)
WARNING_AUTOINDEX ExtendedErrorCode = xErrorCode(WARNING) | (1 << 8)
AUTH_USER ExtendedErrorCode = xErrorCode(AUTH) | (1 << 8)
)
@@ -165,14 +167,6 @@ const (
OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */
)
type _AccessFlag uint32
const (
_ACCESS_EXISTS _AccessFlag = 0
_ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
_ACCESS_READ _AccessFlag = 2 /* Unused */
)
// PrepareFlag is a flag that can be passed to [Conn.PrepareFlags].
//
// https://www.sqlite.org/c3ref/c_prepare_normalize.html

View File

@@ -1,4 +1,27 @@
// Package driver provides a database/sql driver for SQLite.
//
// Importing package driver registers a [database/sql] driver named "sqlite3".
// You may also need to import package embed.
//
// import _ "github.com/ncruces/go-sqlite3/driver"
// import _ "github.com/ncruces/go-sqlite3/embed"
//
// The data source name for "sqlite3" databases can be a filename or a "file:" [URI].
//
// The [TRANSACTION] mode can be specified using "_txlock":
//
// sql.Open("sqlite3", "file:demo.db?_txlock=immediate")
//
// [PRAGMA] statements can be specified using "_pragma":
//
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
//
// If no PRAGMAs are specifed, a busy timeout of 1 minute
// and normal locking mode are used.
//
// [URI]: https://www.sqlite.org/uri.html
// [PRAGMA]: https://www.sqlite.org/pragma.html
// [TRANSACTION]: https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
package driver
import (
@@ -12,6 +35,7 @@ import (
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
func init() {
@@ -20,61 +44,79 @@ func init() {
type sqlite struct{}
func (sqlite) Open(name string) (driver.Conn, error) {
c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE)
func (sqlite) Open(name string) (_ driver.Conn, err error) {
var c conn
c.conn, err = sqlite3.Open(name)
if err != nil {
return nil, err
}
var txBegin string
var pragmas strings.Builder
if _, after, ok := strings.Cut(name, "?"); ok {
query, _ := url.ParseQuery(after)
var pragmas []string
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
query, _ := url.ParseQuery(after)
switch s := query.Get("_txlock"); s {
case "":
txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
txBegin = "BEGIN " + s
default:
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
}
switch s := query.Get("_txlock"); s {
case "":
c.txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
c.txBegin = "BEGIN " + s
default:
c.Close()
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
}
for _, p := range query["_pragma"] {
pragmas.WriteString(`PRAGMA `)
pragmas.WriteString(p)
pragmas.WriteByte(';')
pragmas = query["_pragma"]
}
}
if pragmas.Len() == 0 {
pragmas.WriteString(`PRAGMA locking_mode=normal;`)
pragmas.WriteString(`PRAGMA busy_timeout=60000;`)
if len(pragmas) == 0 {
err := c.conn.Exec(`
PRAGMA locking_mode=normal;
PRAGMA busy_timeout=60000;
`)
if err != nil {
c.Close()
return nil, err
}
c.reusable = true
} else {
s, _, err := c.conn.Prepare(`
SELECT * FROM
PRAGMA_locking_mode,
PRAGMA_query_only;
`)
if err != nil {
c.Close()
return nil, err
}
if s.Step() {
c.reusable = s.ColumnText(0) == "normal"
c.readOnly = s.ColumnRawText(1)[0] // 0 or 1
}
err = s.Close()
if err != nil {
c.Close()
return nil, err
}
}
err = c.Exec(pragmas.String())
if err != nil {
return nil, fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
return conn{
conn: c,
txBegin: txBegin,
pragmas: pragmas.String(),
}, nil
return c, nil
}
type conn struct {
conn *sqlite3.Conn
pragmas string
txBegin string
txReadOnly bool
txCommit string
txRollback string
reusable bool
readOnly byte
}
var (
// Ensure these interfaces are implemented:
_ driver.Validator = conn{}
_ driver.SessionResetter = conn{}
_ driver.ExecerContext = conn{}
_ driver.ConnBeginTx = conn{}
_ driver.ExecerContext = conn{}
_ driver.ConnBeginTx = conn{}
_ driver.Validator = conn{}
_ sqlite3.DriverConn = conn{}
)
func (c conn) Close() error {
@@ -82,39 +124,36 @@ func (c conn) Close() error {
}
func (c conn) IsValid() bool {
// Pool only normal locking mode connections.
stmt, _, err := c.conn.Prepare(`PRAGMA locking_mode`)
if err != nil {
return false
}
defer stmt.Close()
return stmt.Step() && stmt.ColumnText(0) == "normal"
}
func (c conn) ResetSession(ctx context.Context) error {
return c.conn.Exec(c.pragmas)
return c.reusable
}
func (c conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
switch opts.Isolation {
default:
return nil, isolationErr
case driver.IsolationLevel(sql.LevelDefault):
case driver.IsolationLevel(sql.LevelSerializable):
}
func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, error) {
txBegin := c.txBegin
c.txCommit = `COMMIT`
c.txRollback = `ROLLBACK`
if opts.ReadOnly {
txBegin = `
BEGIN deferred;
PRAGMA query_only=on;
`
PRAGMA query_only=on`
c.txCommit = `
ROLLBACK;
PRAGMA query_only=` + string(c.readOnly)
c.txRollback = c.txCommit
}
switch opts.Isolation {
default:
return nil, util.IsolationErr
case
driver.IsolationLevel(sql.LevelDefault),
driver.IsolationLevel(sql.LevelSerializable):
break
}
c.txReadOnly = opts.ReadOnly
err := c.conn.Exec(txBegin)
if err != nil {
@@ -124,18 +163,15 @@ func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, er
}
func (c conn) Commit() error {
if c.txReadOnly {
return c.Rollback()
}
err := c.conn.Exec(`COMMIT`)
if err != nil {
err := c.conn.Exec(c.txCommit)
if err != nil && !c.conn.GetAutocommit() {
c.Rollback()
}
return err
}
func (c conn) Rollback() error {
return c.conn.Exec(`ROLLBACK`)
return c.conn.Exec(c.txRollback)
}
func (c conn) Prepare(query string) (driver.Stmt, error) {
@@ -153,20 +189,24 @@ func (c conn) Prepare(query string) (driver.Stmt, error) {
if st != nil {
s.Close()
st.Close()
return nil, tailErr
return nil, util.TailErr
}
}
return stmt{s, c.conn}, nil
}
func (c conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
return c.Prepare(query)
}
func (c conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if len(args) != 0 {
// Slow path.
return nil, driver.ErrSkip
}
ch := c.conn.SetInterrupt(ctx.Done())
defer c.conn.SetInterrupt(ch)
old := c.conn.SetInterrupt(ctx)
defer c.conn.SetInterrupt(old)
err := c.conn.Exec(query)
if err != nil {
@@ -174,11 +214,19 @@ func (c conn) ExecContext(ctx context.Context, query string, args []driver.Named
}
return result{
int64(c.conn.LastInsertRowID()),
int64(c.conn.Changes()),
c.conn.LastInsertRowID(),
c.conn.Changes(),
}, nil
}
func (c conn) Savepoint() sqlite3.Savepoint {
return c.conn.Savepoint()
}
func (c conn) OpenBlob(db, table, column string, row int64, write bool) (*sqlite3.Blob, error) {
return c.conn.OpenBlob(db, table, column, row, write)
}
type stmt struct {
stmt *sqlite3.Stmt
conn *sqlite3.Conn
@@ -270,11 +318,11 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
case sqlite3.ZeroBlob:
err = s.stmt.BindZeroBlob(id, int64(a))
case time.Time:
err = s.stmt.BindText(id, a.Format(time.RFC3339Nano))
err = s.stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
case nil:
err = s.stmt.BindNull(id)
default:
panic(assertErr)
panic(util.AssertErr())
}
}
if err != nil {
@@ -325,8 +373,8 @@ func (r rows) Columns() []string {
}
func (r rows) Next(dest []driver.Value) error {
ch := r.conn.SetInterrupt(r.ctx.Done())
defer r.conn.SetInterrupt(ch)
old := r.conn.SetInterrupt(r.ctx)
defer r.conn.SetInterrupt(old)
if !r.stmt.Step() {
if err := r.stmt.Err(); err != nil {
@@ -341,11 +389,10 @@ func (r rows) Next(dest []driver.Value) error {
dest[i] = r.stmt.ColumnInt64(i)
case sqlite3.FLOAT:
dest[i] = r.stmt.ColumnFloat(i)
case sqlite3.TEXT:
dest[i] = maybeDate(r.stmt.ColumnText(i))
case sqlite3.BLOB:
buf, _ := dest[i].([]byte)
dest[i] = r.stmt.ColumnBlob(i, buf)
dest[i] = r.stmt.ColumnRawBlob(i)
case sqlite3.TEXT:
dest[i] = stringOrTime(r.stmt.ColumnRawText(i))
case sqlite3.NULL:
if buf, ok := dest[i].([]byte); ok {
dest[i] = buf[0:0]
@@ -353,7 +400,7 @@ func (r rows) Next(dest []driver.Value) error {
dest[i] = nil
}
default:
panic(assertErr)
panic(util.AssertErr())
}
}

View File

@@ -1,4 +1,3 @@
// Package driver provides a database/sql driver for SQLite.
package driver
import (
@@ -12,9 +11,12 @@ import (
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
func Test_Open_dir(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", ".")
if err != nil {
t.Fatal(err)
@@ -25,19 +27,14 @@ func Test_Open_dir(t *testing.T) {
if err == nil {
t.Fatal("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.CANTOPEN {
t.Errorf("got %d, want sqlite3.CANTOPEN", rc)
}
if got := err.Error(); got != `sqlite3: unable to open database file` {
t.Error("got message: ", got)
if !errors.Is(err, sqlite3.CANTOPEN) {
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
}
}
func Test_Open_pragma(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", "file::memory:?_pragma=busy_timeout(1000)")
if err != nil {
t.Fatal(err)
@@ -55,6 +52,8 @@ func Test_Open_pragma(t *testing.T) {
}
func Test_Open_pragma_invalid(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", "file::memory:?_pragma=busy_timeout+1000")
if err != nil {
t.Fatal(err)
@@ -73,13 +72,15 @@ func Test_Open_pragma_invalid(t *testing.T) {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: invalid _pragma: sqlite3: SQL logic error: near "1000": syntax error` {
t.Error("got message: ", got)
t.Error("got message:", got)
}
}
func Test_Open_txLock(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", "file:"+
filepath.Join(t.TempDir(), "test.db")+
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
"?_txlock=exclusive&_pragma=busy_timeout(0)")
if err != nil {
t.Fatal(err)
@@ -95,20 +96,13 @@ func Test_Open_txLock(t *testing.T) {
if err == nil {
t.Error("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.BUSY {
t.Errorf("got %d, want sqlite3.BUSY", rc)
if !errors.Is(err, sqlite3.BUSY) {
t.Errorf("got %v, want sqlite3.BUSY", err)
}
var terr interface{ Temporary() bool }
if !errors.As(err, &terr) || !terr.Temporary() {
t.Error("not temporary", err)
}
if got := err.Error(); got != `sqlite3: database is locked` {
t.Error("got message: ", got)
}
err = tx1.Commit()
if err != nil {
@@ -117,6 +111,8 @@ func Test_Open_txLock(t *testing.T) {
}
func Test_Open_txLock_invalid(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", "file::memory:?_txlock=xclusive")
if err != nil {
t.Fatal(err)
@@ -128,22 +124,26 @@ func Test_Open_txLock_invalid(t *testing.T) {
t.Fatal("want error")
}
if got := err.Error(); got != `sqlite3: invalid _txlock: xclusive` {
t.Error("got message: ", got)
t.Error("got message:", got)
}
}
func Test_BeginTx(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.db"))
db, err := sql.Open("sqlite3", "file:"+
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
"?_txlock=exclusive&_pragma=busy_timeout(0)")
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err.Error() != string(isolationErr) {
if err.Error() != string(util.IsolationErr) {
t.Error("want isolationErr")
}
@@ -161,15 +161,8 @@ func Test_BeginTx(t *testing.T) {
if err == nil {
t.Error("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.READONLY {
t.Errorf("got %d, want sqlite3.READONLY", rc)
}
if got := err.Error(); got != `sqlite3: attempt to write a readonly database` {
t.Error("got message: ", got)
if !errors.Is(err, sqlite3.READONLY) {
t.Errorf("got %v, want sqlite3.READONLY", err)
}
err = tx2.Commit()
@@ -184,6 +177,8 @@ func Test_BeginTx(t *testing.T) {
}
func Test_Prepare(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
@@ -208,7 +203,7 @@ func Test_Prepare(t *testing.T) {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message: ", got)
t.Error("got message:", got)
}
_, err = db.Prepare(`SELECT 1; SELECT`)
@@ -222,16 +217,18 @@ func Test_Prepare(t *testing.T) {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message: ", got)
t.Error("got message:", got)
}
_, err = db.Prepare(`SELECT 1; SELECT 2`)
if err.Error() != string(tailErr) {
if err.Error() != string(util.TailErr) {
t.Error("want tailErr")
}
}
func Test_QueryRow_named(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -282,6 +279,8 @@ func Test_QueryRow_named(t *testing.T) {
}
func Test_QueryRow_blob_null(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
@@ -300,7 +299,7 @@ func Test_QueryRow_blob_null(t *testing.T) {
want := [][]byte{nil, {0xca, 0xfe}, {0xba, 0xbe}, nil}
for i := 0; rows.Next(); i++ {
var buf []byte
var buf sql.RawBytes
err = rows.Scan(&buf)
if err != nil {
t.Fatal(err)
@@ -310,39 +309,3 @@ func Test_QueryRow_blob_null(t *testing.T) {
}
}
}
func Test_ZeroBlob(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
_, err = conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
_, err = conn.ExecContext(ctx, `INSERT INTO test(col) VALUES(?)`, sqlite3.ZeroBlob(4))
if err != nil {
t.Fatal(err)
}
var got []byte
err = conn.QueryRowContext(ctx, `SELECT col FROM test`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if string(got) != "\x00\x00\x00\x00" {
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
}
}

View File

@@ -1,11 +0,0 @@
package driver
type errorString string
func (e errorString) Error() string { return string(e) }
const (
assertErr = errorString("sqlite3: assertion failed")
tailErr = errorString("sqlite3: multiple statements")
isolationErr = errorString("sqlite3: unsupported isolation level")
)

View File

@@ -28,10 +28,11 @@ func Example() {
if err != nil {
log.Fatal(err)
}
defer db.Close()
defer os.Remove("./recordings.db")
defer db.Close()
err = createAlbumsTable()
// Create a table with some data in it.
err = albumsSetup()
if err != nil {
log.Fatal(err)
}
@@ -58,14 +59,13 @@ func Example() {
log.Fatal(err)
}
fmt.Printf("ID of added album: %v\n", albID)
// Output:
// Albums found: [{1 Blue Train John Coltrane 56.99} {2 Giant Steps John Coltrane 63.99}]
// Album found: {2 Giant Steps John Coltrane 63.99}
// ID of added album: 5
}
func createAlbumsTable() error {
func albumsSetup() error {
_, err := db.Exec(`
DROP TABLE IF EXISTS album;
CREATE TABLE album (

View File

@@ -8,24 +8,24 @@ import (
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
// if it roundtrips back to the same string.
// This way times can be persisted to, and recovered from, the database,
// but if a string is needed, [database.sql] will recover the same string.
func maybeDate(text string) driver.Value {
// but if a string is needed, [database/sql] will recover the same string.
func stringOrTime(text []byte) driver.Value {
// Weed out (some) values that can't possibly be
// [time.RFC3339Nano] timestamps.
if len(text) < len("2006-01-02T15:04:05Z") {
return text
return string(text)
}
if len(text) > len(time.RFC3339Nano) {
return text
return string(text)
}
if text[4] != '-' || text[10] != 'T' || text[16] != ':' {
return text
return string(text)
}
// Slow path.
date, err := time.Parse(time.RFC3339Nano, text)
if err == nil && date.Format(time.RFC3339Nano) == text {
date, err := time.Parse(time.RFC3339Nano, string(text))
if err == nil && date.Format(time.RFC3339Nano) == string(text) {
return date
}
return text
return string(text)
}

View File

@@ -5,7 +5,8 @@ import (
"time"
)
func Fuzz_maybeDate(f *testing.F) {
// This checks that any string can be recovered as the same string.
func Fuzz_stringOrTime_1(f *testing.F) {
f.Add("")
f.Add(" ")
f.Add("SQLite")
@@ -21,7 +22,7 @@ func Fuzz_maybeDate(f *testing.F) {
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
f.Fuzz(func(t *testing.T, str string) {
value := maybeDate(str)
value := stringOrTime([]byte(str))
switch v := value.(type) {
case time.Time:
@@ -44,3 +45,56 @@ func Fuzz_maybeDate(f *testing.F) {
}
})
}
// This checks that any [time.Time] can be recovered as a [time.Time],
// with nanosecond accuracy, and preserving any timezone offset.
func Fuzz_stringOrTime_2(f *testing.F) {
f.Add(0, 0)
f.Add(0, 1)
f.Add(0, -1)
f.Add(0, 999_999_999)
f.Add(0, 1_000_000_000)
f.Add(7956915742, 222_222_222) // twosday
f.Add(639095955742, 222_222_222) // twosday, year 22222AD
f.Add(-763421161058, 222_222_222) // twosday, year 22222BC
checkTime := func(t *testing.T, date time.Time) {
value := stringOrTime([]byte(date.Format(time.RFC3339Nano)))
switch v := value.(type) {
case time.Time:
// Make sure times round-trip to the same time:
if !v.Equal(date) {
t.Fatalf("did not round-trip: %v", date)
}
// Make with the same zone offset:
_, off1 := v.Zone()
_, off2 := date.Zone()
if off1 != off2 {
t.Fatalf("did not round-trip: %v", date)
}
case string:
t.Fatalf("was not recovered: %v", date)
default:
t.Fatalf("invalid type %T: %v", v, date)
}
}
f.Fuzz(func(t *testing.T, sec, nsec int) {
// Reduce the search space.
if 1e12 < sec || sec < -1e12 {
// Dates before 29000BC and after 33000AD; I think we're safe.
return
}
if 0 < nsec || nsec > 1e10 {
// Out of range nsec: [time.Time.Unix] handles these.
return
}
unix := time.Unix(int64(sec), int64(nsec))
checkTime(t, unix)
checkTime(t, unix.UTC())
checkTime(t, unix.In(time.FixedZone("", -8*3600)))
checkTime(t, unix.In(time.FixedZone("", +8*3600)))
})
}

75
driver_test.go Normal file
View File

@@ -0,0 +1,75 @@
package sqlite3_test
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
var db *sql.DB
func ExampleDriverConn() {
var err error
db, err = sql.Open("sqlite3", "demo.db")
if err != nil {
log.Fatal(err)
}
defer os.Remove("demo.db")
defer db.Close()
ctx := context.Background()
conn, err := db.Conn(ctx)
if err != nil {
log.Fatal(err)
}
defer conn.Close()
_, err = conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
log.Fatal(err)
}
res, err := conn.ExecContext(ctx, `INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(11))
if err != nil {
log.Fatal(err)
}
id, err := res.LastInsertId()
if err != nil {
log.Fatal(err)
}
err = conn.Raw(func(driverConn any) error {
conn := driverConn.(sqlite3.DriverConn)
savept := conn.Savepoint()
defer savept.Release(&err)
blob, err := conn.OpenBlob("main", "test", "col", id, true)
if err != nil {
return err
}
defer blob.Close()
_, err = fmt.Fprint(blob, "Hello BLOB!")
return err
})
if err != nil {
log.Fatal(err)
}
var msg string
err = conn.QueryRowContext(ctx, `SELECT col FROM test`).Scan(&msg)
if err != nil {
log.Fatal(err)
}
fmt.Println(msg)
// Output:
// Hello BLOB!
}

23
embed/README.md Normal file
View File

@@ -0,0 +1,23 @@
# Embeddable WASM build of SQLite
This folder includes an embeddable WASM build of SQLite 3.41.2 for use with
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
The following optional features are compiled in:
- [math functions](https://www.sqlite.org/lang_mathfunc.html)
- [FTS3/4](https://www.sqlite.org/fts3.html)/[5](https://www.sqlite.org/fts5.html)
- [JSON](https://www.sqlite.org/json1.html)
- [R*Tree](https://www.sqlite.org/rtree.html)
- [GeoPoly](https://www.sqlite.org/geopoly.html)
- [base64](https://github.com/sqlite/sqlite/blob/master/ext/misc/base64.c)
- [decimal](https://github.com/sqlite/sqlite/blob/master/ext/misc/decimal.c)
- [regexp](https://github.com/sqlite/sqlite/blob/master/ext/misc/regexp.c)
- [series](https://github.com/sqlite/sqlite/blob/master/ext/misc/series.c)
- [uint](https://github.com/sqlite/sqlite/blob/master/ext/misc/uint.c)
- [uuid](https://github.com/sqlite/sqlite/blob/master/ext/misc/uuid.c)
- [time](../sqlite3/time.c)
See the [configuration options](../sqlite3/sqlite_cfg.h).
Built using [`wasi-sdk`](https://github.com/WebAssembly/wasi-sdk),
and [`binaryen`](https://github.com/WebAssembly/binaryen).

View File

@@ -1,50 +1,28 @@
#!/usr/bin/env bash
set -eo pipefail
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
# download SQLite
../sqlite3/download.sh
ROOT=../
BINARYEN="$ROOT/tools/binaryen-version_112/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
# build SQLite
zig cc --target=wasm32-wasi -flto -g0 -Os \
-o sqlite3.wasm ../sqlite3/*.c \
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
-o sqlite3.wasm "$ROOT/sqlite3/main.c" \
-I"$ROOT/sqlite3" \
-mexec-model=reactor \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-Wl,--initial-memory=327680 \
-Wl,--stack-first \
-Wl,--import-undefined \
-D_HAVE_SQLITE_CONFIG_H \
-Wl,--export=malloc \
-Wl,--export=free \
-Wl,--export=malloc_destructor \
-Wl,--export=sqlite3_errcode \
-Wl,--export=sqlite3_errstr \
-Wl,--export=sqlite3_errmsg \
-Wl,--export=sqlite3_error_offset \
-Wl,--export=sqlite3_open_v2 \
-Wl,--export=sqlite3_close \
-Wl,--export=sqlite3_prepare_v3 \
-Wl,--export=sqlite3_finalize \
-Wl,--export=sqlite3_reset \
-Wl,--export=sqlite3_step \
-Wl,--export=sqlite3_exec \
-Wl,--export=sqlite3_clear_bindings \
-Wl,--export=sqlite3_bind_parameter_count \
-Wl,--export=sqlite3_bind_parameter_index \
-Wl,--export=sqlite3_bind_parameter_name \
-Wl,--export=sqlite3_bind_null \
-Wl,--export=sqlite3_bind_int64 \
-Wl,--export=sqlite3_bind_double \
-Wl,--export=sqlite3_bind_text64 \
-Wl,--export=sqlite3_bind_blob64 \
-Wl,--export=sqlite3_bind_zeroblob64 \
-Wl,--export=sqlite3_column_count \
-Wl,--export=sqlite3_column_name \
-Wl,--export=sqlite3_column_type \
-Wl,--export=sqlite3_column_int64 \
-Wl,--export=sqlite3_column_double \
-Wl,--export=sqlite3_column_text \
-Wl,--export=sqlite3_column_blob \
-Wl,--export=sqlite3_column_bytes \
-Wl,--export=sqlite3_last_insert_rowid \
-Wl,--export=sqlite3_changes64 \
-Wl,--export=sqlite3_interrupt \
$(awk '{print "-Wl,--export="$0}' exports.txt)
trap 'rm -f sqlite3.tmp' EXIT
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
"$BINARYEN/wasm-opt" -g -O2 sqlite3.tmp -o sqlite3.wasm \
--enable-multivalue --enable-mutable-globals \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext

48
embed/exports.txt Normal file
View File

@@ -0,0 +1,48 @@
free
malloc
malloc_destructor
sqlite3_errcode
sqlite3_errstr
sqlite3_errmsg
sqlite3_error_offset
sqlite3_open_v2
sqlite3_close
sqlite3_close_v2
sqlite3_prepare_v3
sqlite3_finalize
sqlite3_reset
sqlite3_step
sqlite3_exec
sqlite3_clear_bindings
sqlite3_bind_parameter_count
sqlite3_bind_parameter_index
sqlite3_bind_parameter_name
sqlite3_bind_null
sqlite3_bind_int64
sqlite3_bind_double
sqlite3_bind_text64
sqlite3_bind_blob64
sqlite3_bind_zeroblob64
sqlite3_column_count
sqlite3_column_name
sqlite3_column_type
sqlite3_column_int64
sqlite3_column_double
sqlite3_column_text
sqlite3_column_blob
sqlite3_column_bytes
sqlite3_blob_open
sqlite3_blob_close
sqlite3_blob_bytes
sqlite3_blob_read
sqlite3_blob_write
sqlite3_blob_reopen
sqlite3_get_autocommit
sqlite3_last_insert_rowid
sqlite3_changes64
sqlite3_backup_init
sqlite3_backup_step
sqlite3_backup_finish
sqlite3_backup_remaining
sqlite3_backup_pagecount
sqlite3_interrupt_offset

View File

@@ -1,7 +1,9 @@
// Package embed embeds SQLite into your application.
//
// You can obtain this build of SQLite from:
// https://github.com/ncruces/go-sqlite3/tree/main/embed
// Importing package embed initializes the [sqlite3.Binary] variable
// with an appropriate build of SQLite:
//
// import _ "github.com/ncruces/go-sqlite3/embed"
package embed
import (

Binary file not shown.

146
error.go
View File

@@ -1,7 +1,6 @@
package sqlite3
import (
"runtime"
"strconv"
"strings"
)
@@ -10,10 +9,10 @@ import (
//
// https://www.sqlite.org/c3ref/errcode.html
type Error struct {
code uint64
str string
msg string
sql string
code uint64
}
// Code returns the primary error code for this error.
@@ -50,35 +49,140 @@ func (e *Error) Error() string {
return b.String()
}
// Is tests whether this error matches a given [ErrorCode] or [ExtendedErrorCode].
//
// It makes it possible to do:
//
// if errors.Is(err, sqlite3.BUSY) {
// // ... handle BUSY
// }
func (e *Error) Is(err error) bool {
switch c := err.(type) {
case ErrorCode:
return c == e.Code()
case ExtendedErrorCode:
return c == e.ExtendedCode()
}
return false
}
// Temporary returns true for [BUSY] errors.
func (e *Error) Temporary() bool {
return e.Code() == BUSY
}
// Timeout returns true for [BUSY_TIMEOUT] errors.
func (e *Error) Timeout() bool {
return e.ExtendedCode() == BUSY_TIMEOUT
}
// SQL returns the SQL starting at the token that triggered a syntax error.
func (e *Error) SQL() string {
return e.sql
}
type errorString string
// Error implements the error interface.
func (e ErrorCode) Error() string {
switch e {
case _OK:
return "sqlite3: not an error"
case _ROW:
return "sqlite3: another row available"
case _DONE:
return "sqlite3: no more rows available"
func (e errorString) Error() string { return string(e) }
const (
binaryErr = errorString("sqlite3: no SQLite binary embed/set/loaded")
nilErr = errorString("sqlite3: invalid memory address or null pointer dereference")
oomErr = errorString("sqlite3: out of memory")
rangeErr = errorString("sqlite3: index out of range")
noNulErr = errorString("sqlite3: missing NUL terminator")
noGlobalErr = errorString("sqlite3: could not find global: ")
noFuncErr = errorString("sqlite3: could not find function: ")
timeErr = errorString("sqlite3: invalid time value")
)
func assertErr() errorString {
msg := "sqlite3: assertion failed"
if _, file, line, ok := runtime.Caller(1); ok {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
case ERROR:
return "sqlite3: SQL logic error"
case INTERNAL:
break
case PERM:
return "sqlite3: access permission denied"
case ABORT:
return "sqlite3: query aborted"
case BUSY:
return "sqlite3: database is locked"
case LOCKED:
return "sqlite3: database table is locked"
case NOMEM:
return "sqlite3: out of memory"
case READONLY:
return "sqlite3: attempt to write a readonly database"
case INTERRUPT:
return "sqlite3: interrupted"
case IOERR:
return "sqlite3: disk I/O error"
case CORRUPT:
return "sqlite3: database disk image is malformed"
case NOTFOUND:
return "sqlite3: unknown operation"
case FULL:
return "sqlite3: database or disk is full"
case CANTOPEN:
return "sqlite3: unable to open database file"
case PROTOCOL:
return "sqlite3: locking protocol"
case FORMAT:
break
case SCHEMA:
return "sqlite3: database schema has changed"
case TOOBIG:
return "sqlite3: string or blob too big"
case CONSTRAINT:
return "sqlite3: constraint failed"
case MISMATCH:
return "sqlite3: datatype mismatch"
case MISUSE:
return "sqlite3: bad parameter or other API misuse"
case NOLFS:
break
case AUTH:
return "sqlite3: authorization denied"
case EMPTY:
break
case RANGE:
return "sqlite3: column index out of range"
case NOTADB:
return "sqlite3: file is not a database"
case NOTICE:
return "sqlite3: notification message"
case WARNING:
return "sqlite3: warning message"
}
return errorString(msg)
return "sqlite3: unknown error"
}
// Temporary returns true for [BUSY] errors.
func (e ErrorCode) Temporary() bool {
return e == BUSY
}
// Error implements the error interface.
func (e ExtendedErrorCode) Error() string {
switch x := ErrorCode(e); {
case e == ABORT_ROLLBACK:
return "sqlite3: abort due to ROLLBACK"
case x < _ROW:
return x.Error()
case e == _ROW:
return "sqlite3: another row available"
case e == _DONE:
return "sqlite3: no more rows available"
}
return "sqlite3: unknown error"
}
// Is tests whether this error matches a given [ErrorCode].
func (e ExtendedErrorCode) Is(err error) bool {
c, ok := err.(ErrorCode)
return ok && c == ErrorCode(e)
}
// Temporary returns true for [BUSY] errors.
func (e ExtendedErrorCode) Temporary() bool {
return ErrorCode(e) == BUSY
}
// Timeout returns true for [BUSY_TIMEOUT] errors.
func (e ExtendedErrorCode) Timeout() bool {
return e == BUSY_TIMEOUT
}

View File

@@ -1,26 +1,154 @@
package sqlite3
import (
"errors"
"strings"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
)
func Test_assertErr(t *testing.T) {
err := util.AssertErr()
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:12)") {
t.Errorf("got %q", s)
}
}
func TestError(t *testing.T) {
t.Parallel()
err := Error{code: 0x8080}
if rc := err.Code(); rc != 0x80 {
t.Errorf("got %#x, want 0x80", rc)
}
if !errors.Is(&err, ErrorCode(0x80)) {
t.Errorf("want true")
}
if rc := err.ExtendedCode(); rc != 0x8080 {
t.Errorf("got %#x, want 0x8080", rc)
}
if !errors.Is(&err, ExtendedErrorCode(0x8080)) {
t.Errorf("want true")
}
if s := err.Error(); s != "sqlite3: 32896" {
t.Errorf("got %q", s)
}
}
func Test_assertErr(t *testing.T) {
err := assertErr()
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:22)") {
t.Errorf("got %q", s)
if !errors.Is(err.ExtendedCode(), ErrorCode(0x80)) {
t.Errorf("want true")
}
}
func TestError_Temporary(t *testing.T) {
t.Parallel()
tests := []struct {
name string
code uint64
want bool
}{
{"ERROR", uint64(ERROR), false},
{"BUSY", uint64(BUSY), true},
{"BUSY_RECOVERY", uint64(BUSY_RECOVERY), true},
{"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), true},
{"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
{
err := &Error{code: tt.code}
if got := err.Temporary(); got != tt.want {
t.Errorf("Error.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
}
}
{
err := ErrorCode(tt.code)
if got := err.Temporary(); got != tt.want {
t.Errorf("ErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
}
}
{
err := ExtendedErrorCode(tt.code)
if got := err.Temporary(); got != tt.want {
t.Errorf("ExtendedErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
}
}
})
}
}
func TestError_Timeout(t *testing.T) {
t.Parallel()
tests := []struct {
name string
code uint64
want bool
}{
{"ERROR", uint64(ERROR), false},
{"BUSY", uint64(BUSY), false},
{"BUSY_RECOVERY", uint64(BUSY_RECOVERY), false},
{"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), false},
{"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
{
err := &Error{code: tt.code}
if got := err.Timeout(); got != tt.want {
t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want)
}
}
{
err := ExtendedErrorCode(tt.code)
if got := err.Timeout(); got != tt.want {
t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want)
}
}
})
}
}
func Test_ErrorCode_Error(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
// Test all error codes.
for i := 0; i == int(ErrorCode(i)); i++ {
want := "sqlite3: "
r := db.call(db.api.errstr, uint64(i))
want += util.ReadString(db.mod, uint32(r[0]), _MAX_STRING)
got := ErrorCode(i).Error()
if got != want {
t.Fatalf("got %q, want %q, with %d", got, want, i)
}
}
}
func Test_ExtendedErrorCode_Error(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
// Test all extended error codes.
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
want := "sqlite3: "
r := db.call(db.api.errstr, uint64(i))
want += util.ReadString(db.mod, uint32(r[0]), _MAX_STRING)
got := ExtendedErrorCode(i).Error()
if got != want {
t.Fatalf("got %q, want %q, with %d", got, want, i)
}
}
}

View File

@@ -21,7 +21,7 @@ func Example() {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
log.Fatal(err)
}
@@ -30,6 +30,7 @@ func Example() {
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnInt(0), stmt.ColumnText(1))
@@ -47,7 +48,6 @@ func Example() {
if err != nil {
log.Fatal(err)
}
// Output:
// 0 go
// 1 zig

6
go.mod
View File

@@ -4,7 +4,9 @@ go 1.19
require (
github.com/ncruces/julianday v0.1.5
github.com/tetratelabs/wazero v1.0.0-pre.9
github.com/tetratelabs/wazero v1.0.2
golang.org/x/sync v0.1.0
golang.org/x/sys v0.5.0
golang.org/x/sys v0.7.0
)
retract v0.4.0 // tagged from the wrong branch

8
go.sum
View File

@@ -1,8 +1,8 @@
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.0.0-pre.9 h1:2uVdi2bvTi/JQxG2cp3LRm2aRadd3nURn5jcfbvqZcw=
github.com/tetratelabs/wazero v1.0.0-pre.9/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
github.com/tetratelabs/wazero v1.0.2 h1:lpwL5zczFHk2mxKur98035Gig+Z3vd9JURk6lUdZxXY=
github.com/tetratelabs/wazero v1.0.2/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

42
internal/util/error.go Normal file
View File

@@ -0,0 +1,42 @@
package util
import (
"fmt"
"runtime"
"strconv"
)
type ErrorString string
func (e ErrorString) Error() string { return string(e) }
const (
NilErr = ErrorString("sqlite3: invalid memory address or null pointer dereference")
OOMErr = ErrorString("sqlite3: out of memory")
RangeErr = ErrorString("sqlite3: index out of range")
NoNulErr = ErrorString("sqlite3: missing NUL terminator")
NoGlobalErr = ErrorString("sqlite3: could not find global: ")
NoFuncErr = ErrorString("sqlite3: could not find function: ")
BinaryErr = ErrorString("sqlite3: no SQLite binary embed/set/loaded")
TimeErr = ErrorString("sqlite3: invalid time value")
WhenceErr = ErrorString("sqlite3: invalid whence")
OffsetErr = ErrorString("sqlite3: invalid offset")
TailErr = ErrorString("sqlite3: multiple statements")
IsolationErr = ErrorString("sqlite3: unsupported isolation level")
)
func AssertErr() ErrorString {
msg := "sqlite3: assertion failed"
if _, file, line, ok := runtime.Caller(1); ok {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return ErrorString(msg)
}
func Finalizer[T any](skip int) func(*T) {
msg := fmt.Sprintf("sqlite3: %T not closed", new(T))
if _, file, line, ok := runtime.Caller(skip + 1); ok && skip >= 0 {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return func(*T) { panic(ErrorString(msg)) }
}

110
internal/util/mem.go Normal file
View File

@@ -0,0 +1,110 @@
package util
import (
"bytes"
"math"
"github.com/tetratelabs/wazero/api"
)
func View(mod api.Module, ptr uint32, size uint64) []byte {
if ptr == 0 {
panic(NilErr)
}
if size > math.MaxUint32 {
panic(RangeErr)
}
buf, ok := mod.Memory().Read(ptr, uint32(size))
if !ok {
panic(RangeErr)
}
return buf
}
func ReadUint32(mod api.Module, ptr uint32) uint32 {
if ptr == 0 {
panic(NilErr)
}
v, ok := mod.Memory().ReadUint32Le(ptr)
if !ok {
panic(RangeErr)
}
return v
}
func WriteUint32(mod api.Module, ptr uint32, v uint32) {
if ptr == 0 {
panic(NilErr)
}
ok := mod.Memory().WriteUint32Le(ptr, v)
if !ok {
panic(RangeErr)
}
}
func ReadUint64(mod api.Module, ptr uint32) uint64 {
if ptr == 0 {
panic(NilErr)
}
v, ok := mod.Memory().ReadUint64Le(ptr)
if !ok {
panic(RangeErr)
}
return v
}
func WriteUint64(mod api.Module, ptr uint32, v uint64) {
if ptr == 0 {
panic(NilErr)
}
ok := mod.Memory().WriteUint64Le(ptr, v)
if !ok {
panic(RangeErr)
}
}
func ReadFloat64(mod api.Module, ptr uint32) float64 {
return math.Float64frombits(ReadUint64(mod, ptr))
}
func WriteFloat64(mod api.Module, ptr uint32, v float64) {
WriteUint64(mod, ptr, math.Float64bits(v))
}
func ReadString(mod api.Module, ptr, maxlen uint32) string {
if ptr == 0 {
panic(NilErr)
}
switch maxlen {
case 0:
return ""
case math.MaxUint32:
// avoid overflow
default:
maxlen = maxlen + 1
}
mem := mod.Memory()
buf, ok := mem.Read(ptr, maxlen)
if !ok {
buf, ok = mem.Read(ptr, mem.Size()-ptr)
if !ok {
panic(RangeErr)
}
}
if i := bytes.IndexByte(buf, 0); i < 0 {
panic(NoNulErr)
} else {
return string(buf[:i])
}
}
func WriteBytes(mod api.Module, ptr uint32, b []byte) {
buf := View(mod, ptr, uint64(len(b)))
copy(buf, b)
}
func WriteString(mod api.Module, ptr uint32, s string) {
buf := View(mod, ptr, uint64(len(s)+1))
buf[len(s)] = 0
copy(buf, s)
}

90
internal/util/mem_test.go Normal file
View File

@@ -0,0 +1,90 @@
package util
import (
"math"
"testing"
)
func TestView_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
View(mock, 0, 8)
t.Error("want panic")
}
func TestView_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
View(mock, 126, 8)
t.Error("want panic")
}
func TestView_overflow(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
View(mock, 1, math.MaxInt64)
t.Error("want panic")
}
func TestReadUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadUint32(mock, 0)
t.Error("want panic")
}
func TestReadUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadUint32(mock, 126)
t.Error("want panic")
}
func TestReadUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadUint64(mock, 0)
t.Error("want panic")
}
func TestReadUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadUint64(mock, 126)
t.Error("want panic")
}
func TestWriteUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
WriteUint32(mock, 0, 1)
t.Error("want panic")
}
func TestWriteUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
WriteUint32(mock, 126, 1)
t.Error("want panic")
}
func TestWriteUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
WriteUint64(mock, 0, 1)
t.Error("want panic")
}
func TestWriteUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
WriteUint64(mock, 126, 1)
t.Error("want panic")
}
func TestReadString_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadString(mock, 130, math.MaxUint32)
t.Error("want panic")
}

View File

@@ -1,4 +1,4 @@
package sqlite3
package util
import (
"context"
@@ -8,13 +8,9 @@ import (
"github.com/tetratelabs/wazero/api"
)
func init() {
Path = "./embed/sqlite3.wasm"
}
func newMemory(size uint32) memory {
func NewMockModule(size uint32) api.Module {
mem := make(mockMemory, size)
return memory{mockModule{&mem}}
return mockModule{&mem}
}
type mockModule struct {
@@ -152,10 +148,6 @@ func (m *mockMemory) Grow(delta uint32) (result uint32, ok bool) {
return uint32(prev), true
}
func (m mockMemory) PageSize() (result uint32) {
return uint32(len(m) / 65536)
}
func (m mockMemory) hasSize(offset uint32, byteCount uint32) bool {
return uint64(offset)+uint64(byteCount) <= uint64(len(m))
}

242
internal/util/mock_test.go Normal file
View File

@@ -0,0 +1,242 @@
package util
import (
"math"
"testing"
)
func Test_mockMemory_byte(t *testing.T) {
const want byte = 98
mock := NewMockModule(128)
_, ok := mock.Memory().ReadByte(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteByte(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteByte(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadByte(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
func Test_mockMemory_uint16(t *testing.T) {
const want uint16 = 9876
mock := NewMockModule(128)
_, ok := mock.Memory().ReadUint16Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint16Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint16Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadUint16Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
func Test_mockMemory_uint32(t *testing.T) {
const want uint32 = 987654321
mock := NewMockModule(128)
_, ok := mock.Memory().ReadUint32Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint32Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint32Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadUint32Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
func Test_mockMemory_uint64(t *testing.T) {
const want uint64 = 9876543210
mock := NewMockModule(128)
_, ok := mock.Memory().ReadUint64Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint64Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint64Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadUint64Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
func Test_mockMemory_float32(t *testing.T) {
const want float32 = math.Pi
mock := NewMockModule(128)
_, ok := mock.Memory().ReadFloat32Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteFloat32Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteFloat32Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadFloat32Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %f, want %f", got, want)
}
}
func Test_mockMemory_float64(t *testing.T) {
const want float64 = math.Pi
mock := NewMockModule(128)
_, ok := mock.Memory().ReadFloat64Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteFloat64Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteFloat64Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadFloat64Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %f, want %f", got, want)
}
}
func Test_mockMemory_bytes(t *testing.T) {
const want string = "\xca\xfe\xba\xbe"
mock := NewMockModule(128)
_, ok := mock.Memory().Read(128, uint32(len(want)))
if ok {
t.Error("want error")
}
ok = mock.Memory().Write(128, []byte(want))
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteString(128, want)
if ok {
t.Error("want error")
}
ok = mock.Memory().Write(0, []byte(want))
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().Read(0, uint32(len(want)))
if !ok {
t.Error("want ok")
}
if string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
ok = mock.Memory().WriteString(64, want)
if !ok {
t.Error("want ok")
}
got, ok = mock.Memory().Read(64, uint32(len(want)))
if !ok {
t.Error("want ok")
}
if string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
func Test_mockMemory_grow(t *testing.T) {
mock := NewMockModule(128)
_, ok := mock.Memory().ReadByte(65536)
if ok {
t.Error("want error")
}
got, ok := mock.Memory().Grow(1)
if !ok {
t.Error("want ok")
}
if got != 1 {
t.Errorf("got %d, want 1", got)
}
_, ok = mock.Memory().ReadByte(65536)
if !ok {
t.Error("want ok")
}
}

217
internal/vfs/const.go Normal file
View File

@@ -0,0 +1,217 @@
package vfs
const (
_MAX_PATHNAME = 512
_DEFAULT_SECTOR_SIZE = 4096
)
// https://www.sqlite.org/rescode.html
type _ErrorCode uint32
const (
_OK _ErrorCode = 0 /* Successful result */
_PERM _ErrorCode = 3 /* Access permission denied */
_BUSY _ErrorCode = 5 /* The database file is locked */
_IOERR _ErrorCode = 10 /* Some kind of disk I/O error occurred */
_NOTFOUND _ErrorCode = 12 /* Unknown opcode in sqlite3_file_control() */
_CANTOPEN _ErrorCode = 14 /* Unable to open the database file */
_IOERR_READ = _IOERR | (1 << 8)
_IOERR_SHORT_READ = _IOERR | (2 << 8)
_IOERR_WRITE = _IOERR | (3 << 8)
_IOERR_FSYNC = _IOERR | (4 << 8)
_IOERR_DIR_FSYNC = _IOERR | (5 << 8)
_IOERR_TRUNCATE = _IOERR | (6 << 8)
_IOERR_FSTAT = _IOERR | (7 << 8)
_IOERR_UNLOCK = _IOERR | (8 << 8)
_IOERR_RDLOCK = _IOERR | (9 << 8)
_IOERR_DELETE = _IOERR | (10 << 8)
_IOERR_BLOCKED = _IOERR | (11 << 8)
_IOERR_NOMEM = _IOERR | (12 << 8)
_IOERR_ACCESS = _IOERR | (13 << 8)
_IOERR_CHECKRESERVEDLOCK = _IOERR | (14 << 8)
_IOERR_LOCK = _IOERR | (15 << 8)
_IOERR_CLOSE = _IOERR | (16 << 8)
_IOERR_DIR_CLOSE = _IOERR | (17 << 8)
_IOERR_SHMOPEN = _IOERR | (18 << 8)
_IOERR_SHMSIZE = _IOERR | (19 << 8)
_IOERR_SHMLOCK = _IOERR | (20 << 8)
_IOERR_SHMMAP = _IOERR | (21 << 8)
_IOERR_SEEK = _IOERR | (22 << 8)
_IOERR_DELETE_NOENT = _IOERR | (23 << 8)
_IOERR_MMAP = _IOERR | (24 << 8)
_IOERR_GETTEMPPATH = _IOERR | (25 << 8)
_IOERR_CONVPATH = _IOERR | (26 << 8)
_IOERR_VNODE = _IOERR | (27 << 8)
_IOERR_AUTH = _IOERR | (28 << 8)
_IOERR_BEGIN_ATOMIC = _IOERR | (29 << 8)
_IOERR_COMMIT_ATOMIC = _IOERR | (30 << 8)
_IOERR_ROLLBACK_ATOMIC = _IOERR | (31 << 8)
_IOERR_DATA = _IOERR | (32 << 8)
_IOERR_CORRUPTFS = _IOERR | (33 << 8)
_CANTOPEN_NOTEMPDIR = _CANTOPEN | (1 << 8)
_CANTOPEN_ISDIR = _CANTOPEN | (2 << 8)
_CANTOPEN_FULLPATH = _CANTOPEN | (3 << 8)
_CANTOPEN_CONVPATH = _CANTOPEN | (4 << 8)
_CANTOPEN_DIRTYWAL = _CANTOPEN | (5 << 8) /* Not Used */
_CANTOPEN_SYMLINK = _CANTOPEN | (6 << 8)
_OK_SYMLINK = _OK | (2 << 8) /* internal use only */
)
// https://www.sqlite.org/c3ref/c_open_autoproxy.html
type _OpenFlag uint32
const (
_OPEN_READONLY _OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */
_OPEN_READWRITE _OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */
_OPEN_CREATE _OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */
_OPEN_DELETEONCLOSE _OpenFlag = 0x00000008 /* VFS only */
_OPEN_EXCLUSIVE _OpenFlag = 0x00000010 /* VFS only */
_OPEN_AUTOPROXY _OpenFlag = 0x00000020 /* VFS only */
_OPEN_URI _OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */
_OPEN_MEMORY _OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */
_OPEN_MAIN_DB _OpenFlag = 0x00000100 /* VFS only */
_OPEN_TEMP_DB _OpenFlag = 0x00000200 /* VFS only */
_OPEN_TRANSIENT_DB _OpenFlag = 0x00000400 /* VFS only */
_OPEN_MAIN_JOURNAL _OpenFlag = 0x00000800 /* VFS only */
_OPEN_TEMP_JOURNAL _OpenFlag = 0x00001000 /* VFS only */
_OPEN_SUBJOURNAL _OpenFlag = 0x00002000 /* VFS only */
_OPEN_SUPER_JOURNAL _OpenFlag = 0x00004000 /* VFS only */
_OPEN_NOMUTEX _OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */
_OPEN_FULLMUTEX _OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */
_OPEN_SHAREDCACHE _OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */
_OPEN_PRIVATECACHE _OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
_OPEN_WAL _OpenFlag = 0x00080000 /* VFS only */
_OPEN_NOFOLLOW _OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
_OPEN_EXRESCODE _OpenFlag = 0x02000000 /* Extended result codes */
)
// https://www.sqlite.org/c3ref/c_access_exists.html
type _AccessFlag uint32
const (
_ACCESS_EXISTS _AccessFlag = 0
_ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
_ACCESS_READ _AccessFlag = 2 /* Unused */
)
// https://www.sqlite.org/c3ref/c_sync_dataonly.html
type _SyncFlag uint32
const (
_SYNC_NORMAL _SyncFlag = 0x00002
_SYNC_FULL _SyncFlag = 0x00003
_SYNC_DATAONLY _SyncFlag = 0x00010
)
// https://www.sqlite.org/c3ref/c_lock_exclusive.html
type _LockLevel uint32
const (
// No locks are held on the database.
// The database may be neither read nor written.
// Any internally cached data is considered suspect and subject to
// verification against the database file before being used.
// Other processes can read or write the database as their own locking
// states permit.
// This is the default state.
_LOCK_NONE _LockLevel = 0 /* xUnlock() only */
// The database may be read but not written.
// Any number of processes can hold SHARED locks at the same time,
// hence there can be many simultaneous readers.
// But no other thread or process is allowed to write to the database file
// while one or more SHARED locks are active.
_LOCK_SHARED _LockLevel = 1 /* xLock() or xUnlock() */
// A RESERVED lock means that the process is planning on writing to the
// database file at some point in the future but that it is currently just
// reading from the file.
// Only a single RESERVED lock may be active at one time,
// though multiple SHARED locks can coexist with a single RESERVED lock.
// RESERVED differs from PENDING in that new SHARED locks can be acquired
// while there is a RESERVED lock.
_LOCK_RESERVED _LockLevel = 2 /* xLock() only */
// A PENDING lock means that the process holding the lock wants to write to
// the database as soon as possible and is just waiting on all current
// SHARED locks to clear so that it can get an EXCLUSIVE lock.
// No new SHARED locks are permitted against the database if a PENDING lock
// is active, though existing SHARED locks are allowed to continue.
_LOCK_PENDING _LockLevel = 3 /* internal use only */
// An EXCLUSIVE lock is needed in order to write to the database file.
// Only one EXCLUSIVE lock is allowed on the file and no other locks of any
// kind are allowed to coexist with an EXCLUSIVE lock.
// In order to maximize concurrency, SQLite works to minimize the amount of
// time that EXCLUSIVE locks are held.
_LOCK_EXCLUSIVE _LockLevel = 4 /* xLock() only */
)
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type _FcntlOpcode uint32
const (
_FCNTL_LOCKSTATE _FcntlOpcode = 1
_FCNTL_GET_LOCKPROXYFILE _FcntlOpcode = 2
_FCNTL_SET_LOCKPROXYFILE _FcntlOpcode = 3
_FCNTL_LAST_ERRNO _FcntlOpcode = 4
_FCNTL_SIZE_HINT _FcntlOpcode = 5
_FCNTL_CHUNK_SIZE _FcntlOpcode = 6
_FCNTL_FILE_POINTER _FcntlOpcode = 7
_FCNTL_SYNC_OMITTED _FcntlOpcode = 8
_FCNTL_WIN32_AV_RETRY _FcntlOpcode = 9
_FCNTL_PERSIST_WAL _FcntlOpcode = 10
_FCNTL_OVERWRITE _FcntlOpcode = 11
_FCNTL_VFSNAME _FcntlOpcode = 12
_FCNTL_POWERSAFE_OVERWRITE _FcntlOpcode = 13
_FCNTL_PRAGMA _FcntlOpcode = 14
_FCNTL_BUSYHANDLER _FcntlOpcode = 15
_FCNTL_TEMPFILENAME _FcntlOpcode = 16
_FCNTL_MMAP_SIZE _FcntlOpcode = 18
_FCNTL_TRACE _FcntlOpcode = 19
_FCNTL_HAS_MOVED _FcntlOpcode = 20
_FCNTL_SYNC _FcntlOpcode = 21
_FCNTL_COMMIT_PHASETWO _FcntlOpcode = 22
_FCNTL_WIN32_SET_HANDLE _FcntlOpcode = 23
_FCNTL_WAL_BLOCK _FcntlOpcode = 24
_FCNTL_ZIPVFS _FcntlOpcode = 25
_FCNTL_RBU _FcntlOpcode = 26
_FCNTL_VFS_POINTER _FcntlOpcode = 27
_FCNTL_JOURNAL_POINTER _FcntlOpcode = 28
_FCNTL_WIN32_GET_HANDLE _FcntlOpcode = 29
_FCNTL_PDB _FcntlOpcode = 30
_FCNTL_BEGIN_ATOMIC_WRITE _FcntlOpcode = 31
_FCNTL_COMMIT_ATOMIC_WRITE _FcntlOpcode = 32
_FCNTL_ROLLBACK_ATOMIC_WRITE _FcntlOpcode = 33
_FCNTL_LOCK_TIMEOUT _FcntlOpcode = 34
_FCNTL_DATA_VERSION _FcntlOpcode = 35
_FCNTL_SIZE_LIMIT _FcntlOpcode = 36
_FCNTL_CKPT_DONE _FcntlOpcode = 37
_FCNTL_RESERVE_BYTES _FcntlOpcode = 38
_FCNTL_CKPT_START _FcntlOpcode = 39
_FCNTL_EXTERNAL_READER _FcntlOpcode = 40
_FCNTL_CKSM_FILE _FcntlOpcode = 41
_FCNTL_RESET_CACHE _FcntlOpcode = 42
)
// https://www.sqlite.org/c3ref/c_iocap_atomic.html
type _DeviceCharacteristic uint32
const (
_IOCAP_ATOMIC _DeviceCharacteristic = 0x00000001
_IOCAP_ATOMIC512 _DeviceCharacteristic = 0x00000002
_IOCAP_ATOMIC1K _DeviceCharacteristic = 0x00000004
_IOCAP_ATOMIC2K _DeviceCharacteristic = 0x00000008
_IOCAP_ATOMIC4K _DeviceCharacteristic = 0x00000010
_IOCAP_ATOMIC8K _DeviceCharacteristic = 0x00000020
_IOCAP_ATOMIC16K _DeviceCharacteristic = 0x00000040
_IOCAP_ATOMIC32K _DeviceCharacteristic = 0x00000080
_IOCAP_ATOMIC64K _DeviceCharacteristic = 0x00000100
_IOCAP_SAFE_APPEND _DeviceCharacteristic = 0x00000200
_IOCAP_SEQUENTIAL _DeviceCharacteristic = 0x00000400
_IOCAP_UNDELETABLE_WHEN_OPEN _DeviceCharacteristic = 0x00000800
_IOCAP_POWERSAFE_OVERWRITE _DeviceCharacteristic = 0x00001000
_IOCAP_IMMUTABLE _DeviceCharacteristic = 0x00002000
_IOCAP_BATCH_ATOMIC _DeviceCharacteristic = 0x00004000
)

78
internal/vfs/func.go Normal file
View File

@@ -0,0 +1,78 @@
package vfs
import (
"context"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
func registerFunc1[T0, TR ~uint32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(
func(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0])))
}),
[]api.ValueType{api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
func registerFunc2[T0, T1, TR ~uint32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(
func(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
}),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
func registerFunc3[T0, T1, T2, TR ~uint32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(
func(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2])))
}),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
func registerFunc4[T0, T1, T2, T3, TR ~uint32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2, _ T3) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(
func(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
}),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
func registerFunc5[T0, T1, T2, T3, T4, TR ~uint32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2, _ T3, _ T4) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(
func(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4])))
}),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
func registerFuncRW(mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _, _, _ uint32, _ int64) _ErrorCode) {
mod.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(
func(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, uint32(stack[0]), uint32(stack[1]), uint32(stack[2]), int64(stack[3])))
}),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
func registerFuncT(mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ uint32, _ int64) _ErrorCode) {
mod.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(
func(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, uint32(stack[0]), int64(stack[1])))
}),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}

View File

@@ -0,0 +1,178 @@
package mptest
import (
"bytes"
"context"
"crypto/rand"
"embed"
"io"
"io/fs"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"github.com/ncruces/go-sqlite3/internal/vfs"
)
//go:embed testdata/mptest.wasm
var binary []byte
//go:embed testdata/*.*test
var scripts embed.FS
var (
rt wazero.Runtime
module wazero.CompiledModule
instances atomic.Uint64
)
func init() {
ctx := context.TODO()
rt = wazero.NewRuntime(ctx)
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
env := vfs.NewEnvModuleBuilder(rt)
env.NewFunctionBuilder().WithFunc(system).Export("system")
_, err := env.Instantiate(ctx)
if err != nil {
panic(err)
}
module, err = rt.CompileModule(ctx, binary)
if err != nil {
panic(err)
}
}
func config(ctx context.Context) wazero.ModuleConfig {
name := strconv.FormatUint(instances.Add(1), 10)
log := ctx.Value(logger{}).(io.Writer)
fs, err := fs.Sub(scripts, "testdata")
if err != nil {
panic(err)
}
return wazero.NewModuleConfig().
WithName(name).WithStdout(log).WithStderr(log).WithFS(fs).
WithSysWalltime().WithSysNanotime().WithSysNanosleep().
WithOsyield(runtime.Gosched).
WithRandSource(rand.Reader)
}
func system(ctx context.Context, mod api.Module, ptr uint32) uint32 {
buf, _ := mod.Memory().Read(ptr, mod.Memory().Size()-ptr)
buf = buf[:bytes.IndexByte(buf, 0)]
args := strings.Split(string(buf), " ")
for i := range args {
args[i] = strings.Trim(args[i], `"`)
}
args = args[:len(args)-1]
cfg := config(ctx).WithArgs(args...)
go func() {
ctx, vfs := vfs.Context(ctx)
rt.InstantiateModule(ctx, module, cfg)
vfs.Close()
}()
return 0
}
func Test_config01(t *testing.T) {
ctx, vfs := vfs.Context(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "config01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
vfs.Close()
mod.Close(ctx)
}
func Test_config02(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
if os.Getenv("CI") != "" {
t.Skip("skipping in CI")
}
ctx, vfs := vfs.Context(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "config02.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
vfs.Close()
mod.Close(ctx)
}
func Test_crash01(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
ctx, vfs := vfs.Context(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "crash01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
vfs.Close()
mod.Close(ctx)
}
func Test_multiwrite01(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
ctx, vfs := vfs.Context(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "multiwrite01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
vfs.Close()
mod.Close(ctx)
}
func newContext(t *testing.T) context.Context {
return context.WithValue(context.Background(), logger{}, &testWriter{T: t})
}
type logger struct{}
type testWriter struct {
*testing.T
buf []byte
mtx sync.Mutex
}
func (l *testWriter) Write(p []byte) (n int, err error) {
l.mtx.Lock()
defer l.mtx.Unlock()
l.buf = append(l.buf, p...)
for {
before, after, found := bytes.Cut(l.buf, []byte("\n"))
if !found {
return len(p), nil
}
l.Logf("%s", before)
l.buf = after
}
}

View File

@@ -0,0 +1,2 @@
mptest.wasm filter=lfs diff=lfs merge=lfs -text
*.*test -crlf

View File

@@ -0,0 +1 @@
mptest.c

29
internal/vfs/tests/mptest/testdata/build.sh vendored Executable file
View File

@@ -0,0 +1,29 @@
#!/usr/bin/env bash
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../../../../../
BINARYEN="$ROOT/tools/binaryen-version_112/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
-o mptest.wasm main.c \
-I"$ROOT/sqlite3" \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-Wl,--stack-first \
-Wl,--import-undefined \
-D_HAVE_SQLITE_CONFIG_H \
-DSQLITE_DEFAULT_SYNCHRONOUS=0 \
-DSQLITE_DEFAULT_LOCKING_MODE=0 \
-DHAVE_USLEEP -DSQLITE_NO_SYNC \
-DSQLITE_THREADSAFE=0 -DSQLITE_OMIT_LOAD_EXTENSION \
-D_WASI_EMULATED_GETPID -lwasi-emulated-getpid
"$BINARYEN/wasm-opt" -g -O2 mptest.wasm -o mptest.tmp \
--enable-multivalue --enable-mutable-globals \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext
mv mptest.tmp mptest.wasm

View File

@@ -0,0 +1,46 @@
/*
** Configure five tasks in different ways, then run tests.
*/
--if vfsname() GLOB 'unix'
PRAGMA page_size=8192;
--task 1
PRAGMA journal_mode=PERSIST;
PRAGMA mmap_size=0;
--end
--task 2
PRAGMA journal_mode=TRUNCATE;
PRAGMA mmap_size=28672;
--end
--task 3
PRAGMA journal_mode=MEMORY;
--end
--task 4
PRAGMA journal_mode=OFF;
--end
--task 4
PRAGMA mmap_size(268435456);
--end
--source multiwrite01.test
--wait all
PRAGMA page_size=16384;
VACUUM;
CREATE TABLE pgsz(taskid, sz INTEGER);
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--wait all
SELECT sz FROM pgsz;
--match 16384 16384 16384 16384 16384

View File

@@ -0,0 +1,123 @@
/*
** Configure five tasks in different ways, then run tests.
*/
PRAGMA page_size=512;
--task 1
PRAGMA mmap_size=0;
--end
--task 2
PRAGMA mmap_size=28672;
--end
--task 3
PRAGMA mmap_size=8192;
--end
--task 4
PRAGMA mmap_size=65536;
--end
--task 5
PRAGMA mmap_size=268435456;
--end
--source multiwrite01.test
--source crash02.subtest
PRAGMA page_size=1024;
VACUUM;
CREATE TABLE pgsz(taskid, sz INTEGER);
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--source crash02.subtest
--wait all
SELECT sz FROM pgsz;
--match 1024 1024 1024 1024 1024
PRAGMA page_size=2048;
VACUUM;
DELETE FROM pgsz;
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--source crash02.subtest
--wait all
SELECT sz FROM pgsz;
--match 2048 2048 2048 2048 2048
PRAGMA page_size=8192;
VACUUM;
DELETE FROM pgsz;
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--source crash02.subtest
--wait all
SELECT sz FROM pgsz;
--match 8192 8192 8192 8192 8192
PRAGMA page_size=16384;
VACUUM;
DELETE FROM pgsz;
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--source crash02.subtest
--wait all
SELECT sz FROM pgsz;
--match 16384 16384 16384 16384 16384
PRAGMA auto_vacuum=FULL;
VACUUM;
--source multiwrite01.test
--source crash02.subtest
--wait all
PRAGMA auto_vacuum=FULL;
PRAGMA page_size=512;
VACUUM;
--source multiwrite01.test
--source crash02.subtest

View File

@@ -0,0 +1,106 @@
/* Test cases involving incomplete transactions that must be rolled back.
*/
--task 1
DROP TABLE IF EXISTS t1;
CREATE TABLE t1(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t1 VALUES(1, randomblob(2000));
INSERT INTO t1 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t1 SELECT a+2, randomblob(1500) FROM t1;
INSERT INTO t1 SELECT a+4, randomblob(1500) FROM t1;
INSERT INTO t1 SELECT a+8, randomblob(1500) FROM t1;
--sleep 1
INSERT INTO t1 SELECT a+16, randomblob(1500) FROM t1;
--sleep 1
INSERT INTO t1 SELECT a+32, randomblob(1500) FROM t1;
SELECT count(*) FROM t1;
--match 64
SELECT avg(length(b)) FROM t1;
--match 1500.0
--sleep 2
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t1;
--match 247
SELECT a FROM t1 WHERE b='x17y';
--match 17
CREATE INDEX t1b ON t1(b);
SELECT a FROM t1 WHERE b='x17y';
--match 17
SELECT a FROM t1 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--wait 1
--task 2
DROP TABLE IF EXISTS t2;
CREATE TABLE t2(a INTEGER PRIMARY KEY, b);
INSERT INTO t2 SELECT a, b FROM t1;
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t2;
--match 247
SELECT a FROM t2 WHERE b='x17y';
--match 17
CREATE INDEX t2b ON t2(b);
SELECT a FROM t2 WHERE b='x17y';
--match 17
SELECT a FROM t2 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 3
DROP TABLE IF EXISTS t3;
CREATE TABLE t3(a INTEGER PRIMARY KEY, b);
INSERT INTO t3 SELECT a, b FROM t1;
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t3;
--match 247
SELECT a FROM t3 WHERE b='x17y';
--match 17
CREATE INDEX t3b ON t3(b);
SELECT a FROM t3 WHERE b='x17y';
--match 17
SELECT a FROM t3 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 4
DROP TABLE IF EXISTS t4;
CREATE TABLE t4(a INTEGER PRIMARY KEY, b);
INSERT INTO t4 SELECT a, b FROM t1;
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t4;
--match 247
SELECT a FROM t4 WHERE b='x17y';
--match 17
CREATE INDEX t4b ON t4(b);
SELECT a FROM t4 WHERE b='x17y';
--match 17
SELECT a FROM t4 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 5
DROP TABLE IF EXISTS t5;
CREATE TABLE t5(a INTEGER PRIMARY KEY, b);
INSERT INTO t5 SELECT a, b FROM t1;
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t5;
--match 247
SELECT a FROM t5 WHERE b='x17y';
--match 17
CREATE INDEX t5b ON t5(b);
SELECT a FROM t5 WHERE b='x17y';
--match 17
SELECT a FROM t5 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--wait all
/* After the database file has been set up, run the crash2 subscript
** multiple times. */
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest

View File

@@ -0,0 +1,53 @@
/*
** This script is called from crash01.test and config02.test and perhaps other
** script. After the database file has been set up, make a big rollback
** journal in client 1, then crash client 1.
** Then in the other clients, do an integrity check.
*/
--task 1 leave-hot-journal
--sleep 5
--finish
PRAGMA cache_size=10;
BEGIN;
UPDATE t1 SET b=randomblob(20000);
UPDATE t2 SET b=randomblob(20000);
UPDATE t3 SET b=randomblob(20000);
UPDATE t4 SET b=randomblob(20000);
UPDATE t5 SET b=randomblob(20000);
UPDATE t1 SET b=NULL;
UPDATE t2 SET b=NULL;
UPDATE t3 SET b=NULL;
UPDATE t4 SET b=NULL;
UPDATE t5 SET b=NULL;
--print Task one crashing an incomplete transaction
--exit 1
--end
--task 2 integrity_check-2
SELECT count(*) FROM t1;
--match 64
--sleep 100
PRAGMA integrity_check(10);
--match ok
--end
--task 3 integrity_check-3
SELECT count(*) FROM t1;
--match 64
--sleep 100
PRAGMA integrity_check(10);
--match ok
--end
--task 4 integrity_check-4
SELECT count(*) FROM t1;
--match 64
--sleep 100
PRAGMA integrity_check(10);
--match ok
--end
--task 5 integrity_check-5
SELECT count(*) FROM t1;
--match 64
--sleep 100
PRAGMA integrity_check(10);
--match ok
--end
--wait all

View File

@@ -0,0 +1,22 @@
#include <stdbool.h>
#include <stddef.h>
#include "sqlite3.c"
//
#include "os.c"
sqlite3_destructor_type malloc_destructor = &free;
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);
int sqlite3_os_init() {
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
}
__attribute__((constructor)) void premain() { sqlite3_initialize(); }
static int dont_unlink(const char *pathname) { return 0; }
#define sqlite3_enable_load_extension(...)
#define sqlite3_trace(...)
#define unlink dont_unlink
#undef UNUSED_PARAMETER
#include "mptest.c"

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:37ceeed293b9f09e9770b40eda3f625447f5a3a74208709886d4411d12f93414
size 1486113

View File

@@ -0,0 +1,415 @@
/*
** This script sets up five different tasks all writing and updating
** the database at the same time, but each in its own table.
*/
--task 1 build-t1
DROP TABLE IF EXISTS t1;
CREATE TABLE t1(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t1 VALUES(1, randomblob(2000));
INSERT INTO t1 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t1 SELECT a+2, randomblob(1500) FROM t1;
INSERT INTO t1 SELECT a+4, randomblob(1500) FROM t1;
INSERT INTO t1 SELECT a+8, randomblob(1500) FROM t1;
--sleep 1
INSERT INTO t1 SELECT a+16, randomblob(1500) FROM t1;
--sleep 1
INSERT INTO t1 SELECT a+32, randomblob(1500) FROM t1;
SELECT count(*) FROM t1;
--match 64
SELECT avg(length(b)) FROM t1;
--match 1500.0
--sleep 2
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t1;
--match 247
SELECT a FROM t1 WHERE b='x17y';
--match 17
CREATE INDEX t1b ON t1(b);
SELECT a FROM t1 WHERE b='x17y';
--match 17
SELECT a FROM t1 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 2 build-t2
DROP TABLE IF EXISTS t2;
CREATE TABLE t2(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t2 VALUES(1, randomblob(2000));
INSERT INTO t2 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t2 SELECT a+2, randomblob(1500) FROM t2;
INSERT INTO t2 SELECT a+4, randomblob(1500) FROM t2;
INSERT INTO t2 SELECT a+8, randomblob(1500) FROM t2;
--sleep 1
INSERT INTO t2 SELECT a+16, randomblob(1500) FROM t2;
--sleep 1
INSERT INTO t2 SELECT a+32, randomblob(1500) FROM t2;
SELECT count(*) FROM t2;
--match 64
SELECT avg(length(b)) FROM t2;
--match 1500.0
--sleep 2
UPDATE t2 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t2;
--match 247
SELECT a FROM t2 WHERE b='x17y';
--match 17
CREATE INDEX t2b ON t2(b);
SELECT a FROM t2 WHERE b='x17y';
--match 17
SELECT a FROM t2 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 3 build-t3
DROP TABLE IF EXISTS t3;
CREATE TABLE t3(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t3 VALUES(1, randomblob(2000));
INSERT INTO t3 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t3 SELECT a+2, randomblob(1500) FROM t3;
INSERT INTO t3 SELECT a+4, randomblob(1500) FROM t3;
INSERT INTO t3 SELECT a+8, randomblob(1500) FROM t3;
--sleep 1
INSERT INTO t3 SELECT a+16, randomblob(1500) FROM t3;
--sleep 1
INSERT INTO t3 SELECT a+32, randomblob(1500) FROM t3;
SELECT count(*) FROM t3;
--match 64
SELECT avg(length(b)) FROM t3;
--match 1500.0
--sleep 2
UPDATE t3 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t3;
--match 247
SELECT a FROM t3 WHERE b='x17y';
--match 17
CREATE INDEX t3b ON t3(b);
SELECT a FROM t3 WHERE b='x17y';
--match 17
SELECT a FROM t3 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 4 build-t4
DROP TABLE IF EXISTS t4;
CREATE TABLE t4(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t4 VALUES(1, randomblob(2000));
INSERT INTO t4 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t4 SELECT a+2, randomblob(1500) FROM t4;
INSERT INTO t4 SELECT a+4, randomblob(1500) FROM t4;
INSERT INTO t4 SELECT a+8, randomblob(1500) FROM t4;
--sleep 1
INSERT INTO t4 SELECT a+16, randomblob(1500) FROM t4;
--sleep 1
INSERT INTO t4 SELECT a+32, randomblob(1500) FROM t4;
SELECT count(*) FROM t4;
--match 64
SELECT avg(length(b)) FROM t4;
--match 1500.0
--sleep 2
UPDATE t4 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t4;
--match 247
SELECT a FROM t4 WHERE b='x17y';
--match 17
CREATE INDEX t4b ON t4(b);
SELECT a FROM t4 WHERE b='x17y';
--match 17
SELECT a FROM t4 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 5 build-t5
DROP TABLE IF EXISTS t5;
CREATE TABLE t5(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t5 VALUES(1, randomblob(2000));
INSERT INTO t5 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t5 SELECT a+2, randomblob(1500) FROM t5;
INSERT INTO t5 SELECT a+4, randomblob(1500) FROM t5;
INSERT INTO t5 SELECT a+8, randomblob(1500) FROM t5;
--sleep 1
INSERT INTO t5 SELECT a+16, randomblob(1500) FROM t5;
--sleep 1
INSERT INTO t5 SELECT a+32, randomblob(1500) FROM t5;
SELECT count(*) FROM t5;
--match 64
SELECT avg(length(b)) FROM t5;
--match 1500.0
--sleep 2
UPDATE t5 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t5;
--match 247
SELECT a FROM t5 WHERE b='x17y';
--match 17
CREATE INDEX t5b ON t5(b);
SELECT a FROM t5 WHERE b='x17y';
--match 17
SELECT a FROM t5 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--wait all
SELECT count(*), sum(length(b)) FROM t1;
--match 64 247
SELECT count(*), sum(length(b)) FROM t2;
--match 64 247
SELECT count(*), sum(length(b)) FROM t3;
--match 64 247
SELECT count(*), sum(length(b)) FROM t4;
--match 64 247
SELECT count(*), sum(length(b)) FROM t5;
--match 64 247
--task 1
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 5
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 3
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 2
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 4
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--wait all
--task 5
DROP INDEX t5b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t5b ON t5(b DESC);
--end
--task 3
DROP INDEX t3b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t3b ON t3(b DESC);
--end
--task 1
DROP INDEX t1b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t1b ON t1(b DESC);
--end
--task 2
DROP INDEX t2b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t2b ON t2(b DESC);
--end
--task 4
DROP INDEX t4b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t4b ON t4(b DESC);
--end
--wait all
--task 1
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 5
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 3
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 2
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 4
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--wait all
VACUUM;
PRAGMA integrity_check(10);
--match ok
--task 1
UPDATE t1 SET b=randomblob(20000);
--sleep 5
UPDATE t1 SET b='x'||a||'y';
SELECT a FROM t1 WHERE b='x63y';
--match 63
--end
--task 2
UPDATE t2 SET b=randomblob(20000);
--sleep 5
UPDATE t2 SET b='x'||a||'y';
SELECT a FROM t2 WHERE b='x63y';
--match 63
--end
--task 3
UPDATE t3 SET b=randomblob(20000);
--sleep 5
UPDATE t3 SET b='x'||a||'y';
SELECT a FROM t3 WHERE b='x63y';
--match 63
--end
--task 4
UPDATE t4 SET b=randomblob(20000);
--sleep 5
UPDATE t4 SET b='x'||a||'y';
SELECT a FROM t4 WHERE b='x63y';
--match 63
--end
--task 5
UPDATE t5 SET b=randomblob(20000);
--sleep 5
UPDATE t5 SET b='x'||a||'y';
SELECT a FROM t5 WHERE b='x63y';
--match 63
--end
--wait all
--task 1
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--task 5
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--task 3
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--task 2
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--task 4
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--wait all

View File

@@ -0,0 +1,85 @@
package speedtest1
import (
"bytes"
"context"
"crypto/rand"
"io"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"testing"
_ "embed"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"github.com/ncruces/go-sqlite3/internal/vfs"
)
//go:embed testdata/speedtest1.wasm
var binary []byte
var (
rt wazero.Runtime
module wazero.CompiledModule
output bytes.Buffer
options []string
)
func init() {
ctx := context.TODO()
rt = wazero.NewRuntime(ctx)
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
env := vfs.NewEnvModuleBuilder(rt)
_, err := env.Instantiate(ctx)
if err != nil {
panic(err)
}
module, err = rt.CompileModule(ctx, binary)
if err != nil {
panic(err)
}
}
func TestMain(m *testing.M) {
i := 1
options = append(options, "speedtest1")
for _, arg := range os.Args[1:] {
if strings.HasPrefix(arg, "-test.") {
os.Args[i] = arg
i++
} else {
options = append(options, arg)
}
}
os.Args = os.Args[:i]
code := m.Run()
io.Copy(os.Stderr, &output)
os.Exit(code)
}
func Benchmark_speedtest1(b *testing.B) {
output.Reset()
ctx, vfs := vfs.Context(context.Background())
name := filepath.Join(b.TempDir(), "test.db")
args := append(options, "--size", strconv.Itoa(b.N), name)
cfg := wazero.NewModuleConfig().
WithArgs(args...).WithName("speedtest1").
WithStdout(&output).WithStderr(&output).
WithSysWalltime().WithSysNanotime().WithSysNanosleep().
WithOsyield(runtime.Gosched).
WithRandSource(rand.Reader)
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
b.Error(err)
}
vfs.Close()
mod.Close(ctx)
}

View File

@@ -0,0 +1 @@
speedtest1.wasm filter=lfs diff=lfs merge=lfs -text

View File

@@ -0,0 +1 @@
speedtest1.c

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env bash
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../../../../../
BINARYEN="$ROOT/tools/binaryen-version_112/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
-o speedtest1.wasm main.c \
-I"$ROOT/sqlite3" \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-Wl,--stack-first \
-Wl,--import-undefined \
-D_HAVE_SQLITE_CONFIG_H
"$BINARYEN/wasm-opt" -g -O2 speedtest1.wasm -o speedtest1.tmp \
--enable-multivalue --enable-mutable-globals \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext
mv speedtest1.tmp speedtest1.wasm

View File

@@ -0,0 +1,16 @@
#include <stdbool.h>
#include <stddef.h>
#include "sqlite3.c"
//
#include "os.c"
sqlite3_destructor_type malloc_destructor = &free;
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);
int sqlite3_os_init() {
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
}
#define randomFunc(args...) randomFunc2(args)
#include "speedtest1.c"

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8167119c344a68217b0301e2e8c288f2e75611d296d7822f841b65911da0275c
size 1520569

399
internal/vfs/vfs.go Normal file
View File

@@ -0,0 +1,399 @@
package vfs
import (
"context"
"crypto/rand"
"errors"
"io"
"io/fs"
"os"
"path/filepath"
"runtime"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/julianday"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
func Instantiate(ctx context.Context, r wazero.Runtime) {
env := NewEnvModuleBuilder(r)
_, err := env.Instantiate(ctx)
if err != nil {
panic(err)
}
}
func NewEnvModuleBuilder(r wazero.Runtime) wazero.HostModuleBuilder {
env := r.NewHostModuleBuilder("env")
registerFuncT(env, "os_localtime", vfsLocaltime)
registerFunc3(env, "os_randomness", vfsRandomness)
registerFunc2(env, "os_sleep", vfsSleep)
registerFunc2(env, "os_current_time", vfsCurrentTime)
registerFunc2(env, "os_current_time_64", vfsCurrentTime64)
registerFunc4(env, "os_full_pathname", vfsFullPathname)
registerFunc3(env, "os_delete", vfsDelete)
registerFunc4(env, "os_access", vfsAccess)
registerFunc5(env, "os_open", vfsOpen)
registerFunc1(env, "os_close", vfsClose)
registerFuncRW(env, "os_read", vfsRead)
registerFuncRW(env, "os_write", vfsWrite)
registerFuncT(env, "os_truncate", vfsTruncate)
registerFunc2(env, "os_sync", vfsSync)
registerFunc2(env, "os_file_size", vfsFileSize)
registerFunc3(env, "os_file_control", vfsFileControl)
registerFunc1(env, "os_sector_size", vfsSectorSize)
registerFunc1(env, "os_device_characteristics", vfsDeviceCharacteristics)
registerFunc2(env, "os_lock", vfsLock)
registerFunc2(env, "os_unlock", vfsUnlock)
registerFunc2(env, "os_check_reserved_lock", vfsCheckReservedLock)
return env
}
type vfsKey struct{}
type vfsState struct {
files []vfsFile
}
func Context(ctx context.Context) (context.Context, io.Closer) {
vfs := &vfsState{}
return context.WithValue(ctx, vfsKey{}, vfs), vfs
}
func (vfs *vfsState) Close() error {
for _, f := range vfs.files {
if f.File != nil {
f.Close()
}
}
vfs.files = nil
return nil
}
func vfsLocaltime(ctx context.Context, mod api.Module, pTm uint32, t int64) _ErrorCode {
tm := time.Unix(t, 0)
var isdst int
if tm.IsDST() {
isdst = 1
}
const size = 32 / 8
// https://pubs.opengroup.org/onlinepubs/7908799/xsh/time.h.html
util.WriteUint32(mod, pTm+0*size, uint32(tm.Second()))
util.WriteUint32(mod, pTm+1*size, uint32(tm.Minute()))
util.WriteUint32(mod, pTm+2*size, uint32(tm.Hour()))
util.WriteUint32(mod, pTm+3*size, uint32(tm.Day()))
util.WriteUint32(mod, pTm+4*size, uint32(tm.Month()-time.January))
util.WriteUint32(mod, pTm+5*size, uint32(tm.Year()-1900))
util.WriteUint32(mod, pTm+6*size, uint32(tm.Weekday()-time.Sunday))
util.WriteUint32(mod, pTm+7*size, uint32(tm.YearDay()-1))
util.WriteUint32(mod, pTm+8*size, uint32(isdst))
return _OK
}
func vfsRandomness(ctx context.Context, mod api.Module, pVfs, nByte, zByte uint32) uint32 {
mem := util.View(mod, zByte, uint64(nByte))
n, _ := rand.Reader.Read(mem)
return uint32(n)
}
func vfsSleep(ctx context.Context, mod api.Module, pVfs, nMicro uint32) _ErrorCode {
time.Sleep(time.Duration(nMicro) * time.Microsecond)
return _OK
}
func vfsCurrentTime(ctx context.Context, mod api.Module, pVfs, prNow uint32) _ErrorCode {
day := julianday.Float(time.Now())
util.WriteFloat64(mod, prNow, day)
return _OK
}
func vfsCurrentTime64(ctx context.Context, mod api.Module, pVfs, piNow uint32) _ErrorCode {
day, nsec := julianday.Date(time.Now())
msec := day*86_400_000 + nsec/1_000_000
util.WriteUint64(mod, piNow, uint64(msec))
return _OK
}
func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull, zFull uint32) _ErrorCode {
rel := util.ReadString(mod, zRelative, _MAX_PATHNAME)
abs, err := filepath.Abs(rel)
if err != nil {
return _CANTOPEN_FULLPATH
}
size := uint64(len(abs) + 1)
if size > uint64(nFull) {
return _CANTOPEN_FULLPATH
}
mem := util.View(mod, zFull, size)
mem[len(abs)] = 0
copy(mem, abs)
if fi, err := os.Lstat(abs); err == nil {
if fi.Mode()&fs.ModeSymlink != 0 {
return _OK_SYMLINK
}
return _OK
} else if errors.Is(err, fs.ErrNotExist) {
return _OK
}
return _CANTOPEN_FULLPATH
}
func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32) _ErrorCode {
path := util.ReadString(mod, zPath, _MAX_PATHNAME)
err := os.Remove(path)
if errors.Is(err, fs.ErrNotExist) {
return _IOERR_DELETE_NOENT
}
if err != nil {
return _IOERR_DELETE
}
if runtime.GOOS != "windows" && syncDir != 0 {
f, err := os.Open(filepath.Dir(path))
if err != nil {
return _OK
}
defer f.Close()
err = osSync(f, false, false)
if err != nil {
return _IOERR_DIR_FSYNC
}
}
return _OK
}
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags _AccessFlag, pResOut uint32) _ErrorCode {
path := util.ReadString(mod, zPath, _MAX_PATHNAME)
err := osAccess(path, flags)
var res uint32
var rc _ErrorCode
if flags == _ACCESS_EXISTS {
switch {
case err == nil:
res = 1
case errors.Is(err, fs.ErrNotExist):
res = 0
default:
rc = _IOERR_ACCESS
}
} else {
switch {
case err == nil:
res = 1
case errors.Is(err, fs.ErrPermission):
res = 0
default:
rc = _IOERR_ACCESS
}
}
util.WriteUint32(mod, pResOut, res)
return rc
}
func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, flags _OpenFlag, pOutFlags uint32) _ErrorCode {
var oflags int
if flags&_OPEN_EXCLUSIVE != 0 {
oflags |= os.O_EXCL
}
if flags&_OPEN_CREATE != 0 {
oflags |= os.O_CREATE
}
if flags&_OPEN_READONLY != 0 {
oflags |= os.O_RDONLY
}
if flags&_OPEN_READWRITE != 0 {
oflags |= os.O_RDWR
}
var err error
var f *os.File
if zName == 0 {
f, err = os.CreateTemp("", "*.db")
} else {
name := util.ReadString(mod, zName, _MAX_PATHNAME)
f, err = osOpenFile(name, oflags, 0666)
}
if err != nil {
return _CANTOPEN
}
if flags&_OPEN_DELETEONCLOSE != 0 {
os.Remove(f.Name())
}
file := openVFSFile(ctx, mod, pFile, f)
file.psow = true
file.readOnly = flags&_OPEN_READONLY != 0
file.syncDir = runtime.GOOS != "windows" &&
flags&(_OPEN_CREATE) != 0 &&
flags&(_OPEN_MAIN_JOURNAL|_OPEN_SUPER_JOURNAL|_OPEN_WAL) != 0
if pOutFlags != 0 {
util.WriteUint32(mod, pOutFlags, uint32(flags))
}
return _OK
}
func vfsClose(ctx context.Context, mod api.Module, pFile uint32) _ErrorCode {
err := closeVFSFile(ctx, mod, pFile)
if err != nil {
return _IOERR_CLOSE
}
return _OK
}
func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst int64) _ErrorCode {
buf := util.View(mod, zBuf, uint64(iAmt))
file := getVFSFile(ctx, mod, pFile)
n, err := file.ReadAt(buf, iOfst)
if n == int(iAmt) {
return _OK
}
if n == 0 && err != io.EOF {
return _IOERR_READ
}
for i := range buf[n:] {
buf[n+i] = 0
}
return _IOERR_SHORT_READ
}
func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst int64) _ErrorCode {
buf := util.View(mod, zBuf, uint64(iAmt))
file := getVFSFile(ctx, mod, pFile)
_, err := file.WriteAt(buf, iOfst)
if err != nil {
return _IOERR_WRITE
}
return _OK
}
func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte int64) _ErrorCode {
file := getVFSFile(ctx, mod, pFile)
err := file.Truncate(nByte)
if err != nil {
return _IOERR_TRUNCATE
}
return _OK
}
func vfsSync(ctx context.Context, mod api.Module, pFile uint32, flags _SyncFlag) _ErrorCode {
dataonly := (flags & _SYNC_DATAONLY) != 0
fullsync := (flags & 0x0f) == _SYNC_FULL
file := getVFSFile(ctx, mod, pFile)
err := osSync(file.File, fullsync, dataonly)
if err != nil {
return _IOERR_FSYNC
}
if runtime.GOOS != "windows" && file.syncDir {
file.syncDir = false
f, err := os.Open(filepath.Dir(file.Name()))
if err != nil {
return _OK
}
defer f.Close()
err = osSync(f, false, false)
if err != nil {
return _IOERR_DIR_FSYNC
}
}
return _OK
}
func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) _ErrorCode {
file := getVFSFile(ctx, mod, pFile)
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return _IOERR_SEEK
}
util.WriteUint64(mod, pSize, uint64(off))
return _OK
}
func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _FcntlOpcode, pArg uint32) _ErrorCode {
switch op {
case _FCNTL_LOCKSTATE:
util.WriteUint32(mod, pArg, uint32(getVFSFile(ctx, mod, pFile).lock))
return _OK
case _FCNTL_LOCK_TIMEOUT:
file := getVFSFile(ctx, mod, pFile)
millis := file.lockTimeout.Milliseconds()
file.lockTimeout = time.Duration(util.ReadUint32(mod, pArg)) * time.Millisecond
util.WriteUint32(mod, pArg, uint32(millis))
return _OK
case _FCNTL_POWERSAFE_OVERWRITE:
file := getVFSFile(ctx, mod, pFile)
switch util.ReadUint32(mod, pArg) {
case 0:
file.psow = false
case 1:
file.psow = true
default:
if file.psow {
util.WriteUint32(mod, pArg, 1)
} else {
util.WriteUint32(mod, pArg, 0)
}
}
case _FCNTL_SIZE_HINT:
return vfsSizeHint(ctx, mod, pFile, pArg)
case _FCNTL_HAS_MOVED:
return vfsFileMoved(ctx, mod, pFile, pArg)
}
// Consider also implementing these opcodes (in use by SQLite):
// _FCNTL_BUSYHANDLER
// _FCNTL_COMMIT_PHASETWO
// _FCNTL_PDB
// _FCNTL_PRAGMA
// _FCNTL_SYNC
return _NOTFOUND
}
func vfsSectorSize(ctx context.Context, mod api.Module, pFile uint32) uint32 {
return _DEFAULT_SECTOR_SIZE
}
func vfsDeviceCharacteristics(ctx context.Context, mod api.Module, pFile uint32) _DeviceCharacteristic {
file := getVFSFile(ctx, mod, pFile)
if file.psow {
return _IOCAP_POWERSAFE_OVERWRITE
}
return 0
}
func vfsSizeHint(ctx context.Context, mod api.Module, pFile, pArg uint32) _ErrorCode {
file := getVFSFile(ctx, mod, pFile)
size := util.ReadUint64(mod, pArg)
err := osAllocate(file.File, int64(size))
if err != nil {
return _IOERR_TRUNCATE
}
return _OK
}
func vfsFileMoved(ctx context.Context, mod api.Module, pFile, pResOut uint32) _ErrorCode {
file := getVFSFile(ctx, mod, pFile)
fi, err := file.Stat()
if err != nil {
return _IOERR_FSTAT
}
pi, err := os.Stat(file.Name())
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return _IOERR_FSTAT
}
var res uint32
if !os.SameFile(fi, pi) {
res = 1
}
util.WriteUint32(mod, pResOut, res)
return _OK
}

54
internal/vfs/vfs_file.go Normal file
View File

@@ -0,0 +1,54 @@
package vfs
import (
"context"
"os"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
)
type vfsFile struct {
*os.File
lockTimeout time.Duration
lock _LockLevel
psow bool
syncDir bool
readOnly bool
}
func newVFSFile(vfs *vfsState, file *os.File) uint32 {
// Find an empty slot.
for id, f := range vfs.files {
if f.File == nil {
vfs.files[id] = vfsFile{File: file}
return uint32(id)
}
}
// Add a new slot.
vfs.files = append(vfs.files, vfsFile{File: file})
return uint32(len(vfs.files) - 1)
}
func getVFSFile(ctx context.Context, mod api.Module, pFile uint32) *vfsFile {
vfs := ctx.Value(vfsKey{}).(*vfsState)
id := util.ReadUint32(mod, pFile+4)
return &vfs.files[id]
}
func openVFSFile(ctx context.Context, mod api.Module, pFile uint32, file *os.File) *vfsFile {
vfs := ctx.Value(vfsKey{}).(*vfsState)
id := newVFSFile(vfs, file)
util.WriteUint32(mod, pFile+4, id)
return &vfs.files[id]
}
func closeVFSFile(ctx context.Context, mod api.Module, pFile uint32) error {
vfs := ctx.Value(vfsKey{}).(*vfsState)
id := util.ReadUint32(mod, pFile+4)
file := vfs.files[id]
vfs.files[id] = vfsFile{}
return file.Close()
}

168
internal/vfs/vfs_lock.go Normal file
View File

@@ -0,0 +1,168 @@
package vfs
import (
"context"
"os"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
)
const (
_PENDING_BYTE = 0x40000000
_RESERVED_BYTE = (_PENDING_BYTE + 1)
_SHARED_FIRST = (_PENDING_BYTE + 2)
_SHARED_SIZE = 510
)
func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock _LockLevel) _ErrorCode {
// Argument check. SQLite never explicitly requests a pending lock.
if eLock != _LOCK_SHARED && eLock != _LOCK_RESERVED && eLock != _LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
file := getVFSFile(ctx, mod, pFile)
switch {
case file.lock < _LOCK_NONE || file.lock > _LOCK_EXCLUSIVE:
// Connection state check.
panic(util.AssertErr())
case file.lock == _LOCK_NONE && eLock > _LOCK_SHARED:
// We never move from unlocked to anything higher than a shared lock.
panic(util.AssertErr())
case file.lock != _LOCK_SHARED && eLock == _LOCK_RESERVED:
// A shared lock is always held when a reserved lock is requested.
panic(util.AssertErr())
}
// If we already have an equal or more restrictive lock, do nothing.
if file.lock >= eLock {
return _OK
}
// Do not allow any kind of write-lock on a read-only database.
if file.readOnly && eLock >= _LOCK_RESERVED {
return _IOERR_LOCK
}
switch eLock {
case _LOCK_SHARED:
// Must be unlocked to get SHARED.
if file.lock != _LOCK_NONE {
panic(util.AssertErr())
}
if rc := osGetSharedLock(file.File, file.lockTimeout); rc != _OK {
return rc
}
file.lock = _LOCK_SHARED
return _OK
case _LOCK_RESERVED:
// Must be SHARED to get RESERVED.
if file.lock != _LOCK_SHARED {
panic(util.AssertErr())
}
if rc := osGetReservedLock(file.File, file.lockTimeout); rc != _OK {
return rc
}
file.lock = _LOCK_RESERVED
return _OK
case _LOCK_EXCLUSIVE:
// Must be SHARED, RESERVED or PENDING to get EXCLUSIVE.
if file.lock <= _LOCK_NONE || file.lock >= _LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
// A PENDING lock is needed before acquiring an EXCLUSIVE lock.
if file.lock < _LOCK_PENDING {
if rc := osGetPendingLock(file.File); rc != _OK {
return rc
}
file.lock = _LOCK_PENDING
}
if rc := osGetExclusiveLock(file.File, file.lockTimeout); rc != _OK {
return rc
}
file.lock = _LOCK_EXCLUSIVE
return _OK
default:
panic(util.AssertErr())
}
}
func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock _LockLevel) _ErrorCode {
// Argument check.
if eLock != _LOCK_NONE && eLock != _LOCK_SHARED {
panic(util.AssertErr())
}
file := getVFSFile(ctx, mod, pFile)
// Connection state check.
if file.lock < _LOCK_NONE || file.lock > _LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
// If we don't have a more restrictive lock, do nothing.
if file.lock <= eLock {
return _OK
}
switch eLock {
case _LOCK_SHARED:
if rc := osDowngradeLock(file.File, file.lock); rc != _OK {
return rc
}
file.lock = _LOCK_SHARED
return _OK
case _LOCK_NONE:
rc := osReleaseLock(file.File, file.lock)
file.lock = _LOCK_NONE
return rc
default:
panic(util.AssertErr())
}
}
func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) _ErrorCode {
file := getVFSFile(ctx, mod, pFile)
// Connection state check.
if file.lock < _LOCK_NONE || file.lock > _LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
var locked bool
var rc _ErrorCode
if file.lock >= _LOCK_RESERVED {
locked = true
} else {
locked, rc = osCheckReservedLock(file.File)
}
var res uint32
if locked {
res = 1
}
util.WriteUint32(mod, pResOut, res)
return rc
}
func osGetReservedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Acquire the RESERVED lock.
return osWriteLock(file, _RESERVED_BYTE, 1, timeout)
}
func osGetPendingLock(file *os.File) _ErrorCode {
// Acquire the PENDING lock.
return osWriteLock(file, _PENDING_BYTE, 1, 0)
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
// Test the RESERVED lock.
return osCheckLock(file, _RESERVED_BYTE, 1)
}

View File

@@ -0,0 +1,172 @@
package vfs
import (
"context"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
)
func Test_vfsLock(t *testing.T) {
switch runtime.GOOS {
case "linux", "darwin", "windows":
break
default:
t.Skip("OS lacks OFD locks")
}
name := filepath.Join(t.TempDir(), "test.db")
// Create a temporary file.
file1, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
if err != nil {
t.Fatal(err)
}
defer file1.Close()
// Open the temporary file again.
file2, err := os.OpenFile(name, os.O_RDWR, 0)
if err != nil {
t.Fatal(err)
}
defer file2.Close()
const (
pFile1 = 4
pFile2 = 16
pOutput = 32
)
mod := util.NewMockModule(128)
ctx, vfs := Context(context.TODO())
defer vfs.Close()
openVFSFile(ctx, mod, pFile1, file1)
openVFSFile(ctx, mod, pFile2, file2)
rc := vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(ctx, mod, pFile2, _LOCK_SHARED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(ctx, mod, pFile2, _LOCK_RESERVED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsLock(ctx, mod, pFile2, _LOCK_SHARED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsLock(ctx, mod, pFile2, _LOCK_EXCLUSIVE)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsLock(ctx, mod, pFile1, _LOCK_SHARED)
if rc == _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsUnlock(ctx, mod, pFile2, _LOCK_SHARED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(ctx, mod, pFile1, _LOCK_SHARED)
if rc != _OK {
t.Fatal("returned", rc)
}
}

View File

@@ -0,0 +1,56 @@
//go:build freebsd || openbsd || netbsd || dragonfly || (darwin && sqlite3_bsd)
package vfs
import (
"os"
"time"
"golang.org/x/sys/unix"
)
func osUnlock(file *os.File, start, len int64) _ErrorCode {
if start == 0 && len == 0 {
err := unix.Flock(int(file.Fd()), unix.LOCK_UN)
if err != nil {
return _IOERR_UNLOCK
}
}
return _OK
}
func osLock(file *os.File, how int, timeout time.Duration, def _ErrorCode) _ErrorCode {
var err error
for {
err = unix.Flock(int(file.Fd()), how)
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
break
}
if timeout < time.Millisecond {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)
}
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.LOCK_SH|unix.LOCK_NB, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.LOCK_EX|unix.LOCK_NB, timeout, _IOERR_LOCK)
}
func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) {
lock := unix.Flock_t{
Type: unix.F_RDLCK,
Start: start,
Len: len,
}
if unix.FcntlFlock(file.Fd(), unix.F_GETLK, &lock) != nil {
return false, _IOERR_CHECKRESERVEDLOCK
}
return lock.Type != unix.F_UNLCK, _OK
}

View File

@@ -0,0 +1,106 @@
//go:build !sqlite3_bsd
package vfs
import (
"io"
"os"
"time"
"golang.org/x/sys/unix"
)
const (
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
_F_OFD_SETLK = 90
_F_OFD_SETLKW = 91
_F_OFD_GETLK = 92
_F_OFD_SETLKWTIMEOUT = 93
)
type flocktimeout_t struct {
fl unix.Flock_t
timeout unix.Timespec
}
func osSync(file *os.File, fullsync, dataonly bool) error {
if fullsync {
return file.Sync()
}
return unix.Fsync(int(file.Fd()))
}
func osAllocate(file *os.File, size int64) error {
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return err
}
if size <= off {
return nil
}
// https://stackoverflow.com/a/11497568/867786
store := unix.Fstore_t{
Flags: unix.F_ALLOCATECONTIG,
Posmode: unix.F_PEOFPOSMODE,
Offset: 0,
Length: size,
}
// Try to get a continous chunk of disk space.
err = unix.FcntlFstore(file.Fd(), unix.F_PREALLOCATE, &store)
if err != nil {
// OK, perhaps we are too fragmented, allocate non-continuous.
store.Flags = unix.F_ALLOCATEALL
unix.FcntlFstore(file.Fd(), unix.F_PREALLOCATE, &store)
}
return file.Truncate(size)
}
func osUnlock(file *os.File, start, len int64) _ErrorCode {
err := unix.FcntlFlock(file.Fd(), _F_OFD_SETLK, &unix.Flock_t{
Type: unix.F_UNLCK,
Start: start,
Len: len,
})
if err != nil {
return _IOERR_UNLOCK
}
return _OK
}
func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, def _ErrorCode) _ErrorCode {
lock := flocktimeout_t{fl: unix.Flock_t{
Type: typ,
Start: start,
Len: len,
}}
var err error
if timeout == 0 {
err = unix.FcntlFlock(file.Fd(), _F_OFD_SETLK, &lock.fl)
} else {
lock.timeout = unix.NsecToTimespec(int64(timeout / time.Nanosecond))
err = unix.FcntlFlock(file.Fd(), _F_OFD_SETLKWTIMEOUT, &lock.fl)
}
return osLockErrorCode(err, def)
}
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_RDLCK, start, len, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_WRLCK, start, len, timeout, _IOERR_LOCK)
}
func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) {
lock := unix.Flock_t{
Type: unix.F_RDLCK,
Start: start,
Len: len,
}
if unix.FcntlFlock(file.Fd(), _F_OFD_GETLK, &lock) != nil {
return false, _IOERR_CHECKRESERVEDLOCK
}
return lock.Type != unix.F_UNLCK, _OK
}

View File

@@ -0,0 +1,25 @@
package vfs
import (
"os"
"golang.org/x/sys/unix"
)
func osSync(file *os.File, fullsync, dataonly bool) error {
if dataonly {
_, _, err := unix.Syscall(unix.SYS_FDATASYNC, file.Fd(), 0, 0)
if err != 0 {
return err
}
return nil
}
return file.Sync()
}
func osAllocate(file *os.File, size int64) error {
if size == 0 {
return nil
}
return unix.Fallocate(int(file.Fd()), 0, 0, size)
}

View File

@@ -0,0 +1,63 @@
//go:build linux || illumos
package vfs
import (
"os"
"time"
"golang.org/x/sys/unix"
)
func osUnlock(file *os.File, start, len int64) _ErrorCode {
err := unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &unix.Flock_t{
Type: unix.F_UNLCK,
Start: start,
Len: len,
})
if err != nil {
return _IOERR_UNLOCK
}
return _OK
}
func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, def _ErrorCode) _ErrorCode {
lock := unix.Flock_t{
Type: typ,
Start: start,
Len: len,
}
var err error
for {
err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock)
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
break
}
if timeout < time.Millisecond {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)
}
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_RDLCK, start, len, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_WRLCK, start, len, timeout, _IOERR_LOCK)
}
func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) {
lock := unix.Flock_t{
Type: unix.F_RDLCK,
Start: start,
Len: len,
}
if unix.FcntlFlock(file.Fd(), unix.F_OFD_GETLK, &lock) != nil {
return false, _IOERR_CHECKRESERVEDLOCK
}
return lock.Type != unix.F_UNLCK, _OK
}

View File

@@ -0,0 +1,23 @@
//go:build !linux && (!darwin || sqlite3_bsd)
package vfs
import (
"io"
"os"
)
func osSync(file *os.File, fullsync, dataonly bool) error {
return file.Sync()
}
func osAllocate(file *os.File, size int64) error {
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return err
}
if size <= off {
return nil
}
return file.Truncate(size)
}

View File

@@ -0,0 +1,87 @@
//go:build unix
package vfs
import (
"io/fs"
"os"
"time"
"golang.org/x/sys/unix"
)
func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(name, flag, perm)
}
func osAccess(path string, flags _AccessFlag) error {
var access uint32 // unix.F_OK
switch flags {
case _ACCESS_READWRITE:
access = unix.R_OK | unix.W_OK
case _ACCESS_READ:
access = unix.R_OK
}
return unix.Access(path, access)
}
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Test the PENDING lock before acquiring a new SHARED lock.
if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending {
return _ErrorCode(_BUSY)
}
// Acquire the SHARED lock.
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
}
func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
if timeout == 0 {
timeout = time.Millisecond
}
// Acquire the EXCLUSIVE lock.
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
}
func osDowngradeLock(file *os.File, state _LockLevel) _ErrorCode {
if state >= _LOCK_EXCLUSIVE {
// Downgrade to a SHARED lock.
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return _IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return _IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return osUnlock(file, _PENDING_BYTE, 2)
}
func osReleaseLock(file *os.File, _ _LockLevel) _ErrorCode {
// Release all locks.
return osUnlock(file, 0, 0)
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(unix.Errno); ok {
switch errno {
case
unix.EACCES,
unix.EAGAIN,
unix.EBUSY,
unix.EINTR,
unix.ENOLCK,
unix.EDEADLK,
unix.ETIMEDOUT:
return _ErrorCode(_BUSY)
case unix.EPERM:
return _ErrorCode(_PERM)
}
}
return def
}

View File

@@ -0,0 +1,240 @@
package vfs
import (
"io/fs"
"os"
"syscall"
"time"
"golang.org/x/sys/windows"
)
// osOpenFile is a simplified copy of [os.openFileNolog]
// that uses syscall.FILE_SHARE_DELETE.
// https://go.dev/src/os/file_windows.go
//
// See: https://go.dev/issue/32088
func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
if name == "" {
return nil, &os.PathError{Op: "open", Path: name, Err: syscall.ENOENT}
}
r, e := syscallOpen(name, flag, uint32(perm.Perm()))
if e != nil {
return nil, &os.PathError{Op: "open", Path: name, Err: e}
}
return os.NewFile(uintptr(r), name), nil
}
func osAccess(path string, flags _AccessFlag) error {
fi, err := os.Stat(path)
if err != nil {
return err
}
if flags == _ACCESS_EXISTS {
return nil
}
var want fs.FileMode = windows.S_IRUSR
if flags == _ACCESS_READWRITE {
want |= windows.S_IWUSR
}
if fi.IsDir() {
want |= windows.S_IXUSR
}
if fi.Mode()&want != want {
return fs.ErrPermission
}
return nil
}
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Acquire the PENDING lock temporarily before acquiring a new SHARED lock.
rc := osReadLock(file, _PENDING_BYTE, 1, timeout)
if rc == _OK {
// Acquire the SHARED lock.
rc = osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
// Release the PENDING lock.
osUnlock(file, _PENDING_BYTE, 1)
}
return rc
}
func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
if timeout == 0 {
timeout = time.Millisecond
}
// Release the SHARED lock.
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Acquire the EXCLUSIVE lock.
rc := osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
if rc != _OK {
// Reacquire the SHARED lock.
osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
}
return rc
}
func osDowngradeLock(file *os.File, state _LockLevel) _ErrorCode {
if state >= _LOCK_EXCLUSIVE {
// Release the SHARED lock.
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Reacquire the SHARED lock.
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// This should never happen.
// We should always be able to reacquire the read lock.
return _IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
if state >= _LOCK_RESERVED {
osUnlock(file, _RESERVED_BYTE, 1)
}
if state >= _LOCK_PENDING {
osUnlock(file, _PENDING_BYTE, 1)
}
return _OK
}
func osReleaseLock(file *os.File, state _LockLevel) _ErrorCode {
// Release all locks.
if state >= _LOCK_RESERVED {
osUnlock(file, _RESERVED_BYTE, 1)
}
if state >= _LOCK_SHARED {
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
}
if state >= _LOCK_PENDING {
osUnlock(file, _PENDING_BYTE, 1)
}
return _OK
}
func osUnlock(file *os.File, start, len uint32) _ErrorCode {
err := windows.UnlockFileEx(windows.Handle(file.Fd()),
0, len, 0, &windows.Overlapped{Offset: start})
if err == windows.ERROR_NOT_LOCKED {
return _OK
}
if err != nil {
return _IOERR_UNLOCK
}
return _OK
}
func osLock(file *os.File, flags, start, len uint32, timeout time.Duration, def _ErrorCode) _ErrorCode {
var err error
for {
err = windows.LockFileEx(windows.Handle(file.Fd()), flags,
0, len, 0, &windows.Overlapped{Offset: start})
if errno, _ := err.(windows.Errno); errno != windows.ERROR_LOCK_VIOLATION {
break
}
if timeout < time.Millisecond {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)
}
func osReadLock(file *os.File, start, len uint32, timeout time.Duration) _ErrorCode {
return osLock(file,
windows.LOCKFILE_FAIL_IMMEDIATELY,
start, len, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len uint32, timeout time.Duration) _ErrorCode {
return osLock(file,
windows.LOCKFILE_FAIL_IMMEDIATELY|windows.LOCKFILE_EXCLUSIVE_LOCK,
start, len, timeout, _IOERR_LOCK)
}
func osCheckLock(file *os.File, start, len uint32) (bool, _ErrorCode) {
rc := osLock(file,
windows.LOCKFILE_FAIL_IMMEDIATELY,
start, len, 0, _IOERR_CHECKRESERVEDLOCK)
if rc == _BUSY {
return true, _OK
}
if rc == _OK {
osUnlock(file, start, len)
}
return false, rc
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(windows.Errno); ok {
// https://devblogs.microsoft.com/oldnewthing/20140905-00/?p=63
switch errno {
case
windows.ERROR_LOCK_VIOLATION,
windows.ERROR_IO_PENDING,
windows.ERROR_OPERATION_ABORTED:
return _BUSY
}
}
return def
}
// syscallOpen is a simplified copy of [syscall.Open]
// that uses syscall.FILE_SHARE_DELETE.
// https://go.dev/src/syscall/syscall_windows.go
func syscallOpen(path string, mode int, perm uint32) (fd syscall.Handle, err error) {
if len(path) == 0 {
return syscall.InvalidHandle, syscall.ERROR_FILE_NOT_FOUND
}
pathp, err := syscall.UTF16PtrFromString(path)
if err != nil {
return syscall.InvalidHandle, err
}
var access uint32
switch mode & (syscall.O_RDONLY | syscall.O_WRONLY | syscall.O_RDWR) {
case syscall.O_RDONLY:
access = syscall.GENERIC_READ
case syscall.O_WRONLY:
access = syscall.GENERIC_WRITE
case syscall.O_RDWR:
access = syscall.GENERIC_READ | syscall.GENERIC_WRITE
}
if mode&syscall.O_CREAT != 0 {
access |= syscall.GENERIC_WRITE
}
if mode&syscall.O_APPEND != 0 {
access &^= syscall.GENERIC_WRITE
access |= syscall.FILE_APPEND_DATA
}
sharemode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE)
var createmode uint32
switch {
case mode&(syscall.O_CREAT|syscall.O_EXCL) == (syscall.O_CREAT | syscall.O_EXCL):
createmode = syscall.CREATE_NEW
case mode&(syscall.O_CREAT|syscall.O_TRUNC) == (syscall.O_CREAT | syscall.O_TRUNC):
createmode = syscall.CREATE_ALWAYS
case mode&syscall.O_CREAT == syscall.O_CREAT:
createmode = syscall.OPEN_ALWAYS
case mode&syscall.O_TRUNC == syscall.O_TRUNC:
createmode = syscall.TRUNCATE_EXISTING
default:
createmode = syscall.OPEN_EXISTING
}
var attrs uint32 = syscall.FILE_ATTRIBUTE_NORMAL
if perm&syscall.S_IWRITE == 0 {
attrs = syscall.FILE_ATTRIBUTE_READONLY
}
if createmode == syscall.OPEN_EXISTING && access == syscall.GENERIC_READ {
// Necessary for opening directory handles.
attrs |= syscall.FILE_FLAG_BACKUP_SEMANTICS
}
return syscall.CreateFile(pathp, access, sharemode, nil, createmode, attrs, 0)
}

276
internal/vfs/vfs_test.go Normal file
View File

@@ -0,0 +1,276 @@
package vfs
import (
"bytes"
"context"
"errors"
"io/fs"
"os"
"path/filepath"
"syscall"
"testing"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/julianday"
)
func Test_vfsLocaltime(t *testing.T) {
mod := util.NewMockModule(128)
ctx := context.TODO()
tm := time.Now()
rc := vfsLocaltime(ctx, mod, 4, tm.Unix())
if rc != 0 {
t.Fatal("returned", rc)
}
if s := util.ReadUint32(mod, 4+0*4); int(s) != tm.Second() {
t.Error("wrong second")
}
if m := util.ReadUint32(mod, 4+1*4); int(m) != tm.Minute() {
t.Error("wrong minute")
}
if h := util.ReadUint32(mod, 4+2*4); int(h) != tm.Hour() {
t.Error("wrong hour")
}
if d := util.ReadUint32(mod, 4+3*4); int(d) != tm.Day() {
t.Error("wrong day")
}
if m := util.ReadUint32(mod, 4+4*4); time.Month(1+m) != tm.Month() {
t.Error("wrong month")
}
if y := util.ReadUint32(mod, 4+5*4); 1900+int(y) != tm.Year() {
t.Error("wrong year")
}
if w := util.ReadUint32(mod, 4+6*4); time.Weekday(w) != tm.Weekday() {
t.Error("wrong weekday")
}
if d := util.ReadUint32(mod, 4+7*4); int(d) != tm.YearDay()-1 {
t.Error("wrong yearday")
}
}
func Test_vfsRandomness(t *testing.T) {
mod := util.NewMockModule(128)
ctx := context.TODO()
rc := vfsRandomness(ctx, mod, 0, 16, 4)
if rc != 16 {
t.Fatal("returned", rc)
}
var zero [16]byte
if got := util.View(mod, 4, 16); bytes.Equal(got, zero[:]) {
t.Fatal("all zero")
}
}
func Test_vfsSleep(t *testing.T) {
mod := util.NewMockModule(128)
ctx := context.TODO()
now := time.Now()
rc := vfsSleep(ctx, mod, 0, 123456)
if rc != 0 {
t.Fatal("returned", rc)
}
want := 123456 * time.Microsecond
if got := time.Since(now); got < want {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsCurrentTime(t *testing.T) {
mod := util.NewMockModule(128)
ctx := context.TODO()
now := time.Now()
rc := vfsCurrentTime(ctx, mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
want := julianday.Float(now)
if got := util.ReadFloat64(mod, 4); float32(got) != float32(want) {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsCurrentTime64(t *testing.T) {
mod := util.NewMockModule(128)
ctx := context.TODO()
now := time.Now()
time.Sleep(time.Millisecond)
rc := vfsCurrentTime64(ctx, mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
day, nsec := julianday.Date(now)
want := day*86_400_000 + nsec/1_000_000
if got := util.ReadUint64(mod, 4); float32(got) != float32(want) {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsFullPathname(t *testing.T) {
mod := util.NewMockModule(128 + _MAX_PATHNAME)
util.WriteString(mod, 4, ".")
ctx := context.TODO()
rc := vfsFullPathname(ctx, mod, 0, 4, 0, 8)
if rc != _CANTOPEN_FULLPATH {
t.Errorf("returned %d, want %d", rc, _CANTOPEN_FULLPATH)
}
rc = vfsFullPathname(ctx, mod, 0, 4, _MAX_PATHNAME, 8)
if rc != _OK {
t.Fatal("returned", rc)
}
want, _ := filepath.Abs(".")
if got := util.ReadString(mod, 8, _MAX_PATHNAME); got != want {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsDelete(t *testing.T) {
name := filepath.Join(t.TempDir(), "test.db")
file, err := os.Create(name)
if err != nil {
t.Fatal(err)
}
file.Close()
mod := util.NewMockModule(128 + _MAX_PATHNAME)
util.WriteString(mod, 4, name)
ctx := context.TODO()
rc := vfsDelete(ctx, mod, 0, 4, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
if _, err := os.Stat(name); !errors.Is(err, fs.ErrNotExist) {
t.Fatal("did not delete the file")
}
rc = vfsDelete(ctx, mod, 0, 4, 1)
if rc != _IOERR_DELETE_NOENT {
t.Fatal("returned", rc)
}
}
func Test_vfsAccess(t *testing.T) {
dir := t.TempDir()
file := filepath.Join(t.TempDir(), "test.db")
if f, err := os.Create(file); err != nil {
t.Fatal(err)
} else {
f.Close()
}
if err := os.Chmod(file, syscall.S_IRUSR); err != nil {
t.Fatal(err)
}
mod := util.NewMockModule(128 + _MAX_PATHNAME)
util.WriteString(mod, 8, dir)
ctx := context.TODO()
rc := vfsAccess(ctx, mod, 0, 8, _ACCESS_EXISTS, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, 4); got != 1 {
t.Error("directory did not exist")
}
rc = vfsAccess(ctx, mod, 0, 8, _ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, 4); got != 1 {
t.Error("can't access directory")
}
util.WriteString(mod, 8, file)
rc = vfsAccess(ctx, mod, 0, 8, _ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, 4); got != 0 {
t.Error("can access file")
}
}
func Test_vfsFile(t *testing.T) {
mod := util.NewMockModule(128)
ctx, vfs := Context(context.TODO())
defer vfs.Close()
// Open a temporary file.
rc := vfsOpen(ctx, mod, 0, 0, 4, _OPEN_CREATE|_OPEN_EXCLUSIVE|_OPEN_READWRITE|_OPEN_DELETEONCLOSE, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
// Write stuff.
text := "Hello world!"
util.WriteString(mod, 16, text)
rc = vfsWrite(ctx, mod, 4, 16, uint32(len(text)), 0)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check file size.
rc = vfsFileSize(ctx, mod, 4, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, 16); got != uint32(len(text)) {
t.Errorf("got %d", got)
}
// Partial read at offset.
rc = vfsRead(ctx, mod, 4, 16, uint32(len(text)), 4)
if rc != _IOERR_SHORT_READ {
t.Fatal("returned", rc)
}
if got := util.ReadString(mod, 16, 64); got != text[4:] {
t.Errorf("got %q", got)
}
// Truncate the file.
rc = vfsTruncate(ctx, mod, 4, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check file size.
rc = vfsFileSize(ctx, mod, 4, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, 16); got != 4 {
t.Errorf("got %d", got)
}
// Read at offset.
rc = vfsRead(ctx, mod, 4, 32, 4, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadString(mod, 32, 64); got != text[:4] {
t.Errorf("got %q", got)
}
// Close the file.
rc = vfsClose(ctx, mod, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
}

111
mem.go
View File

@@ -1,111 +0,0 @@
package sqlite3
import (
"bytes"
"math"
"github.com/tetratelabs/wazero/api"
)
type memory struct {
mod api.Module
}
func (m memory) view(ptr, size uint32) []byte {
if ptr == 0 {
panic(nilErr)
}
buf, ok := m.mod.Memory().Read(ptr, size)
if !ok {
panic(rangeErr)
}
return buf
}
func (m memory) readUint32(ptr uint32) uint32 {
if ptr == 0 {
panic(nilErr)
}
v, ok := m.mod.Memory().ReadUint32Le(ptr)
if !ok {
panic(rangeErr)
}
return v
}
func (m memory) writeUint32(ptr, v uint32) {
if ptr == 0 {
panic(nilErr)
}
ok := m.mod.Memory().WriteUint32Le(ptr, v)
if !ok {
panic(rangeErr)
}
}
func (m memory) readUint64(ptr uint32) uint64 {
if ptr == 0 {
panic(nilErr)
}
v, ok := m.mod.Memory().ReadUint64Le(ptr)
if !ok {
panic(rangeErr)
}
return v
}
func (m memory) writeUint64(ptr uint32, v uint64) {
if ptr == 0 {
panic(nilErr)
}
ok := m.mod.Memory().WriteUint64Le(ptr, v)
if !ok {
panic(rangeErr)
}
}
func (m memory) readFloat64(ptr uint32) float64 {
return math.Float64frombits(m.readUint64(ptr))
}
func (m memory) writeFloat64(ptr uint32, v float64) {
m.writeUint64(ptr, math.Float64bits(v))
}
func (m memory) readString(ptr, maxlen uint32) string {
if ptr == 0 {
panic(nilErr)
}
switch maxlen {
case 0:
return ""
case math.MaxUint32:
// avoid overflow
default:
maxlen = maxlen + 1
}
mem := m.mod.Memory()
buf, ok := mem.Read(ptr, maxlen)
if !ok {
buf, ok = mem.Read(ptr, mem.Size()-ptr)
if !ok {
panic(rangeErr)
}
}
if i := bytes.IndexByte(buf, 0); i < 0 {
panic(noNulErr)
} else {
return string(buf[:i])
}
}
func (m memory) writeBytes(ptr uint32, b []byte) {
buf := m.view(ptr, uint32(len(b)))
copy(buf, b)
}
func (m memory) writeString(ptr uint32, s string) {
buf := m.view(ptr, uint32(len(s)+1))
buf[len(s)] = 0
copy(buf, s)
}

View File

@@ -1,83 +0,0 @@
package sqlite3
import (
"math"
"testing"
)
func Test_memory_view_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.view(0, 8)
t.Error("want panic")
}
func Test_memory_view_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.view(126, 8)
t.Error("want panic")
}
func Test_memory_readUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint32(0)
t.Error("want panic")
}
func Test_memory_readUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint32(126)
t.Error("want panic")
}
func Test_memory_readUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint64(0)
t.Error("want panic")
}
func Test_memory_readUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint64(126)
t.Error("want panic")
}
func Test_memory_writeUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint32(0, 1)
t.Error("want panic")
}
func Test_memory_writeUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint32(126, 1)
t.Error("want panic")
}
func Test_memory_writeUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint64(0, 1)
t.Error("want panic")
}
func Test_memory_writeUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint64(126, 1)
t.Error("want panic")
}
func Test_memory_readString_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readString(130, math.MaxUint32)
t.Error("want panic")
}

346
module.go Normal file
View File

@@ -0,0 +1,346 @@
// Package sqlite3 wraps the C SQLite API.
package sqlite3
import (
"context"
"io"
"math"
"os"
"sync"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/internal/vfs"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
// Configure SQLite WASM.
//
// Importing package embed initializes these
// with an appropriate build of SQLite:
//
// import _ "github.com/ncruces/go-sqlite3/embed"
var (
Binary []byte // WASM binary to load.
Path string // Path to load the binary from.
)
var sqlite3 struct {
runtime wazero.Runtime
compiled wazero.CompiledModule
err error
once sync.Once
}
func instantiateModule() (*module, error) {
ctx := context.Background()
sqlite3.once.Do(compileModule)
if sqlite3.err != nil {
return nil, sqlite3.err
}
cfg := wazero.NewModuleConfig()
mod, err := sqlite3.runtime.InstantiateModule(ctx, sqlite3.compiled, cfg)
if err != nil {
return nil, err
}
return newModule(mod)
}
func compileModule() {
ctx := context.Background()
sqlite3.runtime = wazero.NewRuntime(ctx)
vfs.Instantiate(ctx, sqlite3.runtime)
bin := Binary
if bin == nil && Path != "" {
bin, sqlite3.err = os.ReadFile(Path)
if sqlite3.err != nil {
return
}
}
if bin == nil {
sqlite3.err = util.BinaryErr
return
}
sqlite3.compiled, sqlite3.err = sqlite3.runtime.CompileModule(ctx, bin)
}
type module struct {
ctx context.Context
mod api.Module
vfs io.Closer
api sqliteAPI
arg []uint64
}
func newModule(mod api.Module) (m *module, err error) {
m = &module{}
m.mod = mod
m.ctx, m.vfs = vfs.Context(context.Background())
getFun := func(name string) api.Function {
f := mod.ExportedFunction(name)
if f == nil {
err = util.NoFuncErr + util.ErrorString(name)
return nil
}
return f
}
getVal := func(name string) uint32 {
g := mod.ExportedGlobal(name)
if g == nil {
err = util.NoGlobalErr + util.ErrorString(name)
return 0
}
return util.ReadUint32(mod, uint32(g.Get()))
}
m.api = sqliteAPI{
free: getFun("free"),
malloc: getFun("malloc"),
destructor: getVal("malloc_destructor"),
errcode: getFun("sqlite3_errcode"),
errstr: getFun("sqlite3_errstr"),
errmsg: getFun("sqlite3_errmsg"),
erroff: getFun("sqlite3_error_offset"),
open: getFun("sqlite3_open_v2"),
close: getFun("sqlite3_close"),
closeZombie: getFun("sqlite3_close_v2"),
prepare: getFun("sqlite3_prepare_v3"),
finalize: getFun("sqlite3_finalize"),
reset: getFun("sqlite3_reset"),
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
bindName: getFun("sqlite3_bind_parameter_name"),
bindNull: getFun("sqlite3_bind_null"),
bindInteger: getFun("sqlite3_bind_int64"),
bindFloat: getFun("sqlite3_bind_double"),
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
columnInteger: getFun("sqlite3_column_int64"),
columnFloat: getFun("sqlite3_column_double"),
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
autocommit: getFun("sqlite3_get_autocommit"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
changes: getFun("sqlite3_changes64"),
blobOpen: getFun("sqlite3_blob_open"),
blobClose: getFun("sqlite3_blob_close"),
blobReopen: getFun("sqlite3_blob_reopen"),
blobBytes: getFun("sqlite3_blob_bytes"),
blobRead: getFun("sqlite3_blob_read"),
blobWrite: getFun("sqlite3_blob_write"),
backupInit: getFun("sqlite3_backup_init"),
backupStep: getFun("sqlite3_backup_step"),
backupFinish: getFun("sqlite3_backup_finish"),
backupRemaining: getFun("sqlite3_backup_remaining"),
backupPageCount: getFun("sqlite3_backup_pagecount"),
interrupt: getVal("sqlite3_interrupt_offset"),
}
if err != nil {
return nil, err
}
return m, nil
}
func (m *module) close() error {
err := m.mod.Close(m.ctx)
m.vfs.Close()
return err
}
func (m *module) error(rc uint64, handle uint32, sql ...string) error {
if rc == _OK {
return nil
}
err := Error{code: rc}
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
panic(util.OOMErr)
}
var r []uint64
r = m.call(m.api.errstr, rc)
if r != nil {
err.str = util.ReadString(m.mod, uint32(r[0]), _MAX_STRING)
}
r = m.call(m.api.errmsg, uint64(handle))
if r != nil {
err.msg = util.ReadString(m.mod, uint32(r[0]), _MAX_STRING)
}
if sql != nil {
r = m.call(m.api.erroff, uint64(handle))
if r != nil && r[0] != math.MaxUint32 {
err.sql = sql[0][r[0]:]
}
}
switch err.msg {
case err.str, "not an error":
err.msg = ""
}
return &err
}
func (m *module) call(fn api.Function, params ...uint64) []uint64 {
m.arg = append(m.arg[:0], params...)
r, err := fn.Call(m.ctx, m.arg...)
if err != nil {
// The module closed or panicked; release resources.
m.vfs.Close()
panic(err)
}
return r
}
func (m *module) free(ptr uint32) {
if ptr == 0 {
return
}
m.call(m.api.free, uint64(ptr))
}
func (m *module) new(size uint64) uint32 {
if size > _MAX_ALLOCATION_SIZE {
panic(util.OOMErr)
}
r := m.call(m.api.malloc, size)
ptr := uint32(r[0])
if ptr == 0 && size != 0 {
panic(util.OOMErr)
}
return ptr
}
func (m *module) newBytes(b []byte) uint32 {
if b == nil {
return 0
}
ptr := m.new(uint64(len(b)))
util.WriteBytes(m.mod, ptr, b)
return ptr
}
func (m *module) newString(s string) uint32 {
ptr := m.new(uint64(len(s) + 1))
util.WriteString(m.mod, ptr, s)
return ptr
}
func (m *module) newArena(size uint64) arena {
return arena{
m: m,
base: m.new(size),
size: uint32(size),
}
}
type arena struct {
m *module
ptrs []uint32
base uint32
next uint32
size uint32
}
func (a *arena) free() {
if a.m == nil {
return
}
a.reset()
a.m.free(a.base)
a.m = nil
}
func (a *arena) reset() {
for _, ptr := range a.ptrs {
a.m.free(ptr)
}
a.ptrs = nil
a.next = 0
}
func (a *arena) new(size uint64) uint32 {
if size <= uint64(a.size-a.next) {
ptr := a.base + a.next
a.next += uint32(size)
return ptr
}
ptr := a.m.new(size)
a.ptrs = append(a.ptrs, ptr)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint64(len(s) + 1))
util.WriteString(a.m.mod, ptr, s)
return ptr
}
type sqliteAPI struct {
free api.Function
malloc api.Function
errcode api.Function
errstr api.Function
errmsg api.Function
erroff api.Function
open api.Function
close api.Function
closeZombie api.Function
prepare api.Function
finalize api.Function
reset api.Function
step api.Function
exec api.Function
clearBindings api.Function
bindNull api.Function
bindCount api.Function
bindIndex api.Function
bindName api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
columnInteger api.Function
columnFloat api.Function
columnText api.Function
columnBlob api.Function
columnBytes api.Function
autocommit api.Function
lastRowid api.Function
changes api.Function
blobOpen api.Function
blobClose api.Function
blobReopen api.Function
blobBytes api.Function
blobRead api.Function
blobWrite api.Function
backupInit api.Function
backupStep api.Function
backupFinish api.Function
backupRemaining api.Function
backupPageCount api.Function
destructor uint32
interrupt uint32
}

View File

@@ -4,33 +4,76 @@ import (
"bytes"
"math"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
)
func init() {
Path = "./embed/sqlite3.wasm"
}
func TestConn_error_OOM(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer func() { _ = recover() }()
m.error(uint64(NOMEM), 0)
t.Error("want panic")
}
func TestConn_call_nil(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer func() { _ = recover() }()
m.call(m.api.free)
t.Error("want panic")
}
func TestConn_new(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
defer func() { _ = recover() }()
db.new(math.MaxUint32)
t.Error("want panic")
t.Run("MaxUint32", func(t *testing.T) {
defer func() { _ = recover() }()
m.new(math.MaxUint32)
t.Error("want panic")
})
t.Run("_MAX_ALLOCATION_SIZE", func(t *testing.T) {
defer func() { _ = recover() }()
m.new(_MAX_ALLOCATION_SIZE)
m.new(_MAX_ALLOCATION_SIZE)
t.Error("want panic")
})
}
func TestConn_newArena(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
arena := db.newArena(16)
defer arena.reset()
arena := m.newArena(16)
defer arena.free()
const title = "Lorem ipsum"
@@ -38,7 +81,7 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := db.mem.readString(ptr, math.MaxUint32); got != title {
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != title {
t.Errorf("got %q, want %q", got, title)
}
@@ -47,33 +90,34 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := db.mem.readString(ptr, math.MaxUint32); got != body {
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != body {
t.Errorf("got %q, want %q", got, body)
}
arena.free()
}
func TestConn_newBytes(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
ptr := db.newBytes(nil)
ptr := m.newBytes(nil)
if ptr != 0 {
t.Errorf("got %#x, want nullptr", ptr)
}
buf := []byte("sqlite3")
ptr = db.newBytes(buf)
ptr = m.newBytes(buf)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := buf
if got := db.mem.view(ptr, uint32(len(want))); !bytes.Equal(got, want) {
if got := util.View(m.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -81,25 +125,25 @@ func TestConn_newBytes(t *testing.T) {
func TestConn_newString(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
ptr := db.newString("")
ptr := m.newString("")
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
str := "sqlite3\000sqlite3"
ptr = db.newString(str)
ptr = m.newString(str)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := str + "\000"
if got := db.mem.view(ptr, uint32(len(want))); string(got) != want {
if got := util.View(m.mod, ptr, uint64(len(want))); string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -107,40 +151,40 @@ func TestConn_newString(t *testing.T) {
func TestConn_getString(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
ptr := db.newString("")
ptr := m.newString("")
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
str := "sqlite3" + "\000 drop this"
ptr = db.newString(str)
ptr = m.newString(str)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := "sqlite3"
if got := db.mem.readString(ptr, math.MaxUint32); got != want {
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != want {
t.Errorf("got %q, want %q", got, want)
}
if got := db.mem.readString(ptr, 0); got != "" {
if got := util.ReadString(m.mod, ptr, 0); got != "" {
t.Errorf("got %q, want empty", got)
}
func() {
defer func() { _ = recover() }()
db.mem.readString(ptr, uint32(len(want)/2))
util.ReadString(m.mod, ptr, uint32(len(want)/2))
t.Error("want panic")
}()
func() {
defer func() { _ = recover() }()
db.mem.readString(0, math.MaxUint32)
util.ReadString(m.mod, 0, math.MaxUint32)
t.Error("want panic")
}()
}
@@ -148,18 +192,18 @@ func TestConn_getString(t *testing.T) {
func TestConn_free(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
db.free(0)
m.free(0)
ptr := db.new(1)
ptr := m.new(1)
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
db.free(ptr)
m.free(ptr)
}

3
sqlite3/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
sqlite3.c
sqlite3.h
sqlite3ext.h

View File

@@ -1,13 +1,31 @@
#!/usr/bin/env bash
set -eo pipefail
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
if [ ! -f "sqlite3.c" ]; then
url="https://www.sqlite.org/2022/sqlite-amalgamation-3400100.zip"
curl "$url" > sqlite.zip
unzip -d . sqlite.zip
mv sqlite-amalgamation-*/sqlite3* .
rm -rf sqlite-amalgamation-*
rm sqlite.zip
fi
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3410200.zip"
unzip -d . sqlite-amalgamation-*.zip
mv sqlite-amalgamation-*/sqlite3* .
rm -rf sqlite-amalgamation-*
cd ext/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/uuid.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/series.c"
cd ~-
cd ../internal/vfs/tests/mptest/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/mptest.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/config02.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/crash01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/crash02.subtest"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/multiwrite01.test"
cd ~-
cd ../internal/vfs/tests/speedtest1/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/test/speedtest1.c"
cd ~-

1
sqlite3/ext/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
*.c

7
sqlite3/format.sh Executable file
View File

@@ -0,0 +1,7 @@
#!/usr/bin/env bash
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
shopt -s extglob
clang-format -i !(sqlite3*).@(c|h)

View File

@@ -1,102 +1,32 @@
#include <stdbool.h>
#include <stdlib.h>
#include <time.h>
#include <stddef.h>
#include "sqlite3.h"
#include "sqlite3.c"
//
#include "os.c"
//
#include "ext/base64.c"
#include "ext/decimal.c"
#include "ext/regexp.c"
#include "ext/series.c"
#include "ext/uint.c"
#include "ext/uuid.c"
#include "time.c"
int main() {
int rc = sqlite3_initialize();
if (rc != SQLITE_OK) return 1;
}
int go_localtime(sqlite3_int64, struct tm *);
int go_randomness(sqlite3_vfs *, int nByte, char *zOut);
int go_sleep(sqlite3_vfs *, int microseconds);
int go_current_time(sqlite3_vfs *, double *);
int go_current_time_64(sqlite3_vfs *, sqlite3_int64 *);
int go_open(sqlite3_vfs *, sqlite3_filename zName, sqlite3_file *, int flags,
int *pOutFlags);
int go_delete(sqlite3_vfs *, const char *zName, int syncDir);
int go_access(sqlite3_vfs *, const char *zName, int flags, int *pResOut);
int go_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut);
struct go_file {
sqlite3_file base;
int id;
int eLock;
};
int go_close(sqlite3_file *);
int go_read(sqlite3_file *, void *, int iAmt, sqlite3_int64 iOfst);
int go_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
int go_truncate(sqlite3_file *, sqlite3_int64 size);
int go_sync(sqlite3_file *, int flags);
int go_file_size(sqlite3_file *, sqlite3_int64 *pSize);
int go_file_control(sqlite3_file *pFile, int op, void *pArg);
int go_lock(sqlite3_file *pFile, int eLock);
int go_unlock(sqlite3_file *pFile, int eLock);
int go_check_reserved_lock(sqlite3_file *pFile, int *pResOut);
static int no_lock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; }
static int no_unlock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; }
static int no_check_reserved_lock(sqlite3_file *pFile, int *pResOut) {
*pResOut = 0;
return SQLITE_OK;
}
static int no_file_control(sqlite3_file *pFile, int op, void *pArg) {
return SQLITE_NOTFOUND;
}
static int no_sector_size(sqlite3_file *pFile) { return 0; }
static int no_device_characteristics(sqlite3_file *pFile) { return 0; }
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
return go_localtime((sqlite3_int64)*pTime, pTm);
}
static int go_open_c(sqlite3_vfs *vfs, sqlite3_filename zName,
sqlite3_file *file, int flags, int *pOutFlags) {
static const sqlite3_io_methods go_io = {
.iVersion = 1,
.xClose = go_close,
.xRead = go_read,
.xWrite = go_write,
.xTruncate = go_truncate,
.xSync = go_sync,
.xFileSize = go_file_size,
.xLock = go_lock,
.xUnlock = go_unlock,
.xCheckReservedLock = go_check_reserved_lock,
.xFileControl = no_file_control,
.xSectorSize = no_sector_size,
.xDeviceCharacteristics = no_device_characteristics,
};
int rc = go_open(vfs, zName, file, flags, pOutFlags);
file->pMethods = (char)rc == SQLITE_OK ? &go_io : NULL;
return rc;
}
sqlite3_destructor_type malloc_destructor = &free;
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);
int sqlite3_os_init() {
static sqlite3_vfs go_vfs = {
.iVersion = 2,
.szOsFile = sizeof(struct go_file),
.mxPathname = 512,
.zName = "go",
.xOpen = go_open_c,
.xDelete = go_delete,
.xAccess = go_access,
.xFullPathname = go_full_pathname,
.xRandomness = go_randomness,
.xSleep = go_sleep,
.xCurrentTime = go_current_time,
.xCurrentTimeInt64 = go_current_time_64,
};
return sqlite3_vfs_register(&go_vfs, /*default=*/true);
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
}
sqlite3_destructor_type malloc_destructor = &free;
__attribute__((constructor)) void init() {
sqlite3_initialize();
sqlite3_auto_extension((void (*)(void))sqlite3_base_init);
sqlite3_auto_extension((void (*)(void))sqlite3_decimal_init);
sqlite3_auto_extension((void (*)(void))sqlite3_regexp_init);
sqlite3_auto_extension((void (*)(void))sqlite3_series_init);
sqlite3_auto_extension((void (*)(void))sqlite3_uint_init);
sqlite3_auto_extension((void (*)(void))sqlite3_uuid_init);
sqlite3_auto_extension((void (*)(void))sqlite3_time_init);
}

97
sqlite3/os.c Normal file
View File

@@ -0,0 +1,97 @@
#include <time.h>
#include "sqlite3.h"
int os_localtime(struct tm *, sqlite3_int64);
int os_randomness(sqlite3_vfs *, int nByte, char *zOut);
int os_sleep(sqlite3_vfs *, int microseconds);
int os_current_time(sqlite3_vfs *, double *);
int os_current_time_64(sqlite3_vfs *, sqlite3_int64 *);
int os_open(sqlite3_vfs *, sqlite3_filename zName, sqlite3_file *, int flags,
int *pOutFlags);
int os_delete(sqlite3_vfs *, const char *zName, int syncDir);
int os_access(sqlite3_vfs *, const char *zName, int flags, int *pResOut);
int os_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut);
struct os_file {
sqlite3_file base;
int handle;
};
static_assert(offsetof(struct os_file, handle) == 4, "Unexpected offset");
int os_close(sqlite3_file *);
int os_read(sqlite3_file *, void *, int iAmt, sqlite3_int64 iOfst);
int os_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
int os_truncate(sqlite3_file *, sqlite3_int64 size);
int os_sync(sqlite3_file *, int flags);
int os_file_size(sqlite3_file *, sqlite3_int64 *pSize);
int os_file_control(sqlite3_file *, int op, void *pArg);
int os_sector_size(sqlite3_file *file);
int os_device_characteristics(sqlite3_file *file);
int os_lock(sqlite3_file *, int eLock);
int os_unlock(sqlite3_file *, int eLock);
int os_check_reserved_lock(sqlite3_file *, int *pResOut);
static int os_file_control_w(sqlite3_file *file, int op, void *pArg) {
struct os_file *pFile = (struct os_file *)file;
if (op == SQLITE_FCNTL_VFSNAME) {
*(char **)pArg = sqlite3_mprintf("%s", "os");
return SQLITE_OK;
}
return os_file_control(file, op, pArg);
}
static int os_open_w(sqlite3_vfs *vfs, sqlite3_filename zName,
sqlite3_file *file, int flags, int *pOutFlags) {
static const sqlite3_io_methods os_io = {
.iVersion = 1,
.xClose = os_close,
.xRead = os_read,
.xWrite = os_write,
.xTruncate = os_truncate,
.xSync = os_sync,
.xFileSize = os_file_size,
.xLock = os_lock,
.xUnlock = os_unlock,
.xCheckReservedLock = os_check_reserved_lock,
.xFileControl = os_file_control_w,
.xSectorSize = os_sector_size,
.xDeviceCharacteristics = os_device_characteristics,
};
memset(file, 0, sizeof(struct os_file));
int rc = os_open(vfs, zName, file, flags, pOutFlags);
if (rc) {
return rc;
}
file->pMethods = &os_io;
return SQLITE_OK;
}
sqlite3_vfs *os_vfs() {
static sqlite3_vfs os_vfs = {
.iVersion = 2,
.szOsFile = sizeof(struct os_file),
.mxPathname = 512,
.zName = "os",
.xOpen = os_open_w,
.xDelete = os_delete,
.xAccess = os_access,
.xFullPathname = os_full_pathname,
.xRandomness = os_randomness,
.xSleep = os_sleep,
.xCurrentTime = os_current_time,
.xCurrentTimeInt64 = os_current_time_64,
};
return &os_vfs;
}
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
return os_localtime(pTm, (sqlite3_int64)*pTime);
}

View File

@@ -28,18 +28,32 @@
#define SQLITE_OMIT_AUTOINIT
#define SQLITE_USE_ALLOCA
// Other Options
// #define SQLITE_ALLOW_URI_AUTHORITY
// Because WASM does not support shared memory,
// SQLite disables WAL for WASM builds.
// We set the default locking mode to EXCLUSIVE instead.
// https://www.sqlite.org/wal.html#noshm
#undef SQLITE_OMIT_WAL
#ifndef SQLITE_DEFAULT_LOCKING_MODE
#define SQLITE_DEFAULT_LOCKING_MODE 1
#endif
// Recommended Extensions
// #define SQLITE_ENABLE_MATH_FUNCTIONS 1
// #define SQLITE_ENABLE_FTS3 1
// #define SQLITE_ENABLE_FTS3_PARENTHESIS 1
// #define SQLITE_ENABLE_FTS4 1
// #define SQLITE_ENABLE_FTS5 1
// #define SQLITE_ENABLE_RTREE 1
// #define SQLITE_ENABLE_GEOPOLY 1
#define SQLITE_ENABLE_MATH_FUNCTIONS 1
#define SQLITE_ENABLE_JSON1 1
#define SQLITE_ENABLE_FTS3 1
#define SQLITE_ENABLE_FTS3_PARENTHESIS 1
#define SQLITE_ENABLE_FTS4 1
#define SQLITE_ENABLE_FTS5 1
#define SQLITE_ENABLE_RTREE 1
#define SQLITE_ENABLE_GEOPOLY 1
// Need this to access WAL databases without the use of shared memory.
#define SQLITE_DEFAULT_LOCKING_MODE 1
// Session Extension
// #define SQLITE_ENABLE_SESSION 1
// #define SQLITE_ENABLE_PREUPDATE_HOOK 1
// Implemented in Go.
int localtime_s(struct tm *const pTm, time_t const *const pTime);

32
sqlite3/time.c Normal file
View File

@@ -0,0 +1,32 @@
#include <string.h>
#include "sqlite3.h"
static int time_collation(void *pArg, int nKey1, const void *pKey1, int nKey2,
const void *pKey2) {
// Remove a Z suffix if one key is no longer than the other.
// A Z suffix collates before any character but after the empty string.
// This avoids making different keys equal.
const int nK1 = nKey1;
const int nK2 = nKey2;
const char *pK1 = (const char *)pKey1;
const char *pK2 = (const char *)pKey2;
if (nK1 && nK1 <= nK2 && pK1[nK1 - 1] == 'Z') {
nKey1--;
}
if (nK2 && nK2 <= nK1 && pK2[nK2 - 1] == 'Z') {
nKey2--;
}
int n = nKey1 < nKey2 ? nKey1 : nKey2;
int rc = memcmp(pKey1, pKey2, n);
if (rc == 0) {
rc = nKey1 - nKey2;
}
return rc;
}
int sqlite3_time_init(sqlite3 *db, char **pzErrMsg,
const sqlite3_api_routines *pApi) {
return sqlite3_create_collation(db, "time", SQLITE_UTF8, 0, time_collation);
}

219
stmt.go
View File

@@ -3,6 +3,8 @@ package sqlite3
import (
"math"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Stmt is a prepared statement object.
@@ -10,13 +12,13 @@ import (
// https://www.sqlite.org/c3ref/stmt.html
type Stmt struct {
c *Conn
handle uint32
err error
handle uint32
}
// Close destroys the prepared statement object.
//
// It is safe to close a nil, zero or closed prepared statement.
// It is safe to close a nil, zero or closed Stmt.
//
// https://www.sqlite.org/c3ref/finalize.html
func (s *Stmt) Close() error {
@@ -24,10 +26,7 @@ func (s *Stmt) Close() error {
return nil
}
r, err := s.c.api.finalize.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r := s.c.call(s.c.api.finalize, uint64(s.handle))
s.handle = 0
return s.c.error(r[0])
@@ -37,10 +36,7 @@ func (s *Stmt) Close() error {
//
// https://www.sqlite.org/c3ref/reset.html
func (s *Stmt) Reset() error {
r, err := s.c.api.reset.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r := s.c.call(s.c.api.reset, uint64(s.handle))
s.err = nil
return s.c.error(r[0])
}
@@ -49,10 +45,7 @@ func (s *Stmt) Reset() error {
//
// https://www.sqlite.org/c3ref/clear_bindings.html
func (s *Stmt) ClearBindings() error {
r, err := s.c.api.clearBindings.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r := s.c.call(s.c.api.clearBindings, uint64(s.handle))
return s.c.error(r[0])
}
@@ -66,10 +59,8 @@ func (s *Stmt) ClearBindings() error {
//
// https://www.sqlite.org/c3ref/step.html
func (s *Stmt) Step() bool {
r, err := s.c.api.step.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
s.c.checkInterrupt()
r := s.c.call(s.c.api.step, uint64(s.handle))
if r[0] == _ROW {
return true
}
@@ -101,11 +92,8 @@ func (s *Stmt) Exec() error {
//
// https://www.sqlite.org/c3ref/bind_parameter_count.html
func (s *Stmt) BindCount() int {
r, err := s.c.api.bindCount.Call(s.c.ctx,
r := s.c.call(s.c.api.bindCount,
uint64(s.handle))
if err != nil {
panic(err)
}
return int(r[0])
}
@@ -116,11 +104,8 @@ func (s *Stmt) BindCount() int {
func (s *Stmt) BindIndex(name string) int {
defer s.c.arena.reset()
namePtr := s.c.arena.string(name)
r, err := s.c.api.bindIndex.Call(s.c.ctx,
r := s.c.call(s.c.api.bindIndex,
uint64(s.handle), uint64(namePtr))
if err != nil {
panic(err)
}
return int(r[0])
}
@@ -129,17 +114,14 @@ func (s *Stmt) BindIndex(name string) int {
//
// https://www.sqlite.org/c3ref/bind_parameter_name.html
func (s *Stmt) BindName(param int) string {
r, err := s.c.api.bindName.Call(s.c.ctx,
r := s.c.call(s.c.api.bindName,
uint64(s.handle), uint64(param))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
return ""
}
return s.c.mem.readString(ptr, _MAX_STRING)
return util.ReadString(s.c.mod, ptr, _MAX_STRING)
}
// BindBool binds a bool to the prepared statement.
@@ -168,11 +150,8 @@ func (s *Stmt) BindInt(param int, value int) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindInt64(param int, value int64) error {
r, err := s.c.api.bindInteger.Call(s.c.ctx,
r := s.c.call(s.c.api.bindInteger,
uint64(s.handle), uint64(param), uint64(value))
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
@@ -181,11 +160,8 @@ func (s *Stmt) BindInt64(param int, value int64) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindFloat(param int, value float64) error {
r, err := s.c.api.bindFloat.Call(s.c.ctx,
r := s.c.call(s.c.api.bindFloat,
uint64(s.handle), uint64(param), math.Float64bits(value))
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
@@ -195,13 +171,10 @@ func (s *Stmt) BindFloat(param int, value float64) error {
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindText(param int, value string) error {
ptr := s.c.newString(value)
r, err := s.c.api.bindText.Call(s.c.ctx,
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
s.c.api.destructor, _UTF8)
if err != nil {
panic(err)
}
uint64(s.c.api.destructor), _UTF8)
return s.c.error(r[0])
}
@@ -212,13 +185,10 @@ func (s *Stmt) BindText(param int, value string) error {
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindBlob(param int, value []byte) error {
ptr := s.c.newBytes(value)
r, err := s.c.api.bindBlob.Call(s.c.ctx,
r := s.c.call(s.c.api.bindBlob,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
s.c.api.destructor)
if err != nil {
panic(err)
}
uint64(s.c.api.destructor))
return s.c.error(r[0])
}
@@ -227,11 +197,8 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindZeroBlob(param int, n int64) error {
r, err := s.c.api.bindZeroBlob.Call(s.c.ctx,
r := s.c.call(s.c.api.bindZeroBlob,
uint64(s.handle), uint64(param), uint64(n))
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
@@ -240,11 +207,8 @@ func (s *Stmt) BindZeroBlob(param int, n int64) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindNull(param int) error {
r, err := s.c.api.bindNull.Call(s.c.ctx,
r := s.c.call(s.c.api.bindNull,
uint64(s.handle), uint64(param))
if err != nil {
panic(err)
}
return s.c.error(r[0])
}
@@ -253,6 +217,9 @@ func (s *Stmt) BindNull(param int) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
if format == TimeFormatDefault {
return s.bindRFC3339Nano(param, value)
}
switch v := format.Encode(value).(type) {
case string:
s.BindText(param, v)
@@ -261,20 +228,31 @@ func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
case float64:
s.BindFloat(param, v)
default:
panic(assertErr())
panic(util.AssertErr())
}
return nil
}
func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
const maxlen = uint64(len(time.RFC3339Nano))
ptr := s.c.new(maxlen)
buf := util.View(s.c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(buf)),
uint64(s.c.api.destructor), _UTF8)
return s.c.error(r[0])
}
// ColumnCount returns the number of columns in a result set.
//
// https://www.sqlite.org/c3ref/column_count.html
func (s *Stmt) ColumnCount() int {
r, err := s.c.api.columnCount.Call(s.c.ctx,
r := s.c.call(s.c.api.columnCount,
uint64(s.handle))
if err != nil {
panic(err)
}
return int(r[0])
}
@@ -283,17 +261,14 @@ func (s *Stmt) ColumnCount() int {
//
// https://www.sqlite.org/c3ref/column_name.html
func (s *Stmt) ColumnName(col int) string {
r, err := s.c.api.columnName.Call(s.c.ctx,
r := s.c.call(s.c.api.columnName,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
panic(oomErr)
panic(util.OOMErr)
}
return s.c.mem.readString(ptr, _MAX_STRING)
return util.ReadString(s.c.mod, ptr, _MAX_STRING)
}
// ColumnType returns the initial [Datatype] of the result column.
@@ -301,11 +276,8 @@ func (s *Stmt) ColumnName(col int) string {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnType(col int) Datatype {
r, err := s.c.api.columnType.Call(s.c.ctx,
r := s.c.call(s.c.api.columnType,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
return Datatype(r[0])
}
@@ -336,11 +308,8 @@ func (s *Stmt) ColumnInt(col int) int {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnInt64(col int) int64 {
r, err := s.c.api.columnInteger.Call(s.c.ctx,
r := s.c.call(s.c.api.columnInteger,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
return int64(r[0])
}
@@ -349,11 +318,8 @@ func (s *Stmt) ColumnInt64(col int) int64 {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnFloat(col int) float64 {
r, err := s.c.api.columnFloat.Call(s.c.ctx,
r := s.c.call(s.c.api.columnFloat,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
return math.Float64frombits(r[0])
}
@@ -373,7 +339,7 @@ func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time {
case NULL:
return time.Time{}
default:
panic(assertErr())
panic(util.AssertErr())
}
t, err := format.Decode(v)
if err != nil {
@@ -387,30 +353,7 @@ func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnText(col int) string {
r, err := s.c.api.columnText.Call(s.c.ctx,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
r, err = s.c.api.errcode.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
s.err = s.c.error(r[0])
return ""
}
r, err = s.c.api.columnBytes.Call(s.c.ctx,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
mem := s.c.mem.view(ptr, uint32(r[0]))
return string(mem)
return string(s.ColumnRawText(col))
}
// ColumnBlob appends to buf and returns
@@ -419,28 +362,66 @@ func (s *Stmt) ColumnText(col int) string {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
r, err := s.c.api.columnBlob.Call(s.c.ctx,
return append(buf, s.ColumnRawBlob(col)...)
}
// ColumnRawText returns the value of the result column as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnRawText(col int) []byte {
r := s.c.call(s.c.api.columnText,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
r, err = s.c.api.errcode.Call(s.c.ctx, uint64(s.handle))
if err != nil {
panic(err)
}
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r[0])
return buf[0:0]
return nil
}
r, err = s.c.api.columnBytes.Call(s.c.ctx,
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
return util.View(s.c.mod, ptr, r[0])
}
// ColumnRawBlob returns the value of the result column as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnRawBlob(col int) []byte {
r := s.c.call(s.c.api.columnBlob,
uint64(s.handle), uint64(col))
ptr := uint32(r[0])
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r[0])
return nil
}
mem := s.c.mem.view(ptr, uint32(r[0]))
return append(buf[0:0], mem...)
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
return util.View(s.c.mod, ptr, r[0])
}
// Return true if stmt is an empty SQL statement.
// This is used as an optimization.
// It's OK to always return false here.
func emptyStatement(stmt string) bool {
for _, b := range []byte(stmt) {
switch b {
case ' ', '\n', '\r', '\t', '\v', '\f':
case ';':
default:
return false
}
}
return true
}

127
tests/backup_test.go Normal file
View File

@@ -0,0 +1,127 @@
package tests
import (
"path/filepath"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestBackup(t *testing.T) {
t.Parallel()
backupName := filepath.Join(t.TempDir(), "backup.db")
func() { // Create backup.
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}
err = db.Backup("main", backupName)
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}()
func() { // Restore backup.
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Restore("main", backupName)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
row := 0
ids := []int{0, 1, 2}
names := []string{"go", "zig", "whatever"}
for ; stmt.Step(); row++ {
id := stmt.ColumnInt(0)
name := stmt.ColumnText(1)
if id != ids[row] {
t.Errorf("got %d, want %d", id, ids[row])
}
if name != names[row] {
t.Errorf("got %q, want %q", name, names[row])
}
}
if row != 3 {
t.Errorf("got %d, want %d", row, len(ids))
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}()
func() { // Errors.
db, err := sqlite3.Open(backupName)
if err != nil {
t.Fatal(err)
}
defer db.Close()
tx, err := db.BeginExclusive()
if err != nil {
t.Fatal(err)
}
err = db.Restore("main", backupName)
if err == nil {
t.Fatal("want error")
}
err = tx.Rollback()
if err != nil {
t.Fatal(err)
}
err = db.Restore("main", backupName)
if err == nil {
t.Fatal("want error")
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}()
}

297
tests/blob_test.go Normal file
View File

@@ -0,0 +1,297 @@
package tests
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"io"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestBlob(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
size := blob.Size()
if size != 1024 {
t.Errorf("got %d, want 1024", size)
}
var data [1280]byte
_, err = rand.Read(data[:])
if err != nil {
t.Fatal(err)
}
_, err = io.Copy(blob, bytes.NewReader(data[:size/2]))
if err != nil {
t.Fatal(err)
}
_, err = io.Copy(blob, bytes.NewReader(data[:]))
if !errors.Is(err, sqlite3.ERROR) {
t.Fatal("want error")
}
_, err = io.Copy(blob, bytes.NewReader(data[size/2:size]))
if err != nil {
t.Fatal(err)
}
_, err = blob.Seek(size/4, io.SeekStart)
if err != nil {
t.Fatal(err)
}
if got, err := io.ReadAll(blob); err != nil {
t.Fatal(err)
} else if !bytes.Equal(got, data[size/4:size]) {
t.Errorf("got %q, want %q", got, data[size/4:size])
}
if n, err := blob.Read(make([]byte, 1)); n != 0 || err != io.EOF {
t.Errorf("got (%d, %v), want (0, EOF)", n, err)
}
if err := blob.Close(); err != nil {
t.Fatal(err)
}
if err := db.Close(); err != nil {
t.Fatal(err)
}
}
func TestBlob_invalid(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
_, err = db.OpenBlob("", "test", "col", db.LastInsertRowID(), false)
if !errors.Is(err, sqlite3.ERROR) {
t.Fatal("want error")
}
}
func TestBlob_Write_readonly(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), false)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
_, err = blob.Write([]byte("data"))
if !errors.Is(err, sqlite3.READONLY) {
t.Fatal("want error")
}
}
func TestBlob_Read_expired(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), false)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
err = db.Exec(`DELETE FROM test`)
if err != nil {
t.Fatal(err)
}
_, err = io.ReadAll(blob)
if !errors.Is(err, sqlite3.ABORT) {
t.Fatal("want error", err)
}
}
func TestBlob_Seek(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
_, err = blob.Seek(0, 10)
if err == nil {
t.Fatal("want error")
}
_, err = blob.Seek(-1, io.SeekCurrent)
if err == nil {
t.Fatal("want error")
}
n, err := blob.Seek(1, io.SeekEnd)
if err != nil {
t.Fatal(err)
}
if n != blob.Size()+1 {
t.Errorf("got %d, want %d", n, blob.Size())
}
_, err = blob.Write([]byte("data"))
if !errors.Is(err, sqlite3.ERROR) {
t.Fatal("want error")
}
}
func TestBlob_Reopen(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
var rowids []int64
for i := 0; i < 100; i++ {
err = db.Exec(`INSERT INTO test VALUES (zeroblob(10))`)
if err != nil {
t.Fatal(err)
}
rowids = append(rowids, db.LastInsertRowID())
}
var blob *sqlite3.Blob
for i, rowid := range rowids {
if i > 0 {
err = blob.Reopen(rowid)
} else {
blob, err = db.OpenBlob("main", "test", "col", rowid, true)
}
if err != nil {
t.Fatal(err)
}
_, err = fmt.Fprintf(blob, "blob %d\n", i)
if err != nil {
t.Fatal(err)
}
}
if err := blob.Close(); err != nil {
t.Fatal(err)
}
for i, rowid := range rowids {
if i > 0 {
err = blob.Reopen(rowid)
} else {
blob, err = db.OpenBlob("main", "test", "col", rowid, false)
}
if err != nil {
t.Fatal(err)
}
var got int
_, err = fmt.Fscanf(blob, "blob %d\n", &got)
if err != nil {
t.Fatal(err)
}
if got != i {
t.Errorf("got %d, want %d", got, i)
}
}
if err := blob.Close(); err != nil {
t.Fatal(err)
}
}

View File

@@ -41,7 +41,9 @@ func (t params) mustExec(sql string, args ...interface{}) sql.Result {
}
func (sqliteDB) RunTest(t *testing.T, fn func(params)) {
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "foo.db"))
db, err := sql.Open("sqlite3", "file:"+
filepath.Join(t.TempDir(), "foo.db")+
"?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)&_pragma=synchronous(off)")
if err != nil {
t.Fatalf("foo.db open fail: %v", err)
}
@@ -104,7 +106,7 @@ func testTxQuery(t params) {
}
defer tx.Rollback()
_, err = t.DB.Exec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))")
_, err = tx.Exec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))")
if err != nil {
t.Logf("cannot drop table "+TablePrefix+"foo: %s", err)
}
@@ -124,7 +126,7 @@ func testTxQuery(t params) {
if r.Err() != nil {
t.Fatal(err)
}
t.Fatal("expected one rows")
t.Fatal("expected one row")
}
var name string

View File

@@ -13,19 +13,12 @@ import (
func TestConn_Open_dir(t *testing.T) {
t.Parallel()
_, err := sqlite3.Open(".")
_, err := sqlite3.OpenFlags(".", 0)
if err == nil {
t.Fatal("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.CANTOPEN {
t.Errorf("got %d, want sqlite3.CANTOPEN", rc)
}
if got := err.Error(); got != `sqlite3: unable to open database file` {
t.Error("got message: ", got)
if !errors.Is(err, sqlite3.CANTOPEN) {
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
}
}
@@ -53,23 +46,21 @@ func TestConn_Close_BUSY(t *testing.T) {
if err == nil {
t.Fatal("want error")
}
var serr *sqlite3.Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.BUSY {
t.Errorf("got %d, want sqlite3.BUSY", rc)
if !errors.Is(err, sqlite3.BUSY) {
t.Errorf("got %v, want sqlite3.BUSY", err)
}
var terr interface{ Temporary() bool }
if !errors.As(err, &terr) || !terr.Temporary() {
t.Error("not temporary", err)
}
if got := err.Error(); got != `sqlite3: database is locked: unable to close due to unfinalized statements or unfinished backups` {
t.Error("got message: ", got)
t.Error("got message:", got)
}
}
func TestConn_SetInterrupt(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
@@ -77,7 +68,7 @@ func TestConn_SetInterrupt(t *testing.T) {
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx.Done())
db.SetInterrupt(ctx)
// Interrupt doesn't interrupt this.
err = db.Exec(`SELECT 1`)
@@ -85,7 +76,7 @@ func TestConn_SetInterrupt(t *testing.T) {
t.Fatal(err)
}
db.SetInterrupt(nil)
db.SetInterrupt(context.Background())
stmt, _, err := db.Prepare(`
WITH RECURSIVE
@@ -103,40 +94,24 @@ func TestConn_SetInterrupt(t *testing.T) {
}
defer stmt.Close()
db.SetInterrupt(ctx)
cancel()
db.SetInterrupt(ctx.Done())
var serr *sqlite3.Error
// Interrupting works.
err = stmt.Exec()
if err != nil {
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
// Interrupting sticks.
err = db.Exec(`SELECT 1`)
if err != nil {
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
db.SetInterrupt(nil)
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
db.SetInterrupt(ctx)
// Interrupting can be cleared.
err = db.Exec(`SELECT 1`)
@@ -207,7 +182,7 @@ func TestConn_Prepare_invalid(t *testing.T) {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message: ", got)
t.Error("got message:", got)
}
_, _, err = db.Prepare(`SELECT * FRM sqlite_schema`)
@@ -221,9 +196,9 @@ func TestConn_Prepare_invalid(t *testing.T) {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := serr.SQL(); got != `FRM sqlite_schema` {
t.Error("got SQL: ", got)
t.Error("got SQL:", got)
}
if got := serr.Error(); got != `sqlite3: SQL logic error: near "FRM": syntax error` {
t.Error("got message: ", got)
t.Error("got message:", got)
}
}

View File

@@ -30,7 +30,7 @@ func testDB(t *testing.T, name string) {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}

View File

@@ -34,7 +34,7 @@ func TestDriver(t *testing.T) {
}
res, err := conn.ExecContext(ctx,
`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}

77
tests/ext_test.go Normal file
View File

@@ -0,0 +1,77 @@
package tests
import (
"testing"
"github.com/ncruces/go-sqlite3"
)
func Test_base64(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
// base64
stmt, _, err := db.Prepare(`SELECT base64('TWFueSBoYW5kcyBtYWtlIGxpZ2h0IHdvcmsu')`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal("expected one row")
}
if got := stmt.ColumnText(0); got != "Many hands make light work." {
t.Errorf("got %q", got)
}
}
func Test_decimal(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT decimal_add(decimal('0.1'), decimal('0.2')) = decimal('0.3')`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal("expected one row")
}
if !stmt.ColumnBool(0) {
t.Error("want true")
}
}
func Test_uint(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT 'z2' < 'z11' COLLATE UINT`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal("expected one row")
}
if !stmt.ColumnBool(0) {
t.Error("want true")
}
}

View File

@@ -1,7 +1,6 @@
package tests
import (
"errors"
"io"
"os"
"os/exec"
@@ -15,14 +14,21 @@ import (
)
func TestParallel(t *testing.T) {
var iter int
if testing.Short() {
iter = 1000
} else {
iter = 5000
}
name := filepath.Join(t.TempDir(), "test.db")
testParallel(t, name, 100)
testParallel(t, name, iter)
testIntegrity(t, name)
}
func TestMultiProcess(t *testing.T) {
if testing.Short() {
t.Skip()
t.Skip("skipping in short mode")
}
name := filepath.Join(t.TempDir(), "test.db")
@@ -46,10 +52,6 @@ func TestMultiProcess(t *testing.T) {
testParallel(t, name, 1000)
if err := cmd.Wait(); err != nil {
t.Error(err)
var eerr *exec.ExitError
if errors.As(err, &eerr) {
t.Error(eerr.Stderr)
}
}
testIntegrity(t, name)
}
@@ -72,8 +74,10 @@ func testParallel(t *testing.T, name string, n int) {
defer db.Close()
err = db.Exec(`
PRAGMA locking_mode = NORMAL;
PRAGMA busy_timeout = 10000;
PRAGMA busy_timeout=10000;
PRAGMA synchronous=off;
PRAGMA locking_mode=normal;
PRAGMA journal_mode=truncate;
`)
if err != nil {
return err
@@ -84,7 +88,7 @@ func testParallel(t *testing.T, name string, n int) {
return err
}
err = db.Exec(`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
return err
}
@@ -100,8 +104,8 @@ func testParallel(t *testing.T, name string, n int) {
defer db.Close()
err = db.Exec(`
PRAGMA locking_mode = NORMAL;
PRAGMA busy_timeout = 10000;
PRAGMA busy_timeout=10000;
PRAGMA locking_mode=normal;
`)
if err != nil {
return err
@@ -138,7 +142,7 @@ func testParallel(t *testing.T, name string, n int) {
}
var group errgroup.Group
group.SetLimit(4)
group.SetLimit(6)
for i := 0; i < n; i++ {
if i&7 != 7 {
group.Go(reader)
@@ -148,7 +152,7 @@ func testParallel(t *testing.T, name string, n int) {
}
err = group.Wait()
if err != nil {
t.Fatal(err)
t.Error(err)
}
}

View File

@@ -22,7 +22,7 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`INSERT INTO test(col) VALUES(?)`)
stmt, _, err := db.Prepare(`INSERT INTO test VALUES (?)`)
if err != nil {
t.Fatal(err)
}
@@ -400,7 +400,7 @@ func TestStmt_BindName(t *testing.T) {
}
}
func TestStmt_Time(t *testing.T) {
func TestStmt_ColumnTime(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
@@ -430,23 +430,23 @@ func TestStmt_Time(t *testing.T) {
}
if now := time.Now(); stmt.Step() {
if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); !reference.Equal(got) {
if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); !got.Equal(reference) {
t.Errorf("got %v, want %v", got, reference)
}
if got := stmt.ColumnTime(1, sqlite3.TimeFormatAuto); !reference.Equal(got) {
if got := stmt.ColumnTime(1, sqlite3.TimeFormatAuto); !got.Equal(reference) {
t.Errorf("got %v, want %v", got, reference)
}
if got := stmt.ColumnTime(2, sqlite3.TimeFormatAuto); reference.Sub(got) > time.Millisecond {
if got := stmt.ColumnTime(2, sqlite3.TimeFormatAuto); got.Sub(reference).Abs() > time.Millisecond {
t.Errorf("got %v, want %v", got, reference)
}
if got := stmt.ColumnTime(3, sqlite3.TimeFormatAuto); now.Sub(got) > time.Second {
if got := stmt.ColumnTime(3, sqlite3.TimeFormatAuto); got.Sub(now).Abs() > time.Second {
t.Errorf("got %v, want %v", got, now)
}
if got := stmt.ColumnTime(4, sqlite3.TimeFormatAuto); now.Sub(got) > time.Second {
if got := stmt.ColumnTime(4, sqlite3.TimeFormatAuto); got.Sub(now).Abs() > time.Second {
t.Errorf("got %v, want %v", got, now)
}
if got := stmt.ColumnTime(5, sqlite3.TimeFormatAuto); now.Sub(got) > time.Millisecond {
if got := stmt.ColumnTime(5, sqlite3.TimeFormatAuto); got.Sub(now).Abs() > time.Second/10 {
t.Errorf("got %v, want %v", got, now)
}

169
tests/time_test.go Normal file
View File

@@ -0,0 +1,169 @@
package tests
import (
"reflect"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
)
func TestTimeFormat_Encode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt sqlite3.TimeFormat
time time.Time
want any
}{
{sqlite3.TimeFormatDefault, reference, "2013-10-07T04:23:19.12-04:00"},
{sqlite3.TimeFormatJulianDay, reference, 2456572.849526851851852},
{sqlite3.TimeFormatUnix, reference, int64(1381134199)},
{sqlite3.TimeFormatUnixFrac, reference, 1381134199.120},
{sqlite3.TimeFormatUnixMilli, reference, int64(1381134199_120)},
{sqlite3.TimeFormatUnixMicro, reference, int64(1381134199_120000)},
{sqlite3.TimeFormatUnixNano, reference, int64(1381134199_120000000)},
{sqlite3.TimeFormat7, reference, "2013-10-07T08:23:19.120"},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
if got := tt.fmt.Encode(tt.time); !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q.Encode(%v) = %v, want %v", tt.fmt, tt.time, got, tt.want)
}
})
}
}
func TestTimeFormat_Decode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt sqlite3.TimeFormat
val any
want time.Time
wantDelta time.Duration
wantErr bool
}{
{sqlite3.TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
{sqlite3.TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
{sqlite3.TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
{sqlite3.TimeFormatJulianDay, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
{sqlite3.TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
{sqlite3.TimeFormatUnix, int64(1381134199), reference, time.Second, false},
{sqlite3.TimeFormatUnix, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnix, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMilli, "1381134199120", reference, 0, false},
{sqlite3.TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
{sqlite3.TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
{sqlite3.TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMilli, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
{sqlite3.TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
{sqlite3.TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
{sqlite3.TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMicro, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
{sqlite3.TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
{sqlite3.TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
{sqlite3.TimeFormatUnixNano, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixNano, false, time.Time{}, 0, true},
{sqlite3.TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
{sqlite3.TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
{sqlite3.TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199", reference, time.Second, false},
{sqlite3.TimeFormatAuto, "1381134199120", reference, 0, false},
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, false},
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, false},
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, true},
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, false},
{sqlite3.TimeFormat3, false, time.Time{}, 0, true},
{sqlite3.TimeFormat9, false, time.Time{}, 0, true},
{sqlite3.TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
{sqlite3.TimeFormatDefault, false, time.Time{}, 0, true},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
got, err := tt.fmt.Decode(tt.val)
if (err != nil) != tt.wantErr {
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
return
}
if tt.want.Sub(got).Abs() > tt.wantDelta {
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
}
})
}
}
func TestDB_timeCollation(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS times (tstamp COLLATE TIME)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`INSERT INTO times VALUES (?), (?), (?)`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
stmt.BindTime(1, time.Unix(0, 0).UTC(), sqlite3.TimeFormatDefault)
stmt.BindTime(2, time.Unix(0, -1).UTC(), sqlite3.TimeFormatDefault)
stmt.BindTime(3, time.Unix(0, +1).UTC(), sqlite3.TimeFormatDefault)
stmt.Step()
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
stmt, _, err = db.Prepare(`SELECT tstamp FROM times ORDER BY tstamp`)
if err != nil {
t.Fatal(err)
}
var t0 time.Time
for stmt.Step() {
t1 := stmt.ColumnTime(0, sqlite3.TimeFormatAuto)
if t0.After(t1) {
t.Errorf("got %v after %v", t0, t1)
}
t0 = t1
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
}

535
tests/tx_test.go Normal file
View File

@@ -0,0 +1,535 @@
package tests
import (
"context"
"errors"
"testing"
"github.com/ncruces/go-sqlite3"
)
func TestConn_Transaction_exec(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
errFailed := errors.New("failed")
count := func() int {
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
return stmt.ColumnInt(0)
}
t.Fatal(stmt.Err())
return 0
}
insert := func(succeed bool) (err error) {
tx := db.Begin()
defer tx.End(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
t.Fatal(err)
}
if succeed {
return nil
}
return errFailed
}
err = insert(true)
if err != nil {
t.Fatal(err)
}
if got := count(); got != 1 {
t.Errorf("got %d, want 1", got)
}
err = insert(true)
if err != nil {
t.Fatal(err)
}
if got := count(); got != 2 {
t.Errorf("got %d, want 2", got)
}
err = insert(false)
if err != errFailed {
t.Errorf("got %v, want errFailed", err)
}
if got := count(); got != 2 {
t.Errorf("got %d, want 2", got)
}
}
func TestConn_Transaction_panic(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES ('one');`)
if err != nil {
t.Fatal(err)
}
panics := func() (err error) {
tx := db.Begin()
defer tx.End(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
return err
}
panic("omg!")
}
defer func() {
p := recover()
if p != "omg!" {
t.Errorf("got %v, want panic", p)
}
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
return
}
t.Fatal(stmt.Err())
}()
err = panics()
if err != nil {
t.Error(err)
}
}
func TestConn_Transaction_interrupt(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
tx, err := db.BeginImmediate()
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
}
tx.End(&err)
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
tx, err = db.BeginExclusive()
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (2)`)
if err != nil {
t.Fatal(err)
}
cancel()
_, err = db.BeginImmediate()
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = db.Exec(`INSERT INTO test VALUES (3)`)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = nil
tx.End(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
db.SetInterrupt(context.Background())
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
}
err = stmt.Err()
if err != nil {
t.Error(err)
}
}
func TestConn_Transaction_interrupted(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
cancel()
tx := db.Begin()
err = tx.Commit()
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = nil
tx.End(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
}
func TestConn_Transaction_rollback(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
tx := db.Begin()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`COMMIT`)
if err != nil {
t.Fatal(err)
}
tx.End(&err)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
}
err = stmt.Err()
if err != nil {
t.Error(err)
}
}
func TestConn_Savepoint_exec(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
errFailed := errors.New("failed")
count := func() int {
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
return stmt.ColumnInt(0)
}
t.Fatal(stmt.Err())
return 0
}
insert := func(succeed bool) (err error) {
defer db.Savepoint().Release(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
t.Fatal(err)
}
if succeed {
return nil
}
return errFailed
}
err = insert(true)
if err != nil {
t.Fatal(err)
}
if got := count(); got != 1 {
t.Errorf("got %d, want 1", got)
}
err = insert(true)
if err != nil {
t.Fatal(err)
}
if got := count(); got != 2 {
t.Errorf("got %d, want 2", got)
}
err = insert(false)
if err != errFailed {
t.Errorf("got %v, want errFailed", err)
}
if got := count(); got != 2 {
t.Errorf("got %d, want 2", got)
}
}
func TestConn_Savepoint_panic(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES ('one');`)
if err != nil {
t.Fatal(err)
}
panics := func() (err error) {
defer db.Savepoint().Release(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
return err
}
panic("omg!")
}
defer func() {
p := recover()
if p != "omg!" {
t.Errorf("got %v, want panic", p)
}
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
return
}
t.Fatal(stmt.Err())
}()
err = panics()
if err != nil {
t.Error(err)
}
}
func TestConn_Savepoint_interrupt(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
savept := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
}
savept.Release(&err)
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
savept1 := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (2)`)
if err != nil {
t.Fatal(err)
}
savept2 := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (3)`)
if err != nil {
t.Fatal(err)
}
cancel()
db.Savepoint().Release(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = db.Exec(`INSERT INTO test VALUES (4)`)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = context.Canceled
savept2.Release(&err)
if err != context.Canceled {
t.Fatal(err)
}
err = nil
savept1.Release(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
db.SetInterrupt(context.Background())
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
}
err = stmt.Err()
if err != nil {
t.Error(err)
}
}
func TestConn_Savepoint_rollback(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
savept := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`COMMIT`)
if err != nil {
t.Fatal(err)
}
savept.Release(&err)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
t.Errorf("got %d, want 1", got)
}
}
err = stmt.Err()
if err != nil {
t.Error(err)
}
}

View File

@@ -1,19 +1,23 @@
package sqlite3
package tests
import "testing"
import (
"testing"
"github.com/ncruces/go-sqlite3"
)
func TestDatatype_String(t *testing.T) {
t.Parallel()
tests := []struct {
data Datatype
data sqlite3.Datatype
want string
}{
{INTEGER, "INTEGER"},
{FLOAT, "FLOAT"},
{TEXT, "TEXT"},
{BLOB, "BLOB"},
{NULL, "NULL"},
{sqlite3.INTEGER, "INTEGER"},
{sqlite3.FLOAT, "FLOAT"},
{sqlite3.TEXT, "TEXT"},
{sqlite3.BLOB, "BLOB"},
{sqlite3.NULL, "NULL"},
{10, "10"},
}
for _, tt := range tests {

107
time.go
View File

@@ -6,11 +6,15 @@ import (
"strings"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/julianday"
)
// TimeFormat specifies how to encode/decode time values.
//
// See the documentation for the [TimeFormatDefault] constant
// for formats recognized by SQLite.
//
// https://www.sqlite.org/lang_datefunc.html
type TimeFormat string
@@ -57,12 +61,31 @@ const (
// Encode encodes a time value using this format.
//
// [TimeFormatDefault] and [TimeFormatAuto] encode using [time.RFC3339Nano],
// with nanosecond accuracy, and preserving timezone.
// with nanosecond accuracy, and preserving any timezone offset.
//
// Formats that don't record the timezone
// This is the format used by the [database/sql] driver:
// [database/sql.Row.Scan] will decode as [time.Time]
// values encoded with [time.RFC3339Nano].
//
// Time values encoded with [time.RFC3339Nano] cannot be sorted as strings
// to produce a time-ordered sequence.
//
// Assuming that the time zones of the time values are the same (e.g., all in UTC),
// and expressed using the same string (e.g., all "Z" or all "+00:00"),
// use the TIME [collating sequence] to produce a time-ordered sequence.
//
// Otherwise, use [TimeFormat7] for time-ordered encoding.
//
// Formats [TimeFormat1] through [TimeFormat10]
// convert time values to UTC before encoding.
//
// Returns a string for the text formats,
// a float64 for [TimeFormatJulianDay] and [TimeFormatUnixFrac],
// or an int64 for the other numeric formats.
//
// https://www.sqlite.org/lang_datefunc.html
//
// [collating sequence]: https://www.sqlite.org/datatype3.html#collating_sequences
func (f TimeFormat) Encode(t time.Time) any {
switch f {
// Numeric formats
@@ -81,11 +104,13 @@ func (f TimeFormat) Encode(t time.Time) any {
// Special formats
case TimeFormatDefault, TimeFormatAuto:
f = time.RFC3339Nano
}
// SQLite assumes UTC if unspecified.
if !strings.Contains(string(f), "MST") &&
!strings.Contains(string(f), "Z07") &&
!strings.Contains(string(f), "-07") {
case
TimeFormat1, TimeFormat2,
TimeFormat3, TimeFormat4,
TimeFormat5, TimeFormat6,
TimeFormat7, TimeFormat8,
TimeFormat9, TimeFormat10:
t = t.UTC()
}
return t.Format(string(f))
@@ -93,8 +118,23 @@ func (f TimeFormat) Encode(t time.Time) any {
// Decode decodes a time value using this format.
//
// Decoding of SQLite recognized formats is lenient:
// timezones and fractional seconds are always optional.
// The time value can be a string, an int64, or a float64.
//
// Formats [TimeFormat8] through [TimeFormat10]
// (and [TimeFormat8TZ] through [TimeFormat10TZ])
// assume a date of 2000-01-01.
//
// The timezone indicator and fractional seconds are always optional
// for formats [TimeFormat2] through [TimeFormat10]
// (and [TimeFormat2TZ] through [TimeFormat10TZ]).
//
// [TimeFormatAuto] implements (and extends) the SQLite auto modifier.
// Julian day numbers are safe to use for historical dates,
// from 4712BC through 9999AD.
// Unix timestamps (expressed in seconds, milliseconds, microseconds, or nanoseconds)
// are safe to use for current events, from at least 1980 through at least 2260.
// Unix timestamps before 1980 and after 9999 may be misinterpreted as julian day numbers,
// or have the wrong time unit.
//
// https://www.sqlite.org/lang_datefunc.html
func (f TimeFormat) Decode(v any) (time.Time, error) {
@@ -109,7 +149,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return julianday.Time(v, 0), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case TimeFormatUnix, TimeFormatUnixFrac:
@@ -128,7 +168,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return time.Unix(v, 0), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case TimeFormatUnixMilli:
@@ -145,7 +185,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return time.UnixMilli(int64(v)), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case TimeFormatUnixMicro:
@@ -162,14 +202,14 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return time.UnixMicro(int64(v)), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case TimeFormatUnixNano:
if s, ok := v.(string); ok {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
v = i
}
@@ -179,7 +219,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return time.Unix(0, int64(v)), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
// Special formats
@@ -249,7 +289,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
return TimeFormatUnixNano.Decode(v)
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case
@@ -261,16 +301,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
TimeFormat7, TimeFormat7TZ:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
f := string(f)
f = strings.TrimSuffix(f, "Z07:00")
f = strings.TrimSuffix(f, ".000")
t, err := time.Parse(f+"Z07:00", s)
if err != nil {
t, err = time.Parse(f, s)
}
return t, err
return f.parseRelaxed(s)
case
TimeFormat8, TimeFormat8TZ,
@@ -278,26 +311,30 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
TimeFormat10, TimeFormat10TZ:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
}
f := string(f)
f = strings.TrimSuffix(f, "Z07:00")
f = strings.TrimSuffix(f, ".000")
t, err := time.Parse(f+"Z07:00", s)
if err != nil {
t, err = time.Parse(f, s)
return time.Time{}, util.TimeErr
}
t, err := f.parseRelaxed(s)
return t.AddDate(2000, 0, 0), err
default:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
f := string(f)
if f == "" {
f = time.RFC3339Nano
}
return time.Parse(f, s)
return time.Parse(string(f), s)
}
}
func (f TimeFormat) parseRelaxed(s string) (time.Time, error) {
fs := string(f)
fs = strings.TrimSuffix(fs, "Z07:00")
fs = strings.TrimSuffix(fs, ".000")
t, err := time.Parse(fs+"Z07:00", s)
if err != nil {
return time.Parse(fs, s)
}
return t, nil
}

View File

@@ -1,118 +0,0 @@
package sqlite3
import (
"reflect"
"testing"
"time"
)
func TestTimeFormat_Encode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt TimeFormat
time time.Time
want any
}{
{TimeFormatDefault, reference, "2013-10-07T04:23:19.12-04:00"},
{TimeFormatJulianDay, reference, 2456572.849526851851852},
{TimeFormatUnix, reference, int64(1381134199)},
{TimeFormatUnixFrac, reference, 1381134199.120},
{TimeFormatUnixMilli, reference, int64(1381134199_120)},
{TimeFormatUnixMicro, reference, int64(1381134199_120000)},
{TimeFormatUnixNano, reference, int64(1381134199_120000000)},
{TimeFormat7, reference, "2013-10-07T08:23:19.120"},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
if got := tt.fmt.Encode(tt.time); !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q.Encode(%v) = %v, want %v", tt.fmt, tt.time, got, tt.want)
}
})
}
}
func TestTimeFormat_Decode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
reftime := time.Date(2000, 1, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt TimeFormat
val any
want time.Time
wantDelta time.Duration
wantErr bool
}{
{TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
{TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
{TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
{TimeFormatJulianDay, false, time.Time{}, 0, true},
{TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
{TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
{TimeFormatUnix, int64(1381134199), reference, time.Second, false},
{TimeFormatUnix, "abc", time.Time{}, 0, true},
{TimeFormatUnix, false, time.Time{}, 0, true},
{TimeFormatUnixMilli, "1381134199120", reference, 0, false},
{TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
{TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
{TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
{TimeFormatUnixMilli, false, time.Time{}, 0, true},
{TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
{TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
{TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
{TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
{TimeFormatUnixMicro, false, time.Time{}, 0, true},
{TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
{TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
{TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
{TimeFormatUnixNano, "abc", time.Time{}, 0, true},
{TimeFormatUnixNano, false, time.Time{}, 0, true},
{TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
{TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
{TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199", reference, time.Second, false},
{TimeFormatAuto, "1381134199120", reference, 0, false},
{TimeFormatAuto, "1381134199120000", reference, 0, false},
{TimeFormatAuto, "1381134199120000000", reference, 0, false},
{TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{TimeFormatAuto, "04:23:19.12-04:00", reftime, 0, false},
{TimeFormatAuto, "abc", time.Time{}, 0, true},
{TimeFormatAuto, false, time.Time{}, 0, true},
{TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
{TimeFormat9, "04:23:19.12-04:00", reftime, 0, false},
{TimeFormat9, "08:23:19.12", reftime, 0, false},
{TimeFormat3, false, time.Time{}, 0, true},
{TimeFormat9, false, time.Time{}, 0, true},
{TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
{TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
{TimeFormatDefault, false, time.Time{}, 0, true},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
got, err := tt.fmt.Decode(tt.val)
if (err != nil) != tt.wantErr {
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
return
}
if tt.want.Sub(got).Abs() > tt.wantDelta {
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
}
})
}
}

202
tx.go Normal file
View File

@@ -0,0 +1,202 @@
package sqlite3
import (
"context"
"errors"
"fmt"
"math/rand"
"runtime"
"strconv"
)
// Tx is an in-progress database transaction.
//
// https://www.sqlite.org/lang_transaction.html
type Tx struct {
c *Conn
}
// Begin starts a deferred transaction.
//
// https://www.sqlite.org/lang_transaction.html
func (c *Conn) Begin() Tx {
// BEGIN even if interrupted.
err := c.txExecInterrupted(`BEGIN DEFERRED`)
if err != nil {
panic(err)
}
return Tx{c}
}
// BeginImmediate starts an immediate transaction.
//
// https://www.sqlite.org/lang_transaction.html
func (c *Conn) BeginImmediate() (Tx, error) {
err := c.Exec(`BEGIN IMMEDIATE`)
if err != nil {
return Tx{}, err
}
return Tx{c}, nil
}
// BeginExclusive starts an exclusive transaction.
//
// https://www.sqlite.org/lang_transaction.html
func (c *Conn) BeginExclusive() (Tx, error) {
err := c.Exec(`BEGIN EXCLUSIVE`)
if err != nil {
return Tx{}, err
}
return Tx{c}, nil
}
// End calls either [Tx.Commit] or [Tx.Rollback]
// depending on whether *error points to a nil or non-nil error.
//
// This is meant to be deferred:
//
// func doWork(conn *sqlite3.Conn) (err error) {
// tx := conn.Begin()
// defer tx.End(&err)
//
// // ... do work in the transaction
// }
//
// https://www.sqlite.org/lang_transaction.html
func (tx Tx) End(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
if (errp == nil || *errp == nil) && recovered == nil {
// Success path.
if tx.c.GetAutocommit() { // There is nothing to commit.
return
}
*errp = tx.Commit()
if *errp == nil {
return
}
// Fall through to the error path.
}
// Error path.
if tx.c.GetAutocommit() { // There is nothing to rollback.
return
}
err := tx.Rollback()
if err != nil {
panic(err)
}
}
// Commit commits the transaction.
//
// https://www.sqlite.org/lang_transaction.html
func (tx Tx) Commit() error {
return tx.c.Exec(`COMMIT`)
}
// Rollback rolls back the transaction,
// even if the connection has been interrupted.
//
// https://www.sqlite.org/lang_transaction.html
func (tx Tx) Rollback() error {
return tx.c.txExecInterrupted(`ROLLBACK`)
}
// Savepoint is a marker within a transaction
// that allows for partial rollback.
//
// https://www.sqlite.org/lang_savepoint.html
type Savepoint struct {
c *Conn
name string
}
// Savepoint establishes a new transaction savepoint.
//
// https://www.sqlite.org/lang_savepoint.html
func (c *Conn) Savepoint() Savepoint {
name := "sqlite3.Savepoint"
var pc [1]uintptr
if n := runtime.Callers(2, pc[:]); n > 0 {
frames := runtime.CallersFrames(pc[:n])
frame, _ := frames.Next()
if frame.Function != "" {
name = frame.Function
}
}
// Names can be reused; this makes catching bugs more likely.
name += "#" + strconv.Itoa(int(rand.Int31()))
err := c.txExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name))
if err != nil {
panic(err)
}
return Savepoint{c: c, name: name}
}
// Release releases the savepoint rolling back any changes
// if *error points to a non-nil error.
//
// This is meant to be deferred:
//
// func doWork(conn *sqlite3.Conn) (err error) {
// savept := conn.Savepoint()
// defer savept.Release(&err)
//
// // ... do work in the transaction
// }
func (s Savepoint) Release(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
if (errp == nil || *errp == nil) && recovered == nil {
// Success path.
if s.c.GetAutocommit() { // There is nothing to commit.
return
}
*errp = s.c.Exec(fmt.Sprintf("RELEASE %q;", s.name))
if *errp == nil {
return
}
// Fall through to the error path.
}
// Error path.
if s.c.GetAutocommit() { // There is nothing to rollback.
return
}
// ROLLBACK and RELEASE even if interrupted.
err := s.c.txExecInterrupted(fmt.Sprintf(`
ROLLBACK TO %[1]q;
RELEASE %[1]q;
`, s.name))
if err != nil {
panic(err)
}
}
// Rollback rolls the transaction back to the savepoint,
// even if the connection has been interrupted.
// Rollback does not release the savepoint.
//
// https://www.sqlite.org/lang_transaction.html
func (s Savepoint) Rollback() error {
// ROLLBACK even if interrupted.
return s.c.txExecInterrupted(fmt.Sprintf("ROLLBACK TO %q;", s.name))
}
func (c *Conn) txExecInterrupted(sql string) error {
err := c.Exec(sql)
if errors.Is(err, INTERRUPT) {
old := c.SetInterrupt(context.Background())
defer c.SetInterrupt(old)
err = c.Exec(sql)
}
return err
}

16
util.go
View File

@@ -1,16 +0,0 @@
package sqlite3
// Return true if stmt is an empty SQL statement.
// This is used as an optimization.
// It's OK to always return false here.
func emptyStatement(stmt string) bool {
for _, b := range []byte(stmt) {
switch b {
case ' ', '\n', '\r', '\t', '\v', '\f':
case ';':
default:
return false
}
}
return true
}

321
vfs.go
View File

@@ -1,321 +0,0 @@
package sqlite3
import (
"context"
"crypto/rand"
"errors"
"io"
"io/fs"
"os"
"path/filepath"
"runtime"
"syscall"
"time"
"github.com/ncruces/julianday"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/sys"
)
func vfsInstantiate(ctx context.Context, r wazero.Runtime) {
wasi := r.NewHostModuleBuilder("wasi_snapshot_preview1")
wasi.NewFunctionBuilder().WithFunc(vfsExit).Export("proc_exit")
_, err := wasi.Instantiate(ctx)
if err != nil {
panic(err)
}
env := r.NewHostModuleBuilder("env")
env.NewFunctionBuilder().WithFunc(vfsLocaltime).Export("go_localtime")
env.NewFunctionBuilder().WithFunc(vfsRandomness).Export("go_randomness")
env.NewFunctionBuilder().WithFunc(vfsSleep).Export("go_sleep")
env.NewFunctionBuilder().WithFunc(vfsCurrentTime).Export("go_current_time")
env.NewFunctionBuilder().WithFunc(vfsCurrentTime64).Export("go_current_time_64")
env.NewFunctionBuilder().WithFunc(vfsFullPathname).Export("go_full_pathname")
env.NewFunctionBuilder().WithFunc(vfsDelete).Export("go_delete")
env.NewFunctionBuilder().WithFunc(vfsAccess).Export("go_access")
env.NewFunctionBuilder().WithFunc(vfsOpen).Export("go_open")
env.NewFunctionBuilder().WithFunc(vfsClose).Export("go_close")
env.NewFunctionBuilder().WithFunc(vfsRead).Export("go_read")
env.NewFunctionBuilder().WithFunc(vfsWrite).Export("go_write")
env.NewFunctionBuilder().WithFunc(vfsTruncate).Export("go_truncate")
env.NewFunctionBuilder().WithFunc(vfsSync).Export("go_sync")
env.NewFunctionBuilder().WithFunc(vfsFileSize).Export("go_file_size")
env.NewFunctionBuilder().WithFunc(vfsLock).Export("go_lock")
env.NewFunctionBuilder().WithFunc(vfsUnlock).Export("go_unlock")
env.NewFunctionBuilder().WithFunc(vfsCheckReservedLock).Export("go_check_reserved_lock")
env.NewFunctionBuilder().WithFunc(vfsFileControl).Export("go_file_control")
_, err = env.Instantiate(ctx)
if err != nil {
panic(err)
}
}
func vfsExit(ctx context.Context, mod api.Module, exitCode uint32) {
// Ensure other callers see the exit code.
_ = mod.CloseWithExitCode(ctx, exitCode)
// Prevent any code from executing after this function.
panic(sys.NewExitError(mod.Name(), exitCode))
}
func vfsLocaltime(ctx context.Context, mod api.Module, t uint64, pTm uint32) uint32 {
tm := time.Unix(int64(t), 0)
var isdst int
if tm.IsDST() {
isdst = 1
}
// https://pubs.opengroup.org/onlinepubs/7908799/xsh/time.h.html
mem := memory{mod}
mem.writeUint32(pTm+0*ptrlen, uint32(tm.Second()))
mem.writeUint32(pTm+1*ptrlen, uint32(tm.Minute()))
mem.writeUint32(pTm+2*ptrlen, uint32(tm.Hour()))
mem.writeUint32(pTm+3*ptrlen, uint32(tm.Day()))
mem.writeUint32(pTm+4*ptrlen, uint32(tm.Month()-time.January))
mem.writeUint32(pTm+5*ptrlen, uint32(tm.Year()-1900))
mem.writeUint32(pTm+6*ptrlen, uint32(tm.Weekday()-time.Sunday))
mem.writeUint32(pTm+7*ptrlen, uint32(tm.YearDay()-1))
mem.writeUint32(pTm+8*ptrlen, uint32(isdst))
return _OK
}
func vfsRandomness(ctx context.Context, mod api.Module, pVfs, nByte, zByte uint32) uint32 {
mem := memory{mod}.view(zByte, nByte)
n, _ := rand.Reader.Read(mem)
return uint32(n)
}
func vfsSleep(ctx context.Context, pVfs, nMicro uint32) uint32 {
time.Sleep(time.Duration(nMicro) * time.Microsecond)
return _OK
}
func vfsCurrentTime(ctx context.Context, mod api.Module, pVfs, prNow uint32) uint32 {
day := julianday.Float(time.Now())
memory{mod}.writeFloat64(prNow, day)
return _OK
}
func vfsCurrentTime64(ctx context.Context, mod api.Module, pVfs, piNow uint32) uint32 {
day, nsec := julianday.Date(time.Now())
msec := day*86_400_000 + nsec/1_000_000
memory{mod}.writeUint64(piNow, uint64(msec))
return _OK
}
func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull, zFull uint32) uint32 {
rel := memory{mod}.readString(zRelative, _MAX_PATHNAME)
abs, err := filepath.Abs(rel)
if err != nil {
return uint32(IOERR)
}
// Consider either using [filepath.EvalSymlinks] to canonicalize the path (as the Unix VFS does).
// Or using [os.Readlink] to resolve a symbolic link (as the Unix VFS did).
// This might be buggy on Windows (the Windows VFS doesn't try).
size := uint32(len(abs) + 1)
if size > nFull {
return uint32(CANTOPEN_FULLPATH)
}
mem := memory{mod}.view(zFull, size)
mem[len(abs)] = 0
copy(mem, abs)
return _OK
}
func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32) uint32 {
path := memory{mod}.readString(zPath, _MAX_PATHNAME)
err := os.Remove(path)
if errors.Is(err, fs.ErrNotExist) {
return _OK
}
if err != nil {
return uint32(IOERR_DELETE)
}
if runtime.GOOS != "windows" && syncDir != 0 {
f, err := os.Open(filepath.Dir(path))
if err == nil {
err = f.Sync()
f.Close()
}
if err != nil {
return uint32(IOERR_DELETE)
}
}
return _OK
}
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags _AccessFlag, pResOut uint32) uint32 {
// Consider using [syscall.Access] for [ACCESS_READWRITE]/[ACCESS_READ]
// (as the Unix VFS does).
path := memory{mod}.readString(zPath, _MAX_PATHNAME)
fi, err := os.Stat(path)
var res uint32
switch {
case flags == _ACCESS_EXISTS:
switch {
case err == nil:
res = 1
case errors.Is(err, fs.ErrNotExist):
res = 0
default:
return uint32(IOERR_ACCESS)
}
case err == nil:
var want fs.FileMode = syscall.S_IRUSR
if flags == _ACCESS_READWRITE {
want |= syscall.S_IWUSR
}
if fi.IsDir() {
want |= syscall.S_IXUSR
}
if fi.Mode()&want == want {
res = 1
} else {
res = 0
}
case errors.Is(err, fs.ErrPermission):
res = 0
default:
return uint32(IOERR_ACCESS)
}
memory{mod}.writeUint32(pResOut, res)
return _OK
}
func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, flags OpenFlag, pOutFlags uint32) uint32 {
var oflags int
if flags&OPEN_EXCLUSIVE != 0 {
oflags |= os.O_EXCL
}
if flags&OPEN_CREATE != 0 {
oflags |= os.O_CREATE
}
if flags&OPEN_READONLY != 0 {
oflags |= os.O_RDONLY
}
if flags&OPEN_READWRITE != 0 {
oflags |= os.O_RDWR
}
var err error
var file *os.File
if zName == 0 {
file, err = os.CreateTemp("", "*.db")
} else {
name := memory{mod}.readString(zName, _MAX_PATHNAME)
file, err = os.OpenFile(name, oflags, 0600)
}
if err != nil {
return uint32(CANTOPEN)
}
if flags&OPEN_DELETEONCLOSE != 0 {
deleteOnClose(file)
}
info, err := file.Stat()
if err != nil {
return uint32(CANTOPEN)
}
if info.IsDir() {
return uint32(CANTOPEN_ISDIR)
}
id := vfsGetOpenFileID(file, info)
vfsFilePtr{mod, pFile}.SetID(id).SetLock(_NO_LOCK)
if pOutFlags != 0 {
memory{mod}.writeUint32(pOutFlags, uint32(flags))
}
return _OK
}
func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 {
id := vfsFilePtr{mod, pFile}.ID()
err := vfsReleaseOpenFile(id)
if err != nil {
return uint32(IOERR_CLOSE)
}
return _OK
}
func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 {
buf := memory{mod}.view(zBuf, iAmt)
file := vfsFilePtr{mod, pFile}.OSFile()
n, err := file.ReadAt(buf, int64(iOfst))
if n == int(iAmt) {
return _OK
}
if n == 0 && err != io.EOF {
return uint32(IOERR_READ)
}
for i := range buf[n:] {
buf[n+i] = 0
}
return uint32(IOERR_SHORT_READ)
}
func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 {
buf := memory{mod}.view(zBuf, iAmt)
file := vfsFilePtr{mod, pFile}.OSFile()
_, err := file.WriteAt(buf, int64(iOfst))
if err != nil {
return uint32(IOERR_WRITE)
}
return _OK
}
func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte uint64) uint32 {
file := vfsFilePtr{mod, pFile}.OSFile()
err := file.Truncate(int64(nByte))
if err != nil {
return uint32(IOERR_TRUNCATE)
}
return _OK
}
func vfsSync(ctx context.Context, mod api.Module, pFile, flags uint32) uint32 {
file := vfsFilePtr{mod, pFile}.OSFile()
err := file.Sync()
if err != nil {
return uint32(IOERR_FSYNC)
}
return _OK
}
func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) uint32 {
// This uses [os.File.Seek] because we don't care about the offset for reading/writing.
// But consider using [os.File.Stat] instead (as other VFSes do).
file := vfsFilePtr{mod, pFile}.OSFile()
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return uint32(IOERR_SEEK)
}
memory{mod}.writeUint64(pSize, uint64(off))
return _OK
}
func vfsFileControl(ctx context.Context, pFile, op, pArg uint32) uint32 {
// SQLite calls vfsFileControl with these opcodes:
// SQLITE_FCNTL_SIZE_HINT
// SQLITE_FCNTL_PRAGMA
// SQLITE_FCNTL_BUSYHANDLER
// SQLITE_FCNTL_HAS_MOVED
// SQLITE_FCNTL_SYNC
// SQLITE_FCNTL_COMMIT_PHASETWO
// SQLITE_FCNTL_PDB
return uint32(NOTFOUND)
}

View File

@@ -1,107 +0,0 @@
package sqlite3
import (
"os"
"sync"
"github.com/tetratelabs/wazero/api"
)
type vfsOpenFile struct {
file *os.File
info os.FileInfo
nref int
locker vfsFileLocker
}
var (
vfsOpenFiles []*vfsOpenFile
vfsOpenFilesMtx sync.Mutex
)
func vfsGetOpenFileID(file *os.File, info os.FileInfo) uint32 {
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
// Reuse an already opened file.
for id, of := range vfsOpenFiles {
if of == nil {
continue
}
if os.SameFile(info, of.info) {
of.nref++
_ = file.Close()
return uint32(id)
}
}
of := &vfsOpenFile{
file: file,
info: info,
nref: 1,
locker: vfsFileLocker{file: file},
}
// Find an empty slot.
for id, ptr := range vfsOpenFiles {
if ptr == nil {
vfsOpenFiles[id] = of
return uint32(id)
}
}
// Add a new slot.
id := len(vfsOpenFiles)
vfsOpenFiles = append(vfsOpenFiles, of)
return uint32(id)
}
func vfsReleaseOpenFile(id uint32) error {
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
of := vfsOpenFiles[id]
if of.nref--; of.nref > 0 {
return nil
}
err := of.file.Close()
vfsOpenFiles[id] = nil
return err
}
type vfsFilePtr struct {
api.Module
ptr uint32
}
func (p vfsFilePtr) OSFile() *os.File {
id := p.ID()
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
return vfsOpenFiles[id].file
}
func (p vfsFilePtr) Locker() *vfsFileLocker {
id := p.ID()
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
return &vfsOpenFiles[id].locker
}
func (p vfsFilePtr) ID() uint32 {
return memory{p}.readUint32(p.ptr + ptrlen)
}
func (p vfsFilePtr) Lock() vfsLockState {
return vfsLockState(memory{p}.readUint32(p.ptr + 2*ptrlen))
}
func (p vfsFilePtr) SetID(id uint32) vfsFilePtr {
memory{p}.writeUint32(p.ptr+ptrlen, id)
return p
}
func (p vfsFilePtr) SetLock(lock vfsLockState) vfsFilePtr {
memory{p}.writeUint32(p.ptr+2*ptrlen, uint32(lock))
return p
}

View File

@@ -1,276 +0,0 @@
package sqlite3
import (
"context"
"os"
"sync"
"github.com/tetratelabs/wazero/api"
)
const (
// No locks are held on the database.
// The database may be neither read nor written.
// Any internally cached data is considered suspect and subject to
// verification against the database file before being used.
// Other processes can read or write the database as their own locking
// states permit.
// This is the default state.
_NO_LOCK = 0
// The database may be read but not written.
// Any number of processes can hold SHARED locks at the same time,
// hence there can be many simultaneous readers.
// But no other thread or process is allowed to write to the database file
// while one or more SHARED locks are active.
_SHARED_LOCK = 1
// A RESERVED lock means that the process is planning on writing to the
// database file at some point in the future but that it is currently just
// reading from the file.
// Only a single RESERVED lock may be active at one time,
// though multiple SHARED locks can coexist with a single RESERVED lock.
// RESERVED differs from PENDING in that new SHARED locks can be acquired
// while there is a RESERVED lock.
_RESERVED_LOCK = 2
// A PENDING lock means that the process holding the lock wants to write to
// the database as soon as possible and is just waiting on all current
// SHARED locks to clear so that it can get an EXCLUSIVE lock.
// No new SHARED locks are permitted against the database if a PENDING lock
// is active, though existing SHARED locks are allowed to continue.
_PENDING_LOCK = 3
// An EXCLUSIVE lock is needed in order to write to the database file.
// Only one EXCLUSIVE lock is allowed on the file and no other locks of any
// kind are allowed to coexist with an EXCLUSIVE lock.
// In order to maximize concurrency, SQLite works to minimize the amount of
// time that EXCLUSIVE locks are held.
_EXCLUSIVE_LOCK = 4
_PENDING_BYTE = 0x40000000
_RESERVED_BYTE = (_PENDING_BYTE + 1)
_SHARED_FIRST = (_PENDING_BYTE + 2)
_SHARED_SIZE = 510
)
type vfsLockState uint32
type vfsFileLocker struct {
sync.Mutex
file *os.File
state vfsLockState
shared int
}
func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) uint32 {
// Argument check. SQLite never explicitly requests a pendig lock.
if eLock != _SHARED_LOCK && eLock != _RESERVED_LOCK && eLock != _EXCLUSIVE_LOCK {
panic(assertErr())
}
ptr := vfsFilePtr{mod, pFile}
cLock := ptr.Lock()
switch {
case cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK:
// Connection state check.
panic(assertErr())
case cLock == _NO_LOCK && eLock > _SHARED_LOCK:
// We never move from unlocked to anything higher than a shared lock.
panic(assertErr())
case cLock != _SHARED_LOCK && eLock == _RESERVED_LOCK:
// A shared lock is always held when a reserved lock is requested.
panic(assertErr())
}
// If we already have an equal or more restrictive lock, do nothing.
if cLock >= eLock {
return _OK
}
fLock := ptr.Locker()
fLock.Lock()
defer fLock.Unlock()
// File state check.
switch {
case fLock.state < _NO_LOCK || fLock.state > _EXCLUSIVE_LOCK:
panic(assertErr())
case fLock.state == _NO_LOCK && fLock.shared != 0:
panic(assertErr())
case fLock.state == _EXCLUSIVE_LOCK && fLock.shared != 1:
panic(assertErr())
case fLock.state != _NO_LOCK && fLock.shared <= 0:
panic(assertErr())
case fLock.state < cLock:
panic(assertErr())
}
// If some other connection has a lock that precludes the requested lock, return BUSY.
if cLock != fLock.state && (eLock > _SHARED_LOCK || fLock.state >= _PENDING_LOCK) {
return uint32(BUSY)
}
switch eLock {
case _SHARED_LOCK:
// Test the PENDING lock before acquiring a new SHARED lock.
if locked, _ := fLock.CheckPending(); locked {
return uint32(BUSY)
}
// If some other connection has a SHARED or RESERVED lock,
// increment the reference count and return OK.
if fLock.state == _SHARED_LOCK || fLock.state == _RESERVED_LOCK {
ptr.SetLock(_SHARED_LOCK)
fLock.shared++
return _OK
}
// Must be unlocked to get SHARED.
if fLock.state != _NO_LOCK {
panic(assertErr())
}
if rc := fLock.GetShared(); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_SHARED_LOCK)
fLock.state = _SHARED_LOCK
fLock.shared = 1
return _OK
case _RESERVED_LOCK:
// Must be SHARED to get RESERVED.
if fLock.state != _SHARED_LOCK {
panic(assertErr())
}
if rc := fLock.GetReserved(); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_RESERVED_LOCK)
fLock.state = _RESERVED_LOCK
return _OK
case _EXCLUSIVE_LOCK:
// Must be SHARED, PENDING or RESERVED to get EXCLUSIVE.
if fLock.state <= _NO_LOCK || fLock.state >= _EXCLUSIVE_LOCK {
panic(assertErr())
}
// A PENDING lock is needed before acquiring an EXCLUSIVE lock.
if fLock.state == _RESERVED_LOCK {
if rc := fLock.GetPending(); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_PENDING_LOCK)
fLock.state = _PENDING_LOCK
}
// We are trying for an EXCLUSIVE lock but another connection is still holding a shared lock.
if fLock.shared > 1 {
return uint32(BUSY)
}
if rc := fLock.GetExclusive(); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_EXCLUSIVE_LOCK)
fLock.state = _EXCLUSIVE_LOCK
return _OK
default:
panic(assertErr())
}
}
func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) uint32 {
// Argument check.
if eLock != _NO_LOCK && eLock != _SHARED_LOCK {
panic(assertErr())
}
ptr := vfsFilePtr{mod, pFile}
cLock := ptr.Lock()
// Connection state check.
if cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK {
panic(assertErr())
}
// If we don't have a more restrictive lock, do nothing.
if cLock <= eLock {
return _OK
}
fLock := ptr.Locker()
fLock.Lock()
defer fLock.Unlock()
// File state check.
switch {
case fLock.state <= _NO_LOCK || fLock.state > _EXCLUSIVE_LOCK:
panic(assertErr())
case fLock.state == _EXCLUSIVE_LOCK && fLock.shared != 1:
panic(assertErr())
case fLock.shared <= 0:
panic(assertErr())
case fLock.state < cLock:
panic(assertErr())
}
if cLock > _SHARED_LOCK {
// The connection must own the lock to release it.
if cLock != fLock.state {
panic(assertErr())
}
if eLock == _SHARED_LOCK {
if rc := fLock.Downgrade(); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_SHARED_LOCK)
fLock.state = _SHARED_LOCK
return _OK
}
}
// If we get here, make sure we're dropping all locks.
if eLock != _NO_LOCK {
panic(assertErr())
}
// Release the connection lock and decrement the shared lock counter.
// Release the file lock only when all connections have released the lock.
ptr.SetLock(_NO_LOCK)
if fLock.shared--; fLock.shared == 0 {
rc := fLock.Release()
fLock.state = _NO_LOCK
return uint32(rc)
}
return _OK
}
func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) uint32 {
ptr := vfsFilePtr{mod, pFile}
cLock := ptr.Lock()
if cLock > _SHARED_LOCK {
panic(assertErr())
}
fLock := ptr.Locker()
fLock.Lock()
defer fLock.Unlock()
if fLock.state >= _RESERVED_LOCK {
memory{mod}.writeUint32(pResOut, 1)
return _OK
}
locked, rc := fLock.CheckReserved()
var res uint32
if locked {
res = 1
}
memory{mod}.writeUint32(pResOut, res)
return uint32(rc)
}

View File

@@ -1,122 +0,0 @@
package sqlite3
import (
"context"
"os"
"path/filepath"
"runtime"
"testing"
)
func Test_vfsLock(t *testing.T) {
// Other OSes lack open file descriptors locks.
switch runtime.GOOS {
case "linux", "darwin", "illumos", "windows":
//
default:
t.Skip()
}
name := filepath.Join(t.TempDir(), "test.db")
// Create a temporary file.
file1, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
if err != nil {
t.Fatal(err)
}
defer file1.Close()
// Open the temporary file again.
file2, err := os.OpenFile(name, os.O_RDWR, 0)
if err != nil {
t.Fatal(err)
}
defer file2.Close()
// Bypass open file reuse.
vfsOpenFiles = append(vfsOpenFiles, &vfsOpenFile{
file: file1,
nref: 1,
locker: vfsFileLocker{file: file1},
}, &vfsOpenFile{
file: file2,
nref: 1,
locker: vfsFileLocker{file: file2},
})
mem := newMemory(128)
mem.writeUint32(4+4, 0)
mem.writeUint32(16+4, 1)
const (
pFile1 = 4
pFile2 = 16
pOutput = 32
)
rc := vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile2, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile2, _RESERVED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsLock(context.TODO(), mem.mod, pFile2, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile2, _EXCLUSIVE_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile1, _SHARED_LOCK)
if rc == _OK {
t.Fatal("returned", rc)
}
rc = vfsUnlock(context.TODO(), mem.mod, pFile2, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsLock(context.TODO(), mem.mod, pFile1, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
}

View File

@@ -1,271 +0,0 @@
package sqlite3
import (
"bytes"
"context"
"errors"
"io/fs"
"os"
"path/filepath"
"syscall"
"testing"
"time"
"github.com/ncruces/julianday"
)
func Test_vfsExit(t *testing.T) {
mem := newMemory(128)
defer func() { _ = recover() }()
vfsExit(context.TODO(), mem.mod, 1)
t.Error("want panic")
}
func Test_vfsLocaltime(t *testing.T) {
mem := newMemory(128)
rc := vfsLocaltime(context.TODO(), mem.mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
epoch := time.Unix(0, 0)
if s := mem.readUint32(4 + 0*4); int(s) != epoch.Second() {
t.Error("wrong second")
}
if m := mem.readUint32(4 + 1*4); int(m) != epoch.Minute() {
t.Error("wrong minute")
}
if h := mem.readUint32(4 + 2*4); int(h) != epoch.Hour() {
t.Error("wrong hour")
}
if d := mem.readUint32(4 + 3*4); int(d) != epoch.Day() {
t.Error("wrong day")
}
if m := mem.readUint32(4 + 4*4); time.Month(1+m) != epoch.Month() {
t.Error("wrong month")
}
if y := mem.readUint32(4 + 5*4); 1900+int(y) != epoch.Year() {
t.Error("wrong year")
}
if w := mem.readUint32(4 + 6*4); time.Weekday(w) != epoch.Weekday() {
t.Error("wrong weekday")
}
if d := mem.readUint32(4 + 7*4); int(d) != epoch.YearDay()-1 {
t.Error("wrong yearday")
}
}
func Test_vfsRandomness(t *testing.T) {
mem := newMemory(128)
rc := vfsRandomness(context.TODO(), mem.mod, 0, 16, 4)
if rc != 16 {
t.Fatal("returned", rc)
}
var zero [16]byte
if got := mem.view(4, 16); bytes.Equal(got, zero[:]) {
t.Fatal("all zero")
}
}
func Test_vfsSleep(t *testing.T) {
start := time.Now()
rc := vfsSleep(context.TODO(), 0, 123456)
if rc != 0 {
t.Fatal("returned", rc)
}
want := 123456 * time.Microsecond
if got := time.Since(start); got < want {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsCurrentTime(t *testing.T) {
mem := newMemory(128)
now := time.Now()
rc := vfsCurrentTime(context.TODO(), mem.mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
want := julianday.Float(now)
if got := mem.readFloat64(4); float32(got) != float32(want) {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsCurrentTime64(t *testing.T) {
mem := newMemory(128)
now := time.Now()
time.Sleep(time.Millisecond)
rc := vfsCurrentTime64(context.TODO(), mem.mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
day, nsec := julianday.Date(now)
want := day*86_400_000 + nsec/1_000_000
if got := mem.readUint64(4); float32(got) != float32(want) {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsFullPathname(t *testing.T) {
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(4, ".")
rc := vfsFullPathname(context.TODO(), mem.mod, 0, 4, 0, 8)
if rc != uint32(CANTOPEN_FULLPATH) {
t.Errorf("returned %d, want %d", rc, CANTOPEN_FULLPATH)
}
rc = vfsFullPathname(context.TODO(), mem.mod, 0, 4, _MAX_PATHNAME, 8)
if rc != _OK {
t.Fatal("returned", rc)
}
want, _ := filepath.Abs(".")
if got := mem.readString(8, _MAX_PATHNAME); got != want {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsDelete(t *testing.T) {
name := filepath.Join(t.TempDir(), "test.db")
file, err := os.Create(name)
if err != nil {
t.Fatal(err)
}
file.Close()
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(4, name)
rc := vfsDelete(context.TODO(), mem.mod, 0, 4, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
if _, err := os.Stat(name); !errors.Is(err, fs.ErrNotExist) {
t.Fatal("did not delete the file")
}
rc = vfsDelete(context.TODO(), mem.mod, 0, 4, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
}
func Test_vfsAccess(t *testing.T) {
dir := t.TempDir()
file := filepath.Join(t.TempDir(), "test.db")
if f, err := os.Create(file); err != nil {
t.Fatal(err)
} else {
f.Close()
}
if err := os.Chmod(file, syscall.S_IRUSR); err != nil {
t.Fatal(err)
}
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(8, dir)
rc := vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_EXISTS, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(4); got != 1 {
t.Error("directory did not exist")
}
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(4); got != 1 {
t.Error("can't access directory")
}
mem.writeString(8, file)
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(4); got != 0 {
t.Error("can access file")
}
}
func Test_vfsFile(t *testing.T) {
mem := newMemory(128)
// Open a temporary file.
rc := vfsOpen(context.TODO(), mem.mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
// Write stuff.
text := "Hello world!"
mem.writeString(16, text)
rc = vfsWrite(context.TODO(), mem.mod, 4, 16, uint32(len(text)), 0)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check file size.
rc = vfsFileSize(context.TODO(), mem.mod, 4, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(16); got != uint32(len(text)) {
t.Errorf("got %d", got)
}
// Partial read at offset.
rc = vfsRead(context.TODO(), mem.mod, 4, 16, uint32(len(text)), 4)
if rc != uint32(IOERR_SHORT_READ) {
t.Fatal("returned", rc)
}
if got := mem.readString(16, 64); got != text[4:] {
t.Errorf("got %q", got)
}
// Truncate the file.
rc = vfsTruncate(context.TODO(), mem.mod, 4, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check file size.
rc = vfsFileSize(context.TODO(), mem.mod, 4, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(16); got != 4 {
t.Errorf("got %d", got)
}
// Read at offset.
rc = vfsRead(context.TODO(), mem.mod, 4, 32, 4, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readString(32, 64); got != text[:4] {
t.Errorf("got %q", got)
}
// Close the file.
rc = vfsClose(context.TODO(), mem.mod, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
}

View File

@@ -1,158 +0,0 @@
//go:build unix
package sqlite3
import (
"os"
"runtime"
"syscall"
)
func deleteOnClose(f *os.File) {
_ = os.Remove(f.Name())
}
func (l *vfsFileLocker) GetShared() xErrorCode {
// Acquire the SHARED lock.
return l.readLock(_SHARED_FIRST, _SHARED_SIZE)
}
func (l *vfsFileLocker) GetReserved() xErrorCode {
// Acquire the RESERVED lock.
return l.writeLock(_RESERVED_BYTE, 1)
}
func (l *vfsFileLocker) GetPending() xErrorCode {
// Acquire the PENDING lock.
return l.writeLock(_PENDING_BYTE, 1)
}
func (l *vfsFileLocker) GetExclusive() xErrorCode {
// Acquire the EXCLUSIVE lock.
return l.writeLock(_SHARED_FIRST, _SHARED_SIZE)
}
func (l *vfsFileLocker) Downgrade() xErrorCode {
if l.state >= _EXCLUSIVE_LOCK {
// Downgrade to a SHARED lock.
if rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return l.unlock(_PENDING_BYTE, 2)
}
func (l *vfsFileLocker) Release() xErrorCode {
// Release all locks.
return l.unlock(0, 0)
}
func (l *vfsFileLocker) CheckReserved() (bool, xErrorCode) {
// Test the RESERVED lock.
return l.checkLock(_RESERVED_BYTE, 1)
}
func (l *vfsFileLocker) CheckPending() (bool, xErrorCode) {
// Test the PENDING lock.
return l.checkLock(_PENDING_BYTE, 1)
}
func (l *vfsFileLocker) unlock(start, len int64) xErrorCode {
err := l.fcntlSetLock(&syscall.Flock_t{
Type: syscall.F_UNLCK,
Start: start,
Len: len,
})
if err != nil {
return IOERR_UNLOCK
}
return _OK
}
func (l *vfsFileLocker) readLock(start, len int64) xErrorCode {
return l.errorCode(l.fcntlSetLock(&syscall.Flock_t{
Type: syscall.F_RDLCK,
Start: start,
Len: len,
}), IOERR_RDLOCK)
}
func (l *vfsFileLocker) writeLock(start, len int64) xErrorCode {
return l.errorCode(l.fcntlSetLock(&syscall.Flock_t{
Type: syscall.F_WRLCK,
Start: start,
Len: len,
}), IOERR_LOCK)
}
func (l *vfsFileLocker) checkLock(start, len int64) (bool, xErrorCode) {
lock := syscall.Flock_t{
Type: syscall.F_RDLCK,
Start: start,
Len: len,
}
if l.fcntlGetLock(&lock) != nil {
return false, IOERR_CHECKRESERVEDLOCK
}
return lock.Type != syscall.F_UNLCK, _OK
}
func (l *vfsFileLocker) fcntlGetLock(lock *syscall.Flock_t) error {
F_GETLK := syscall.F_GETLK
switch runtime.GOOS {
case "linux":
// https://github.com/torvalds/linux/blob/master/include/uapi/asm-generic/fcntl.h
F_GETLK = 36 // F_OFD_GETLK
case "darwin":
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
F_GETLK = 92 // F_OFD_GETLK
case "illumos":
// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/fcntl.h
F_GETLK = 47 // F_OFD_GETLK
}
return syscall.FcntlFlock(l.file.Fd(), F_GETLK, lock)
}
func (l *vfsFileLocker) fcntlSetLock(lock *syscall.Flock_t) error {
F_SETLK := syscall.F_SETLK
switch runtime.GOOS {
case "linux":
// https://github.com/torvalds/linux/blob/master/include/uapi/asm-generic/fcntl.h
F_SETLK = 37 // F_OFD_SETLK
case "darwin":
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
F_SETLK = 90 // F_OFD_SETLK
case "illumos":
// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/fcntl.h
F_SETLK = 48 // F_OFD_SETLK
}
return syscall.FcntlFlock(l.file.Fd(), F_SETLK, lock)
}
func (*vfsFileLocker) errorCode(err error, def xErrorCode) xErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(syscall.Errno); ok {
switch errno {
case
syscall.EACCES,
syscall.EAGAIN,
syscall.EBUSY,
syscall.EINTR,
syscall.ENOLCK,
syscall.EDEADLK,
syscall.ETIMEDOUT:
return xErrorCode(BUSY)
case syscall.EPERM:
return xErrorCode(PERM)
}
}
return def
}

View File

@@ -1,127 +0,0 @@
package sqlite3
import (
"os"
"syscall"
"golang.org/x/sys/windows"
)
func deleteOnClose(f *os.File) {}
func (l *vfsFileLocker) GetShared() xErrorCode {
// Acquire the SHARED lock.
return l.readLock(_SHARED_FIRST, _SHARED_SIZE)
}
func (l *vfsFileLocker) GetReserved() xErrorCode {
// Acquire the RESERVED lock.
return l.writeLock(_RESERVED_BYTE, 1)
}
func (l *vfsFileLocker) GetPending() xErrorCode {
// Acquire the PENDING lock.
return l.writeLock(_PENDING_BYTE, 1)
}
func (l *vfsFileLocker) GetExclusive() xErrorCode {
// Release the SHARED lock.
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
// Acquire the EXCLUSIVE lock.
rc := l.writeLock(_SHARED_FIRST, _SHARED_SIZE)
// Reacquire the SHARED lock.
if rc != _OK {
l.readLock(_SHARED_FIRST, _SHARED_SIZE)
}
return rc
}
func (l *vfsFileLocker) Downgrade() xErrorCode {
if l.state >= _EXCLUSIVE_LOCK {
// Release the SHARED lock.
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
// Reacquire the SHARED lock.
if rc := l.readLock(_SHARED_FIRST, _SHARED_SIZE); rc != _OK {
// This should never happen.
// We should always be able to reacquire the read lock.
return IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
if l.state >= _RESERVED_LOCK {
l.unlock(_RESERVED_BYTE, 1)
}
if l.state >= _PENDING_LOCK {
l.unlock(_PENDING_BYTE, 1)
}
return _OK
}
func (l *vfsFileLocker) Release() xErrorCode {
// Release all locks.
if l.state >= _RESERVED_LOCK {
l.unlock(_RESERVED_BYTE, 1)
}
if l.state >= _SHARED_LOCK {
l.unlock(_SHARED_FIRST, _SHARED_SIZE)
}
if l.state >= _PENDING_LOCK {
l.unlock(_PENDING_BYTE, 1)
}
return _OK
}
func (l *vfsFileLocker) CheckReserved() (bool, xErrorCode) {
// Test the RESERVED lock.
rc := l.readLock(_RESERVED_BYTE, 1)
if rc == _OK {
l.unlock(_RESERVED_BYTE, 1)
}
return rc != _OK, _OK
}
func (l *vfsFileLocker) CheckPending() (bool, xErrorCode) {
// Test the PENDING lock.
rc := l.readLock(_PENDING_BYTE, 1)
if rc == _OK {
l.unlock(_PENDING_BYTE, 1)
}
return rc != _OK, _OK
}
func (l *vfsFileLocker) unlock(start, len uint32) xErrorCode {
err := windows.UnlockFileEx(windows.Handle(l.file.Fd()),
0, len, 0, &windows.Overlapped{Offset: start})
if err != nil {
return IOERR_UNLOCK
}
return _OK
}
func (l *vfsFileLocker) readLock(start, len uint32) xErrorCode {
return l.errorCode(windows.LockFileEx(windows.Handle(l.file.Fd()),
windows.LOCKFILE_FAIL_IMMEDIATELY,
0, len, 0, &windows.Overlapped{Offset: start}),
IOERR_RDLOCK)
}
func (l *vfsFileLocker) writeLock(start, len uint32) xErrorCode {
return l.errorCode(windows.LockFileEx(windows.Handle(l.file.Fd()),
windows.LOCKFILE_FAIL_IMMEDIATELY|windows.LOCKFILE_EXCLUSIVE_LOCK,
0, len, 0, &windows.Overlapped{Offset: start}),
IOERR_LOCK)
}
func (*vfsFileLocker) errorCode(err error, def xErrorCode) xErrorCode {
if err == nil {
return _OK
}
if errno, _ := err.(syscall.Errno); errno == windows.ERROR_INVALID_HANDLE {
return def
}
return xErrorCode(BUSY)
}