Compare commits

...

93 Commits

Author SHA1 Message Date
Nuno Cruces
dd4823ebf0 Documentation, tests. 2023-05-19 14:45:40 +01:00
Nuno Cruces
663b23ff3b Documentation. 2023-05-19 13:47:37 +01:00
Nuno Cruces
4e2ce6c635 Refactor VFS. 2023-05-19 03:04:07 +01:00
Nuno Cruces
66effb4249 Rename. 2023-05-19 02:28:30 +01:00
Nuno Cruces
e1cce83f71 More VFS API. 2023-05-19 02:00:16 +01:00
Nuno Cruces
df953b31c2 Refactor VFS. 2023-05-18 16:00:34 +01:00
Nuno Cruces
67cc3d35d5 More VFS API. 2023-05-18 01:34:54 +01:00
Nuno Cruces
6846b72b31 Add SetInterrupt to DriverConn. 2023-05-17 14:38:47 +01:00
Nuno Cruces
c94cdaf720 More VFS API. 2023-05-17 14:38:47 +01:00
Nuno Cruces
f6a887dd1c Allow manual runs. 2023-05-17 14:38:35 +01:00
Nuno Cruces
2a010a2022 Towards VFS API. 2023-05-17 01:00:08 +01:00
Nuno Cruces
c86b06b048 Refactor. 2023-05-16 17:52:37 +01:00
Nuno Cruces
a44a13a506 Rename. 2023-05-16 15:40:08 +01:00
Nuno Cruces
4604719966 SQLite 3.42.0. 2023-05-16 14:56:47 +01:00
Nuno Cruces
03168d5d34 Build scripts. 2023-05-16 12:14:34 +01:00
Nuno Cruces
be4b6304f9 Documentation. 2023-05-16 12:14:23 +01:00
Nuno Cruces
b5e678a40a Inline host calls. 2023-05-09 14:41:24 +01:00
Nuno Cruces
2fc4698ddc Avoid some allocs. 2023-05-08 10:47:19 +01:00
dependabot[bot]
bd86539577 Bump golang.org/x/sync from 0.1.0 to 0.2.0
Bumps [golang.org/x/sync](https://github.com/golang/sync) from 0.1.0 to 0.2.0.
- [Commits](https://github.com/golang/sync/compare/v0.1.0...v0.2.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sync
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-05-05 07:52:12 +01:00
dependabot[bot]
7a785d9aec Bump golang.org/x/sys from 0.7.0 to 0.8.0
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.7.0 to 0.8.0.
- [Commits](https://github.com/golang/sys/compare/v0.7.0...v0.8.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-05-05 07:51:01 +01:00
Nuno Cruces
59f79e8e74 Optimize calls. 2023-05-02 01:08:04 +01:00
dependabot[bot]
40457721d7 Bump github.com/tetratelabs/wazero from 1.0.3 to 1.1.0 (#11)
Bumps [github.com/tetratelabs/wazero](https://github.com/tetratelabs/wazero) from 1.0.3 to 1.1.0.
- [Release notes](https://github.com/tetratelabs/wazero/releases)
- [Commits](https://github.com/tetratelabs/wazero/compare/v1.0.3...v1.1.0)

---
updated-dependencies:
- dependency-name: github.com/tetratelabs/wazero
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-02 01:03:42 +01:00
Nuno Cruces
18eeb85783 Improve mock. 2023-04-28 13:50:50 +01:00
Nuno Cruces
b36536979b Fix reopen. 2023-04-28 13:50:32 +01:00
Nuno Cruces
a6226c3b31 wazero v1.0.3. 2023-04-22 10:06:48 +01:00
Nuno Cruces
bed2ee7674 Refactor. 2023-04-22 00:15:44 +01:00
Nuno Cruces
7e6d178122 Fix. 2023-04-21 13:33:24 +01:00
Nuno Cruces
f360c77a78 Optimize blobs. (#10) 2023-04-21 13:31:45 +01:00
Nuno Cruces
759b11a05d wazero 1.0.2. 2023-04-18 23:33:56 +01:00
Nuno Cruces
93ce586139 Optimize time. 2023-04-18 01:00:59 +01:00
Nuno Cruces
2e5082c616 Query pragmas at startup. 2023-04-17 00:29:20 +01:00
Nuno Cruces
34acc28af8 Fix CI. 2023-04-14 15:48:20 +01:00
Nuno Cruces
c1a640f7d8 Build using wasi-sdk. 2023-04-14 15:31:17 +01:00
Nuno Cruces
005b15610a Memory optimizations. 2023-04-11 15:33:38 +01:00
Nuno Cruces
23ee4ccb0b Refactor. 2023-04-10 19:55:44 +01:00
Nuno Cruces
3a8cfd036d Dependencies. 2023-04-10 14:24:06 +01:00
Nuno Cruces
c38382fd8e Refactor. 2023-03-31 14:33:24 +01:00
Nuno Cruces
8509e0b6c8 Test coverage. 2023-03-31 13:42:31 +01:00
Nuno Cruces
9c07e57252 Refactor. 2023-03-29 15:06:22 +01:00
Nuno Cruces
80039385d3 Read only files. 2023-03-25 11:46:13 +00:00
Nuno Cruces
89f4327b2b Sync journal directories. 2023-03-25 11:16:51 +00:00
Nuno Cruces
37a3ff37e8 wazero 1.0. 2023-03-24 21:17:30 +00:00
Nuno Cruces
d880d6842c Refactor VFS. 2023-03-23 13:29:26 +00:00
Nuno Cruces
bef46e7954 Locking improvements (windows). 2023-03-23 12:40:55 +00:00
Nuno Cruces
4e72b4d117 Locking fix. 2023-03-23 11:26:19 +00:00
Nuno Cruces
3b08d02a83 Lock refactoring. 2023-03-23 01:55:54 +00:00
Nuno Cruces
b19c12c4c7 SQLite 3.41.2, prefer speed over size. 2023-03-23 00:44:43 +00:00
Nuno Cruces
859a21ef4e CI improvements. 2023-03-22 12:08:33 +00:00
Nuno Cruces
8ff0ee752f Use flock. 2023-03-22 03:15:54 +00:00
Nuno Cruces
589ad86f76 Extensions. 2023-03-21 00:13:12 +00:00
Nuno Cruces
1a3a1be1f6 Fix test. 2023-03-20 14:26:25 +00:00
Nuno Cruces
222c217bc8 Scripts. 2023-03-20 13:06:31 +00:00
Nuno Cruces
c1dc716391 VFS performance. 2023-03-20 11:02:34 +00:00
Nuno Cruces
71e1e5a8ee Avoid some copies. 2023-03-20 02:16:42 +00:00
Nuno Cruces
e4efb20c71 Generate coverage chart. 2023-03-18 03:51:05 +00:00
Nuno Cruces
2c9459d907 Add SQLite speedtest1. 2023-03-18 03:03:11 +00:00
Nuno Cruces
d0875e5fab Lock timeouts. 2023-03-18 01:13:31 +00:00
Nuno Cruces
15dec13f15 FCNTL_SIZE_HINT, refactor. 2023-03-17 17:13:03 +00:00
Nuno Cruces
f38e36109a FCNTL_HAS_MOVED. 2023-03-17 14:11:09 +00:00
Nuno Cruces
4cb65ccbd9 xFileControl, xDeviceCharacteristics, PSOW. 2023-03-17 13:39:19 +00:00
Nuno Cruces
f789c2fb8b OPEN_NOFOLLOW. 2023-03-16 12:27:44 +00:00
Nuno Cruces
c6a2617dfc Locking fixes. 2023-03-16 02:52:22 +00:00
Nuno Cruces
6fc0afcd12 Towards lock timeouts. 2023-03-15 13:58:16 +00:00
Nuno Cruces
77088962f5 SQLite 3.41.1. 2023-03-15 13:29:09 +00:00
Nuno Cruces
71da34861b Fix time collation. 2023-03-13 04:19:58 +00:00
Nuno Cruces
56e8281bdb Time collation tests. 2023-03-10 16:42:20 +00:00
Nuno Cruces
f61d430e65 Documentation. 2023-03-10 16:26:19 +00:00
Nuno Cruces
dbaed53b9a Sync and delete improvements. 2023-03-10 14:17:02 +00:00
Nuno Cruces
8b1bfd04e3 Simplify windows hacks. 2023-03-10 10:43:02 +00:00
Nuno Cruces
11c1687146 Time collation. 2023-03-09 14:42:29 +00:00
Nuno Cruces
94c43a8685 Use access syscall. 2023-03-09 01:59:46 +00:00
Nuno Cruces
a25159a070 Fix sharing violation. 2023-03-09 01:23:52 +00:00
Nuno Cruces
e007e9b060 Windows fixes. 2023-03-08 20:10:46 +00:00
Nuno Cruces
66a730893f Fix readonly transaction rollback. 2023-03-08 18:07:21 +00:00
Nuno Cruces
926adeb3f5 Remove MustPrepare. 2023-03-08 17:39:41 +00:00
Nuno Cruces
677f51bec1 Savepoint API. 2023-03-08 17:39:23 +00:00
Nuno Cruces
5d6f92b733 Documentation, tests, tweaks. 2023-03-08 13:29:33 +00:00
Nuno Cruces
f5747f19fb Tests. 2023-03-07 14:19:22 +00:00
Nuno Cruces
dfcdbf9c4c Online backup. 2023-03-07 12:15:29 +00:00
Nuno Cruces
ad1e8f4b0e Refactor. 2023-03-07 10:47:55 +00:00
Nuno Cruces
8f29882671 Pass mptest crash. 2023-03-07 04:37:55 +00:00
Nuno Cruces
6c96a019e6 Towards shared modules: refactor. 2023-03-07 04:37:55 +00:00
Nuno Cruces
d291738b81 Towards shared modules: refactor. 2023-03-07 04:37:55 +00:00
Nuno Cruces
c1263d4f33 Towards shared modules: refactor. 2023-03-07 04:37:55 +00:00
Nuno Cruces
1ebdc1aa93 Towards shared modules: refactor. 2023-03-07 04:37:55 +00:00
Nuno Cruces
4dd10f071a Towards shared modules: backup. 2023-03-07 04:37:55 +00:00
Nuno Cruces
7dbddfa5c0 Towards shared modules. 2023-03-07 04:37:55 +00:00
dependabot[bot]
ce5e035801 Bump golang.org/x/sys from 0.5.0 to 0.6.0
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.5.0 to 0.6.0.
- [Release notes](https://github.com/golang/sys/releases)
- [Commits](https://github.com/golang/sys/compare/v0.5.0...v0.6.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-03-07 04:11:45 +00:00
Nuno Cruces
8bb8367a36 Refactor mptest. 2023-03-06 14:27:49 +00:00
Nuno Cruces
9f59b3d0ec Pass mptest multiwrite on Windows. 2023-03-05 14:33:51 +00:00
Nuno Cruces
5f893b5459 Add SQLite mptest. 2023-03-05 12:20:02 +00:00
Nuno Cruces
35b1a97b88 Use finalizers to detect unclosed connections. 2023-03-03 14:50:55 +00:00
Nuno Cruces
416c3863a0 Documentation, tests, dependencies. 2023-03-01 23:47:24 +00:00
100 changed files with 6248 additions and 2716 deletions

View File

@@ -5,6 +5,7 @@ on:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
workflow_dispatch:
jobs:
test:
@@ -15,12 +16,30 @@ jobs:
steps:
- uses: actions/checkout@v3
with:
lfs: 'true'
- name: Set up Go
uses: actions/setup-go@v3
- name: Set up
uses: actions/setup-go@v4
with:
go-version: stable
cache: true
- name: Format
run: gofmt -s -w . && git diff --exit-code
if: matrix.os != 'windows-latest'
- name: Tidy
run: go mod tidy && git diff --exit-code
- name: Download
run: go mod download
- name: Verify
run: go mod verify
- name: Vet
run: go vet ./...
continue-on-error: true
- name: Build
run: go build -v ./...
@@ -28,12 +47,15 @@ jobs:
- name: Test
run: go test -v ./...
- name: Test data races
run: go test -v -race ./...
if: matrix.os == 'ubuntu-latest'
- name: Test BSD locks
run: go test -v -tags sqlite3_bsd ./...
if: matrix.os == 'macos-latest'
- name: Update coverage report
uses: ncruces/go-coverage-report@main
- name: Coverage report
uses: ncruces/go-coverage-report@v0
with:
chart: 'true'
amend: 'true'
if: |
matrix.os == 'ubuntu-latest' &&
github.event_name == 'push'

5
.gitignore vendored
View File

@@ -13,7 +13,4 @@
# Dependency directories (remove the comment below to include it)
# vendor/
tools
# Project
sqlite3/sqlite3*
tools

View File

@@ -2,7 +2,7 @@
[![Go Reference](https://pkg.go.dev/badge/image)](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
[![Go Report](https://goreportcard.com/badge/github.com/ncruces/go-sqlite3)](https://goreportcard.com/report/github.com/ncruces/go-sqlite3)
[![Go Coverage](https://github.com/ncruces/go-sqlite3/wiki/coverage.svg)](https://raw.githack.com/wiki/ncruces/go-sqlite3/coverage.html)
[![Go Coverage](https://github.com/ncruces/go-sqlite3/wiki/coverage.svg)](https://github.com/ncruces/go-sqlite3/wiki/Test-coverage-report)
Go module `github.com/ncruces/go-sqlite3` wraps a [WASM](https://webassembly.org/) build of [SQLite](https://sqlite.org/),
and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings.
@@ -15,9 +15,15 @@ provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
- Package [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
embeds a build of SQLite into your application.
- Package [`github.com/ncruces/go-sqlite3/sqlite3vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/sqlite3vfs)
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
### Caveats
This module replaces the SQLite [OS Interface](https://www.sqlite.org/vfs.html) (aka VFS)
with a [pure Go](sqlite3vfs/) implementation.
This has numerous benefits, but also comes with some drawbacks.
#### Write-Ahead Logging
Because WASM does not support shared memory,
@@ -30,44 +36,47 @@ For non-WAL databases, `NORMAL` locking mode can be activated with
[`PRAGMA locking_mode=NORMAL`](https://www.sqlite.org/pragma.html#pragma_locking_mode).
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
the `database/sql` driver defaults to `NORMAL` locking mode,
and WAL databases are not supported.
the `database/sql` driver defaults to `NORMAL` locking mode.
To open WAL databases, or use `EXCLUSIVE` locking mode,
disable connection pooling by calling
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
#### Open File Description Locks
#### POSIX Advisory Locks
On Unix, this module uses [OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
POSIX advisory locks, which SQLite uses, are
[broken by design](https://www.sqlite.org/src/artifact/90c4fa?ln=1073-1161).
On Linux, macOS and illumos, this module uses
[OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
to synchronize access to database files.
OFD locks are fully compatible with process-associated POSIX advisory locks.
POSIX advisory locks, which SQLite uses, are [broken by design](https://www.sqlite.org/src/artifact/90c4fa?ln=1073-1161).
OFD locks are fully compatible with process-associated POSIX advisory locks,
and are supported on Linux, macOS and illumos.
As a work around for other Unixes, you can use [`nolock=1`](https://www.sqlite.org/uri.html).
On BSD Unixes, this module uses
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2).
BSD locks may _not_ be compatible with process-associated POSIX advisory locks.
#### Testing
The pure Go VFS is stress tested by running an unmodified build of SQLite's
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
on Linux, macOS and Windows.
Performance is tested by running
[speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c).
### Roadmap
- [x] build SQLite using `zig cc --target=wasm32-wasi`
- [x] `:memory:` databases
- [x] port [`test_demovfs.c`](https://www.sqlite.org/src/doc/trunk/src/test_demovfs.c) to Go
- branch [`wasi`](https://github.com/ncruces/go-sqlite3/tree/wasi) uses `test_demovfs.c` directly
- [x] design a nice API, enough for simple use cases
- [x] provide a simple `database/sql` driver
- [x] file locking, compatible with SQLite on macOS/Linux/Windows
- [ ] advanced SQLite features
- [x] nested transactions
- [x] incremental BLOB I/O
- [ ] online backup
- [ ] snapshots
- [x] online backup
- [ ] session extension
- [ ] resumable bulk update
- [ ] shared cache mode
- [ ] unlock-notify
- [ ] custom SQL functions
- [ ] custom VFSes
- [ ] in-memory VFS
- [ ] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
- [ ] in-memory VFS, wrapping a [`bytes.Buffer`](https://pkg.go.dev/bytes#Buffer)
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
- [ ] custom VFS API
- [x] custom VFS API
### Alternatives
- [`modernc.org/sqlite`](https://pkg.go.dev/modernc.org/sqlite)

126
api.go
View File

@@ -1,126 +0,0 @@
// Package sqlite3 wraps the C SQLite API.
package sqlite3
import (
"context"
"github.com/tetratelabs/wazero/api"
)
func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) {
getFun := func(name string) api.Function {
f := module.ExportedFunction(name)
if f == nil {
err = noFuncErr + errorString(name)
return nil
}
return f
}
getVal := func(name string) uint32 {
global := module.ExportedGlobal(name)
if global == nil {
err = noGlobalErr + errorString(name)
return 0
}
return memory{module}.readUint32(uint32(global.Get()))
}
c := Conn{
ctx: ctx,
mem: memory{module},
api: sqliteAPI{
free: getFun("free"),
malloc: getFun("malloc"),
destructor: uint64(getVal("malloc_destructor")),
errcode: getFun("sqlite3_errcode"),
errstr: getFun("sqlite3_errstr"),
errmsg: getFun("sqlite3_errmsg"),
erroff: getFun("sqlite3_error_offset"),
open: getFun("sqlite3_open_v2"),
close: getFun("sqlite3_close"),
prepare: getFun("sqlite3_prepare_v3"),
finalize: getFun("sqlite3_finalize"),
reset: getFun("sqlite3_reset"),
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
bindName: getFun("sqlite3_bind_parameter_name"),
bindNull: getFun("sqlite3_bind_null"),
bindInteger: getFun("sqlite3_bind_int64"),
bindFloat: getFun("sqlite3_bind_double"),
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
columnInteger: getFun("sqlite3_column_int64"),
columnFloat: getFun("sqlite3_column_double"),
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
autocommit: getFun("sqlite3_get_autocommit"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
changes: getFun("sqlite3_changes64"),
blobOpen: getFun("sqlite3_blob_open"),
blobClose: getFun("sqlite3_blob_close"),
blobReopen: getFun("sqlite3_blob_reopen"),
blobBytes: getFun("sqlite3_blob_bytes"),
blobRead: getFun("sqlite3_blob_read"),
blobWrite: getFun("sqlite3_blob_write"),
interrupt: getVal("sqlite3_interrupt_offset"),
},
}
if err != nil {
return nil, err
}
return &c, nil
}
type sqliteAPI struct {
free api.Function
malloc api.Function
destructor uint64
errcode api.Function
errstr api.Function
errmsg api.Function
erroff api.Function
open api.Function
close api.Function
prepare api.Function
finalize api.Function
reset api.Function
step api.Function
exec api.Function
clearBindings api.Function
bindNull api.Function
bindCount api.Function
bindIndex api.Function
bindName api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
columnInteger api.Function
columnFloat api.Function
columnText api.Function
columnBlob api.Function
columnBytes api.Function
autocommit api.Function
lastRowid api.Function
changes api.Function
blobOpen api.Function
blobClose api.Function
blobReopen api.Function
blobBytes api.Function
blobRead api.Function
blobWrite api.Function
interrupt uint32
}

134
backup.go Normal file
View File

@@ -0,0 +1,134 @@
package sqlite3
// Backup is an handle to an ongoing online backup operation.
//
// https://www.sqlite.org/c3ref/backup.html
type Backup struct {
c *Conn
handle uint32
otherc uint32
}
// Backup backs up srcDB on the src connection to the "main" database in dstURI.
//
// Backup opens the SQLite database file dstURI,
// and blocks until the entire backup is complete.
// Use [Conn.BackupInit] for incremental backup.
//
// https://www.sqlite.org/backup.html
func (src *Conn) Backup(srcDB, dstURI string) error {
b, err := src.BackupInit(srcDB, dstURI)
if err != nil {
return err
}
defer b.Close()
_, err = b.Step(-1)
return err
}
// Restore restores dstDB on the dst connection from the "main" database in srcURI.
//
// Restore opens the SQLite database file srcURI,
// and blocks until the entire restore is complete.
//
// https://www.sqlite.org/backup.html
func (dst *Conn) Restore(dstDB, srcURI string) error {
src, err := dst.openDB(srcURI, OPEN_READONLY|OPEN_URI)
if err != nil {
return err
}
b, err := dst.backupInit(dst.handle, dstDB, src, "main")
if err != nil {
return err
}
defer b.Close()
_, err = b.Step(-1)
return err
}
// BackupInit initializes a backup operation to copy the content of one database into another.
//
// BackupInit opens the SQLite database file dstURI,
// then initializes a backup that copies the contents of srcDB on the src connection
// to the "main" database in dstURI.
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupinit
func (src *Conn) BackupInit(srcDB, dstURI string) (*Backup, error) {
dst, err := src.openDB(dstURI, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
if err != nil {
return nil, err
}
return src.backupInit(dst, "main", src.handle, srcDB)
}
func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string) (*Backup, error) {
defer c.arena.reset()
dstPtr := c.arena.string(dstName)
srcPtr := c.arena.string(srcName)
other := dst
if c.handle == dst {
other = src
}
r := c.call(c.api.backupInit,
uint64(dst), uint64(dstPtr),
uint64(src), uint64(srcPtr))
if r[0] == 0 {
defer c.closeDB(other)
r = c.call(c.api.errcode, uint64(dst))
return nil, c.module.error(r[0], dst)
}
return &Backup{
c: c,
otherc: other,
handle: uint32(r[0]),
}, nil
}
// Close finishes a backup operation.
//
// It is safe to close a nil, zero or closed Backup.
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupfinish
func (b *Backup) Close() error {
if b == nil || b.handle == 0 {
return nil
}
r := b.c.call(b.c.api.backupFinish, uint64(b.handle))
b.c.closeDB(b.otherc)
b.handle = 0
return b.c.error(r[0])
}
// Step copies up to nPage pages between the source and destination databases.
// If nPage is negative, all remaining source pages are copied.
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupstep
func (b *Backup) Step(nPage int) (done bool, err error) {
r := b.c.call(b.c.api.backupStep, uint64(b.handle), uint64(nPage))
if r[0] == _DONE {
return true, nil
}
return false, b.c.error(r[0])
}
// Remaining returns the number of pages still to be backed up
// at the conclusion of the most recent [Backup.Step].
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupremaining
func (b *Backup) Remaining() int {
r := b.c.call(b.c.api.backupRemaining, uint64(b.handle))
return int(r[0])
}
// PageCount returns the total number of pages in the source database
// at the conclusion of the most recent [Backup.Step].
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backuppagecount
func (b *Backup) PageCount() int {
r := b.c.call(b.c.api.backupFinish, uint64(b.handle))
return int(r[0])
}

133
blob.go
View File

@@ -1,22 +1,26 @@
package sqlite3
import "io"
import (
"io"
"github.com/ncruces/go-sqlite3/internal/util"
)
// ZeroBlob represents a zero-filled, length n BLOB
// that can be used as an argument to
// [database/sql.DB.Exec] and similar methods.
type ZeroBlob int64
// Blob is a handle to an open BLOB.
// Blob is an handle to an open BLOB.
//
// It implements [io.ReadWriteSeeker] for incremental BLOB I/O.
//
// https://www.sqlite.org/c3ref/blob.html
type Blob struct {
c *Conn
handle uint32
bytes int64
offset int64
handle uint32
}
var _ io.ReadWriteSeeker = &Blob{}
@@ -25,6 +29,7 @@ var _ io.ReadWriteSeeker = &Blob{}
//
// https://www.sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
c.checkInterrupt()
defer c.arena.reset()
blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db)
@@ -45,7 +50,7 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob,
}
blob := Blob{c: c}
blob.handle = c.mem.readUint32(blobPtr)
blob.handle = util.ReadUint32(c.mod, blobPtr)
blob.bytes = int64(c.call(c.api.blobBytes, uint64(blob.handle))[0])
return &blob, nil
}
@@ -81,14 +86,14 @@ func (b *Blob) Read(p []byte) (n int, err error) {
return 0, io.EOF
}
want := int64(len(p))
avail := b.bytes - b.offset
want := int64(len(p))
if want > avail {
want = avail
}
ptr := b.c.new(uint64(want))
defer b.c.free(ptr)
defer b.c.arena.reset()
ptr := b.c.arena.new(uint64(want))
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
uint64(ptr), uint64(want), uint64(b.offset))
@@ -96,30 +101,68 @@ func (b *Blob) Read(p []byte) (n int, err error) {
if err != nil {
return 0, err
}
mem := b.c.mem.view(ptr, uint64(want))
copy(p, mem)
b.offset += want
if b.offset >= b.bytes {
err = io.EOF
}
copy(p, util.View(b.c.mod, ptr, uint64(want)))
return int(want), err
}
// WriteTo implements the [io.WriterTo] interface.
//
// https://www.sqlite.org/c3ref/blob_read.html
func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
if b.offset >= b.bytes {
return 0, nil
}
avail := b.bytes - b.offset
want := int64(65536)
if want > avail {
want = avail
}
ptr := b.c.new(uint64(want))
defer b.c.free(ptr)
for want > 0 {
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
uint64(ptr), uint64(want), uint64(b.offset))
err = b.c.error(r[0])
if err != nil {
return n, err
}
mem := util.View(b.c.mod, ptr, uint64(want))
m, err := w.Write(mem[:want])
b.offset += int64(m)
n += int64(m)
if err != nil {
return n, err
}
if int64(m) != want {
return n, io.ErrShortWrite
}
avail = b.bytes - b.offset
if want > avail {
want = avail
}
}
return n, nil
}
// Write implements the [io.Writer] interface.
//
// https://www.sqlite.org/c3ref/blob_write.html
func (b *Blob) Write(p []byte) (n int, err error) {
offset := b.offset
if offset > b.bytes {
offset = b.bytes
}
ptr := b.c.newBytes(p)
defer b.c.free(ptr)
defer b.c.arena.reset()
ptr := b.c.arena.bytes(p)
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
uint64(ptr), uint64(len(p)), uint64(offset))
uint64(ptr), uint64(len(p)), uint64(b.offset))
err = b.c.error(r[0])
if err != nil {
return 0, err
@@ -128,11 +171,57 @@ func (b *Blob) Write(p []byte) (n int, err error) {
return len(p), nil
}
// ReadFrom implements the [io.ReaderFrom] interface.
//
// https://www.sqlite.org/c3ref/blob_write.html
func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) {
avail := b.bytes - b.offset
want := int64(65536)
if want > avail {
want = avail
}
if want < 1 {
want = 1
}
ptr := b.c.new(uint64(want))
defer b.c.free(ptr)
for {
mem := util.View(b.c.mod, ptr, uint64(want))
m, err := r.Read(mem[:want])
if m > 0 {
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
uint64(ptr), uint64(m), uint64(b.offset))
err := b.c.error(r[0])
if err != nil {
return n, err
}
b.offset += int64(m)
n += int64(m)
}
if err == io.EOF {
return n, nil
}
if err != nil {
return n, err
}
avail = b.bytes - b.offset
if want > avail {
want = avail
}
if want < 1 {
want = 1
}
}
}
// Seek implements the [io.Seeker] interface.
func (b *Blob) Seek(offset int64, whence int) (int64, error) {
switch whence {
default:
return 0, whenceErr
return 0, util.WhenceErr
case io.SeekStart:
break
case io.SeekCurrent:
@@ -141,7 +230,7 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
offset += b.bytes
}
if offset < 0 {
return 0, offsetErr
return 0, util.OffsetErr
}
b.offset = offset
return offset, nil
@@ -151,8 +240,8 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
//
// https://www.sqlite.org/c3ref/blob_reopen.html
func (b *Blob) Reopen(row int64) error {
r := b.c.call(b.c.api.blobReopen, uint64(b.handle), uint64(row))
err := b.c.error(b.c.call(b.c.api.blobReopen, uint64(b.handle), uint64(row))[0])
b.bytes = int64(b.c.call(b.c.api.blobBytes, uint64(b.handle))[0])
b.offset = 0
return b.c.error(r[0])
return err
}

View File

@@ -1,68 +0,0 @@
package sqlite3
import (
"context"
"crypto/rand"
"os"
"runtime"
"strconv"
"sync"
"sync/atomic"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
// Configure SQLite WASM.
//
// Importing package embed initializes these
// with an appropriate build of SQLite:
//
// import _ "github.com/ncruces/go-sqlite3/embed"
var (
Binary []byte // WASM binary to load.
Path string // Path to load the binary from.
)
var sqlite3 sqlite3Runtime
type sqlite3Runtime struct {
once sync.Once
runtime wazero.Runtime
compiled wazero.CompiledModule
instances atomic.Uint64
err error
}
func (s *sqlite3Runtime) instantiateModule(ctx context.Context) (api.Module, error) {
s.once.Do(func() { s.compileModule(ctx) })
if s.err != nil {
return nil, s.err
}
cfg := wazero.NewModuleConfig().
WithName("sqlite3-" + strconv.FormatUint(s.instances.Add(1), 10)).
WithSysWalltime().WithSysNanotime().WithSysNanosleep().
WithOsyield(runtime.Gosched).
WithRandSource(rand.Reader)
return s.runtime.InstantiateModule(ctx, s.compiled, cfg)
}
func (s *sqlite3Runtime) compileModule(ctx context.Context) {
s.runtime = wazero.NewRuntime(ctx)
vfsInstantiate(ctx, s.runtime)
bin := Binary
if bin == nil && Path != "" {
bin, s.err = os.ReadFile(Path)
if s.err != nil {
return
}
}
if bin == nil {
s.err = binaryErr
return
}
s.compiled, s.err = s.runtime.CompileModule(ctx, bin)
}

261
conn.go
View File

@@ -3,14 +3,15 @@ package sqlite3
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"math"
"net/url"
"runtime"
"strings"
"sync/atomic"
"unsafe"
"github.com/tetratelabs/wazero/api"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Conn is a database connection handle.
@@ -18,56 +19,70 @@ import (
//
// https://www.sqlite.org/c3ref/sqlite3.html
type Conn struct {
ctx context.Context
api sqliteAPI
mem memory
handle uint32
*module
arena arena
interrupt context.Context
waiter chan struct{}
pending *Stmt
arena arena
handle uint32
}
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI].
func Open(filename string) (conn *Conn, err error) {
return OpenFlags(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE], [OPEN_URI] and [OPEN_NOFOLLOW].
func Open(filename string) (*Conn, error) {
return newConn(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI|OPEN_NOFOLLOW)
}
// OpenFlags opens an SQLite database file as specified by the filename argument.
//
// If none of the required flags is used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used.
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
//
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
//
// https://www.sqlite.org/c3ref/open.html
func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
ctx := context.Background()
module, err := sqlite3.instantiateModule(ctx)
func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
if flags&(OPEN_READONLY|OPEN_READWRITE|OPEN_CREATE) == 0 {
flags |= OPEN_READWRITE | OPEN_CREATE
}
return newConn(filename, flags)
}
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
mod, err := instantiateModule()
if err != nil {
return nil, err
}
defer func() {
if conn == nil {
module.Close(ctx)
mod.close()
} else {
runtime.SetFinalizer(conn, util.Finalizer[Conn](3))
}
}()
c, err := newConn(ctx, module)
c := &Conn{module: mod}
c.arena = c.newArena(1024)
c.handle, err = c.openDB(filename, flags)
if err != nil {
return nil, err
}
c.arena = c.newArena(1024)
return c, nil
}
func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
defer c.arena.reset()
connPtr := c.arena.new(ptrlen)
namePtr := c.arena.string(filename)
flags |= OPEN_EXRESCODE
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
c.handle = c.mem.readUint32(connPtr)
if err := c.error(r[0]); err != nil {
return nil, err
handle := util.ReadUint32(c.mod, connPtr)
if err := c.module.error(r[0], handle); err != nil {
c.closeDB(handle)
return 0, err
}
if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
@@ -80,11 +95,27 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
pragmas.WriteByte(';')
}
}
if err := c.Exec(pragmas.String()); err != nil {
return nil, fmt.Errorf("sqlite3: invalid _pragma: %w", err)
c.arena.reset()
pragmaPtr := c.arena.string(pragmas.String())
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.module.error(r[0], handle, pragmas.String()); err != nil {
if errors.Is(err, ERROR) {
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
c.closeDB(handle)
return 0, err
}
}
return c, nil
return handle, nil
}
func (c *Conn) closeDB(handle uint32) {
r := c.call(c.api.closeZombie, uint64(handle))
if err := c.module.error(r[0], handle); err != nil {
panic(err)
}
}
// Close closes the database connection.
@@ -102,6 +133,8 @@ func (c *Conn) Close() error {
}
c.SetInterrupt(context.Background())
c.pending.Close()
c.pending = nil
r := c.call(c.api.close, uint64(c.handle))
if err := c.error(r[0]); err != nil {
@@ -109,7 +142,8 @@ func (c *Conn) Close() error {
}
c.handle = 0
return c.mem.mod.Close(c.ctx)
runtime.SetFinalizer(c, nil)
return c.module.close()
}
// Exec is a convenience function that allows an application to run
@@ -125,22 +159,6 @@ func (c *Conn) Exec(sql string) error {
return c.error(r[0])
}
// MustPrepare calls [Conn.Prepare] and panics on error,
// a nil Stmt, or a non-empty tail.
func (c *Conn) MustPrepare(sql string) *Stmt {
s, tail, err := c.PrepareFlags(sql, 0)
if err != nil {
panic(err)
}
if s == nil {
panic(emptyErr)
}
if !emptyStatement(tail) {
panic(tailErr)
}
return s
}
// Prepare calls [Conn.PrepareFlags] with no flags.
func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
return c.PrepareFlags(sql, 0)
@@ -167,8 +185,8 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
uint64(stmtPtr), uint64(tailPtr))
stmt = &Stmt{c: c}
stmt.handle = c.mem.readUint32(stmtPtr)
i := c.mem.readUint32(tailPtr)
stmt.handle = util.ReadUint32(c.mod, stmtPtr)
i := util.ReadUint32(c.mod, tailPtr)
tail = sql[i-sqlPtr:]
if err := c.error(r[0], sql); err != nil {
@@ -228,26 +246,23 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
<-c.waiter // Wait for it to finish.
c.waiter = nil
}
// Reset the pending statement.
if c.pending != nil {
c.pending.Reset()
}
old = c.interrupt
c.interrupt = ctx
if ctx == nil || ctx.Done() == nil {
// Finalize the uncompleted SQL statement.
if c.pending != nil {
c.pending.Close()
c.pending = nil
}
return old
}
// Creating an uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
if c.pending == nil {
c.pending = c.MustPrepare(`SELECT 1 UNION ALL SELECT 2`)
c.pending.Step()
} else {
c.pending.Reset()
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
}
c.pending.Step()
// Don't create the goroutine if we're already interrupted.
// This happens frequently while restoring to a previously interrupted state.
@@ -263,7 +278,8 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
break
case <-ctx.Done(): // Done was closed.
buf := c.mem.view(c.handle+c.api.interrupt, 4)
const isInterruptedOffset = 280
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
// Wait for the next call to SetInterrupt.
<-waiter
@@ -279,7 +295,8 @@ func (c *Conn) checkInterrupt() bool {
if c.interrupt == nil || c.interrupt.Err() == nil {
return false
}
buf := c.mem.view(c.handle+c.api.interrupt, 4)
const isInterruptedOffset = 280
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
return true
}
@@ -287,144 +304,22 @@ func (c *Conn) checkInterrupt() bool {
// Pragma executes a PRAGMA statement and returns any results.
//
// https://www.sqlite.org/pragma.html
func (c *Conn) Pragma(str string) []string {
stmt := c.MustPrepare(`PRAGMA ` + str)
func (c *Conn) Pragma(str string) ([]string, error) {
stmt, _, err := c.Prepare(`PRAGMA ` + str)
if err != nil {
return nil, err
}
defer stmt.Close()
var pragmas []string
for stmt.Step() {
pragmas = append(pragmas, stmt.ColumnText(0))
}
return pragmas
return pragmas, stmt.Close()
}
func (c *Conn) error(rc uint64, sql ...string) error {
if rc == _OK {
return nil
}
err := Error{code: rc}
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
panic(oomErr)
}
var r []uint64
r, _ = c.api.errstr.Call(c.ctx, rc)
if r != nil {
err.str = c.mem.readString(uint32(r[0]), _MAX_STRING)
}
r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle))
if r != nil {
err.msg = c.mem.readString(uint32(r[0]), _MAX_STRING)
}
if sql != nil {
r, _ = c.api.erroff.Call(c.ctx, uint64(c.handle))
if r != nil && r[0] != math.MaxUint32 {
err.sql = sql[0][r[0]:]
}
}
switch err.msg {
case err.str, "not an error":
err.msg = ""
}
return &err
}
func (c *Conn) call(fn api.Function, params ...uint64) []uint64 {
r, err := fn.Call(c.ctx, params...)
if err != nil {
panic(err)
}
return r
}
func (c *Conn) free(ptr uint32) {
if ptr == 0 {
return
}
c.call(c.api.free, uint64(ptr))
}
func (c *Conn) new(size uint64) uint32 {
if size > _MAX_ALLOCATION_SIZE {
panic(oomErr)
}
r := c.call(c.api.malloc, size)
ptr := uint32(r[0])
if ptr == 0 && size != 0 {
panic(oomErr)
}
return ptr
}
func (c *Conn) newBytes(b []byte) uint32 {
if b == nil {
return 0
}
ptr := c.new(uint64(len(b)))
c.mem.writeBytes(ptr, b)
return ptr
}
func (c *Conn) newString(s string) uint32 {
ptr := c.new(uint64(len(s) + 1))
c.mem.writeString(ptr, s)
return ptr
}
func (c *Conn) newArena(size uint64) arena {
return arena{
c: c,
base: c.new(size),
size: uint32(size),
}
}
type arena struct {
c *Conn
base uint32
next uint32
size uint32
ptrs []uint32
}
func (a *arena) free() {
if a.c == nil {
return
}
a.reset()
a.c.free(a.base)
a.c = nil
}
func (a *arena) reset() {
for _, ptr := range a.ptrs {
a.c.free(ptr)
}
a.ptrs = nil
a.next = 0
}
func (a *arena) new(size uint64) uint32 {
if size <= uint64(a.size-a.next) {
ptr := a.base + a.next
a.next += uint32(size)
return ptr
}
ptr := a.c.new(size)
a.ptrs = append(a.ptrs, ptr)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint64(len(s) + 1))
a.c.mem.writeString(ptr, s)
return ptr
return c.module.error(rc, c.handle, sql...)
}
// DriverConn is implemented by the SQLite [database/sql] driver connection.
@@ -439,6 +334,8 @@ type DriverConn interface {
driver.ExecerContext
driver.ConnPrepareContext
Savepoint() (release func(*error))
Savepoint() Savepoint
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
SetInterrupt(ctx context.Context) (old context.Context)
}

View File

@@ -9,8 +9,7 @@ const (
_UTF8 = 1
_MAX_STRING = 512 // Used for short strings: names, error messages…
_MAX_PATHNAME = 512
_MAX_STRING = 512 // Used for short strings: names, error messages…
_MAX_ALLOCATION_SIZE = 0x7ffffeff
@@ -133,46 +132,28 @@ const (
CONSTRAINT_DATATYPE ExtendedErrorCode = xErrorCode(CONSTRAINT) | (12 << 8)
NOTICE_RECOVER_WAL ExtendedErrorCode = xErrorCode(NOTICE) | (1 << 8)
NOTICE_RECOVER_ROLLBACK ExtendedErrorCode = xErrorCode(NOTICE) | (2 << 8)
NOTICE_RBU ExtendedErrorCode = xErrorCode(NOTICE) | (3 << 8)
WARNING_AUTOINDEX ExtendedErrorCode = xErrorCode(WARNING) | (1 << 8)
AUTH_USER ExtendedErrorCode = xErrorCode(AUTH) | (1 << 8)
)
// OpenFlag is a flag for a file open operation.
// OpenFlag is a flag for the [OpenFlags] function.
//
// https://www.sqlite.org/c3ref/c_open_autoproxy.html
type OpenFlag uint32
const (
OPEN_READONLY OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */
OPEN_READWRITE OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */
OPEN_CREATE OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */
OPEN_DELETEONCLOSE OpenFlag = 0x00000008 /* VFS only */
OPEN_EXCLUSIVE OpenFlag = 0x00000010 /* VFS only */
OPEN_AUTOPROXY OpenFlag = 0x00000020 /* VFS only */
OPEN_URI OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */
OPEN_MEMORY OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */
OPEN_MAIN_DB OpenFlag = 0x00000100 /* VFS only */
OPEN_TEMP_DB OpenFlag = 0x00000200 /* VFS only */
OPEN_TRANSIENT_DB OpenFlag = 0x00000400 /* VFS only */
OPEN_MAIN_JOURNAL OpenFlag = 0x00000800 /* VFS only */
OPEN_TEMP_JOURNAL OpenFlag = 0x00001000 /* VFS only */
OPEN_SUBJOURNAL OpenFlag = 0x00002000 /* VFS only */
OPEN_SUPER_JOURNAL OpenFlag = 0x00004000 /* VFS only */
OPEN_NOMUTEX OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */
OPEN_FULLMUTEX OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */
OPEN_SHAREDCACHE OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */
OPEN_PRIVATECACHE OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
OPEN_WAL OpenFlag = 0x00080000 /* VFS only */
OPEN_NOFOLLOW OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */
)
type _AccessFlag uint32
const (
_ACCESS_EXISTS _AccessFlag = 0
_ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
_ACCESS_READ _AccessFlag = 2 /* Unused */
OPEN_READONLY OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */
OPEN_READWRITE OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */
OPEN_CREATE OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */
OPEN_URI OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */
OPEN_MEMORY OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */
OPEN_NOMUTEX OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */
OPEN_FULLMUTEX OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */
OPEN_SHAREDCACHE OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */
OPEN_PRIVATECACHE OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
OPEN_NOFOLLOW OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */
)
// PrepareFlag is a flag that can be passed to [Conn.PrepareFlags].

View File

@@ -16,7 +16,7 @@
//
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
//
// If no PRAGMAs are specifed, a busy timeout of 1 minute
// If no PRAGMAs are specified, a busy timeout of 1 minute
// and normal locking mode are used.
//
// [URI]: https://www.sqlite.org/uri.html
@@ -35,6 +35,7 @@ import (
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
func init() {
@@ -44,12 +45,12 @@ func init() {
type sqlite struct{}
func (sqlite) Open(name string) (_ driver.Conn, err error) {
c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE)
var c conn
c.conn, err = sqlite3.Open(name)
if err != nil {
return nil, err
}
var txBegin string
var pragmas []string
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
@@ -57,9 +58,9 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
switch s := query.Get("_txlock"); s {
case "":
txBegin = "BEGIN"
c.txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
txBegin = "BEGIN " + s
c.txBegin = "BEGIN " + s
default:
c.Close()
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
@@ -69,32 +70,52 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
}
}
if len(pragmas) == 0 {
err := c.Exec(`
PRAGMA busy_timeout=60000;
err := c.conn.Exec(`
PRAGMA locking_mode=normal;
PRAGMA busy_timeout=60000;
`)
if err != nil {
c.Close()
return nil, err
}
c.reusable = true
} else {
s, _, err := c.conn.Prepare(`
SELECT * FROM
PRAGMA_locking_mode,
PRAGMA_query_only;
`)
if err != nil {
c.Close()
return nil, err
}
if s.Step() {
c.reusable = s.ColumnText(0) == "normal"
c.readOnly = s.ColumnRawText(1)[0] // 0 or 1
}
err = s.Close()
if err != nil {
c.Close()
return nil, err
}
}
return conn{
conn: c,
txBegin: txBegin,
}, nil
return c, nil
}
type conn struct {
conn *sqlite3.Conn
txBegin string
txCommit string
conn *sqlite3.Conn
txBegin string
txCommit string
txRollback string
reusable bool
readOnly byte
}
var (
// Ensure these interfaces are implemented:
_ driver.ExecerContext = conn{}
_ driver.ConnBeginTx = conn{}
_ driver.Validator = conn{}
_ sqlite3.DriverConn = conn{}
)
@@ -102,27 +123,36 @@ func (c conn) Close() error {
return c.conn.Close()
}
func (c conn) IsValid() bool {
return c.reusable
}
func (c conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
switch opts.Isolation {
default:
return nil, isolationErr
case driver.IsolationLevel(sql.LevelDefault):
case driver.IsolationLevel(sql.LevelSerializable):
}
func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, error) {
txBegin := c.txBegin
c.txCommit = `COMMIT`
c.txRollback = `ROLLBACK`
if opts.ReadOnly {
c.txCommit = `
ROLLBACK;
PRAGMA query_only=` + c.conn.Pragma("query_only")[0]
txBegin = `
BEGIN deferred;
PRAGMA query_only=on`
c.txCommit = `
ROLLBACK;
PRAGMA query_only=` + string(c.readOnly)
c.txRollback = c.txCommit
}
switch opts.Isolation {
default:
return nil, util.IsolationErr
case
driver.IsolationLevel(sql.LevelDefault),
driver.IsolationLevel(sql.LevelSerializable):
break
}
err := c.conn.Exec(txBegin)
@@ -134,14 +164,14 @@ func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, er
func (c conn) Commit() error {
err := c.conn.Exec(c.txCommit)
if err != nil {
if err != nil && !c.conn.GetAutocommit() {
c.Rollback()
}
return err
}
func (c conn) Rollback() error {
return c.conn.Exec(`ROLLBACK`)
return c.conn.Exec(c.txRollback)
}
func (c conn) Prepare(query string) (driver.Stmt, error) {
@@ -159,7 +189,7 @@ func (c conn) Prepare(query string) (driver.Stmt, error) {
if st != nil {
s.Close()
st.Close()
return nil, tailErr
return nil, util.TailErr
}
}
return stmt{s, c.conn}, nil
@@ -183,13 +213,10 @@ func (c conn) ExecContext(ctx context.Context, query string, args []driver.Named
return nil, err
}
return result{
c.conn.LastInsertRowID(),
c.conn.Changes(),
}, nil
return newResult(c.conn), nil
}
func (c conn) Savepoint() (release func(*error)) {
func (c conn) Savepoint() sqlite3.Savepoint {
return c.conn.Savepoint()
}
@@ -197,6 +224,10 @@ func (c conn) OpenBlob(db, table, column string, row int64, write bool) (*sqlite
return c.conn.OpenBlob(db, table, column, row, write)
}
func (c conn) SetInterrupt(ctx context.Context) (old context.Context) {
return c.conn.SetInterrupt(ctx)
}
type stmt struct {
stmt *sqlite3.Stmt
conn *sqlite3.Conn
@@ -246,10 +277,7 @@ func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver
return nil, err
}
return result{
int64(s.conn.LastInsertRowID()),
int64(s.conn.Changes()),
}, nil
return newResult(s.conn), nil
}
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
@@ -288,11 +316,11 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
case sqlite3.ZeroBlob:
err = s.stmt.BindZeroBlob(id, int64(a))
case time.Time:
err = s.stmt.BindText(id, a.Format(time.RFC3339Nano))
err = s.stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
case nil:
err = s.stmt.BindNull(id)
default:
panic(assertErr)
panic(util.AssertErr())
}
}
if err != nil {
@@ -313,6 +341,17 @@ func (s stmt) CheckNamedValue(arg *driver.NamedValue) error {
}
}
func newResult(c *sqlite3.Conn) driver.Result {
rows := c.Changes()
if rows != 0 {
id := c.LastInsertRowID()
if id != 0 {
return result{id, rows}
}
}
return resultRowsAffected(rows)
}
type result struct{ lastInsertId, rowsAffected int64 }
func (r result) LastInsertId() (int64, error) {
@@ -323,6 +362,16 @@ func (r result) RowsAffected() (int64, error) {
return r.rowsAffected, nil
}
type resultRowsAffected int64
func (r resultRowsAffected) LastInsertId() (int64, error) {
return 0, nil
}
func (r resultRowsAffected) RowsAffected() (int64, error) {
return int64(r), nil
}
type rows struct {
ctx context.Context
stmt *sqlite3.Stmt
@@ -359,11 +408,10 @@ func (r rows) Next(dest []driver.Value) error {
dest[i] = r.stmt.ColumnInt64(i)
case sqlite3.FLOAT:
dest[i] = r.stmt.ColumnFloat(i)
case sqlite3.TEXT:
dest[i] = maybeTime(r.stmt.ColumnText(i))
case sqlite3.BLOB:
buf, _ := dest[i].([]byte)
dest[i] = r.stmt.ColumnBlob(i, buf)
dest[i] = r.stmt.ColumnRawBlob(i)
case sqlite3.TEXT:
dest[i] = stringOrTime(r.stmt.ColumnRawText(i))
case sqlite3.NULL:
if buf, ok := dest[i].([]byte); ok {
dest[i] = buf[0:0]
@@ -371,7 +419,7 @@ func (r rows) Next(dest []driver.Value) error {
dest[i] = nil
}
default:
panic(assertErr)
panic(util.AssertErr())
}
}

View File

@@ -1,4 +1,3 @@
// Package driver provides a database/sql driver for SQLite.
package driver
import (
@@ -12,6 +11,7 @@ import (
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
func Test_Open_dir(t *testing.T) {
@@ -134,14 +134,16 @@ func Test_BeginTx(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.db"))
db, err := sql.Open("sqlite3", "file:"+
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
"?_txlock=exclusive&_pragma=busy_timeout(0)")
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err.Error() != string(isolationErr) {
if err.Error() != string(util.IsolationErr) {
t.Error("want isolationErr")
}
@@ -219,7 +221,7 @@ func Test_Prepare(t *testing.T) {
}
_, err = db.Prepare(`SELECT 1; SELECT 2`)
if err.Error() != string(tailErr) {
if err.Error() != string(util.TailErr) {
t.Error("want tailErr")
}
}
@@ -297,7 +299,7 @@ func Test_QueryRow_blob_null(t *testing.T) {
want := [][]byte{nil, {0xca, 0xfe}, {0xba, 0xbe}, nil}
for i := 0; rows.Next(); i++ {
var buf []byte
var buf sql.RawBytes
err = rows.Scan(&buf)
if err != nil {
t.Fatal(err)

View File

@@ -1,11 +0,0 @@
package driver
type errorString string
func (e errorString) Error() string { return string(e) }
const (
assertErr = errorString("sqlite3: assertion failed")
tailErr = errorString("sqlite3: multiple statements")
isolationErr = errorString("sqlite3: unsupported isolation level")
)

View File

@@ -28,8 +28,8 @@ func Example() {
if err != nil {
log.Fatal(err)
}
defer db.Close()
defer os.Remove("./recordings.db")
defer db.Close()
// Create a table with some data in it.
err = albumsSetup()

View File

@@ -9,23 +9,23 @@ import (
// if it roundtrips back to the same string.
// This way times can be persisted to, and recovered from, the database,
// but if a string is needed, [database/sql] will recover the same string.
func maybeTime(text string) driver.Value {
func stringOrTime(text []byte) driver.Value {
// Weed out (some) values that can't possibly be
// [time.RFC3339Nano] timestamps.
if len(text) < len("2006-01-02T15:04:05Z") {
return text
return string(text)
}
if len(text) > len(time.RFC3339Nano) {
return text
return string(text)
}
if text[4] != '-' || text[10] != 'T' || text[16] != ':' {
return text
return string(text)
}
// Slow path.
date, err := time.Parse(time.RFC3339Nano, text)
if err == nil && date.Format(time.RFC3339Nano) == text {
date, err := time.Parse(time.RFC3339Nano, string(text))
if err == nil && date.Format(time.RFC3339Nano) == string(text) {
return date
}
return text
return string(text)
}

View File

@@ -6,7 +6,7 @@ import (
)
// This checks that any string can be recovered as the same string.
func Fuzz_maybeTime_1(f *testing.F) {
func Fuzz_stringOrTime_1(f *testing.F) {
f.Add("")
f.Add(" ")
f.Add("SQLite")
@@ -22,7 +22,7 @@ func Fuzz_maybeTime_1(f *testing.F) {
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
f.Fuzz(func(t *testing.T, str string) {
value := maybeTime(str)
value := stringOrTime([]byte(str))
switch v := value.(type) {
case time.Time:
@@ -48,7 +48,7 @@ func Fuzz_maybeTime_1(f *testing.F) {
// This checks that any [time.Time] can be recovered as a [time.Time],
// with nanosecond accuracy, and preserving any timezone offset.
func Fuzz_maybeTime_2(f *testing.F) {
func Fuzz_stringOrTime_2(f *testing.F) {
f.Add(0, 0)
f.Add(0, 1)
f.Add(0, -1)
@@ -59,7 +59,7 @@ func Fuzz_maybeTime_2(f *testing.F) {
f.Add(-763421161058, 222_222_222) // twosday, year 22222BC
checkTime := func(t *testing.T, date time.Time) {
value := maybeTime(date.Format(time.RFC3339Nano))
value := stringOrTime([]byte(date.Format(time.RFC3339Nano)))
switch v := value.(type) {
case time.Time:

View File

@@ -20,8 +20,8 @@ func ExampleDriverConn() {
if err != nil {
log.Fatal(err)
}
defer db.Close()
defer os.Remove("demo.db")
defer db.Close()
ctx := context.Background()
@@ -48,7 +48,8 @@ func ExampleDriverConn() {
err = conn.Raw(func(driverConn any) error {
conn := driverConn.(sqlite3.DriverConn)
defer conn.Savepoint()(&err)
savept := conn.Savepoint()
defer savept.Release(&err)
blob, err := conn.OpenBlob("main", "test", "col", id, true)
if err != nil {

23
embed/README.md Normal file
View File

@@ -0,0 +1,23 @@
# Embeddable WASM build of SQLite
This folder includes an embeddable WASM build of SQLite 3.42.0 for use with
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
The following optional features are compiled in:
- [math functions](https://www.sqlite.org/lang_mathfunc.html)
- [FTS3/4](https://www.sqlite.org/fts3.html)/[5](https://www.sqlite.org/fts5.html)
- [JSON](https://www.sqlite.org/json1.html)
- [R*Tree](https://www.sqlite.org/rtree.html)
- [GeoPoly](https://www.sqlite.org/geopoly.html)
- [base64](https://github.com/sqlite/sqlite/blob/master/ext/misc/base64.c)
- [decimal](https://github.com/sqlite/sqlite/blob/master/ext/misc/decimal.c)
- [regexp](https://github.com/sqlite/sqlite/blob/master/ext/misc/regexp.c)
- [series](https://github.com/sqlite/sqlite/blob/master/ext/misc/series.c)
- [uint](https://github.com/sqlite/sqlite/blob/master/ext/misc/uint.c)
- [uuid](https://github.com/sqlite/sqlite/blob/master/ext/misc/uuid.c)
- [time](../sqlite3/time.c)
See the [configuration options](../sqlite3/sqlite_cfg.h).
Built using [`wasi-sdk`](https://github.com/WebAssembly/wasi-sdk),
and [`binaryen`](https://github.com/WebAssembly/binaryen).

View File

@@ -1,63 +1,27 @@
#!/usr/bin/env bash
set -eo pipefail
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
# download SQLite
../sqlite3/download.sh
ROOT=../
BINARYEN="$ROOT/tools/binaryen-version_112/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
# build SQLite
zig cc --target=wasm32-wasi -flto -g0 -Os \
-o sqlite3.wasm ../sqlite3/amalg.c \
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
-o sqlite3.wasm "$ROOT/sqlite3/main.c" \
-I"$ROOT/sqlite3" \
-mexec-model=reactor \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-D_HAVE_SQLITE_CONFIG_H \
-Wl,--export=free \
-Wl,--export=malloc \
-Wl,--export=malloc_destructor \
-Wl,--export=sqlite3_errcode \
-Wl,--export=sqlite3_errstr \
-Wl,--export=sqlite3_errmsg \
-Wl,--export=sqlite3_error_offset \
-Wl,--export=sqlite3_open_v2 \
-Wl,--export=sqlite3_close \
-Wl,--export=sqlite3_prepare_v3 \
-Wl,--export=sqlite3_finalize \
-Wl,--export=sqlite3_reset \
-Wl,--export=sqlite3_step \
-Wl,--export=sqlite3_exec \
-Wl,--export=sqlite3_clear_bindings \
-Wl,--export=sqlite3_bind_parameter_count \
-Wl,--export=sqlite3_bind_parameter_index \
-Wl,--export=sqlite3_bind_parameter_name \
-Wl,--export=sqlite3_bind_null \
-Wl,--export=sqlite3_bind_int64 \
-Wl,--export=sqlite3_bind_double \
-Wl,--export=sqlite3_bind_text64 \
-Wl,--export=sqlite3_bind_blob64 \
-Wl,--export=sqlite3_bind_zeroblob64 \
-Wl,--export=sqlite3_column_count \
-Wl,--export=sqlite3_column_name \
-Wl,--export=sqlite3_column_type \
-Wl,--export=sqlite3_column_int64 \
-Wl,--export=sqlite3_column_double \
-Wl,--export=sqlite3_column_text \
-Wl,--export=sqlite3_column_blob \
-Wl,--export=sqlite3_column_bytes \
-Wl,--export=sqlite3_blob_open \
-Wl,--export=sqlite3_blob_close \
-Wl,--export=sqlite3_blob_bytes \
-Wl,--export=sqlite3_blob_read \
-Wl,--export=sqlite3_blob_write \
-Wl,--export=sqlite3_blob_reopen \
-Wl,--export=sqlite3_get_autocommit \
-Wl,--export=sqlite3_last_insert_rowid \
-Wl,--export=sqlite3_changes64 \
-Wl,--export=sqlite3_unlock_notify \
-Wl,--export=sqlite3_backup_init \
-Wl,--export=sqlite3_backup_step \
-Wl,--export=sqlite3_backup_finish \
-Wl,--export=sqlite3_backup_remaining \
-Wl,--export=sqlite3_backup_pagecount \
-Wl,--export=sqlite3_interrupt_offset \
-Wl,--initial-memory=327680 \
-Wl,--stack-first \
-Wl,--import-undefined \
$(awk '{print "-Wl,--export="$0}' exports.txt)
trap 'rm -f sqlite3.tmp' EXIT
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
"$BINARYEN/wasm-opt" -g -O2 sqlite3.tmp -o sqlite3.wasm \
--enable-multivalue --enable-mutable-globals \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext

47
embed/exports.txt Normal file
View File

@@ -0,0 +1,47 @@
free
malloc
malloc_destructor
sqlite3_errcode
sqlite3_errstr
sqlite3_errmsg
sqlite3_error_offset
sqlite3_open_v2
sqlite3_close
sqlite3_close_v2
sqlite3_prepare_v3
sqlite3_finalize
sqlite3_reset
sqlite3_step
sqlite3_exec
sqlite3_clear_bindings
sqlite3_bind_parameter_count
sqlite3_bind_parameter_index
sqlite3_bind_parameter_name
sqlite3_bind_null
sqlite3_bind_int64
sqlite3_bind_double
sqlite3_bind_text64
sqlite3_bind_blob64
sqlite3_bind_zeroblob64
sqlite3_column_count
sqlite3_column_name
sqlite3_column_type
sqlite3_column_int64
sqlite3_column_double
sqlite3_column_text
sqlite3_column_blob
sqlite3_column_bytes
sqlite3_blob_open
sqlite3_blob_close
sqlite3_blob_bytes
sqlite3_blob_read
sqlite3_blob_write
sqlite3_blob_reopen
sqlite3_get_autocommit
sqlite3_last_insert_rowid
sqlite3_changes64
sqlite3_backup_init
sqlite3_backup_step
sqlite3_backup_finish
sqlite3_backup_remaining
sqlite3_backup_pagecount

View File

@@ -4,9 +4,6 @@
// with an appropriate build of SQLite:
//
// import _ "github.com/ncruces/go-sqlite3/embed"
//
// You can obtain this build of SQLite from:
// https://github.com/ncruces/go-sqlite3/tree/main/embed
package embed
import (

Binary file not shown.

112
error.go
View File

@@ -1,19 +1,20 @@
package sqlite3
import (
"runtime"
"strconv"
"strings"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Error wraps an SQLite Error Code.
//
// https://www.sqlite.org/c3ref/errcode.html
type Error struct {
code uint64
str string
msg string
sql string
code uint64
}
// Code returns the primary error code for this error.
@@ -84,72 +85,7 @@ func (e *Error) SQL() string {
// Error implements the error interface.
func (e ErrorCode) Error() string {
switch e {
case _OK:
return "sqlite3: not an error"
case _ROW:
return "sqlite3: another row available"
case _DONE:
return "sqlite3: no more rows available"
case ERROR:
return "sqlite3: SQL logic error"
case INTERNAL:
break
case PERM:
return "sqlite3: access permission denied"
case ABORT:
return "sqlite3: query aborted"
case BUSY:
return "sqlite3: database is locked"
case LOCKED:
return "sqlite3: database table is locked"
case NOMEM:
return "sqlite3: out of memory"
case READONLY:
return "sqlite3: attempt to write a readonly database"
case INTERRUPT:
return "sqlite3: interrupted"
case IOERR:
return "sqlite3: disk I/O error"
case CORRUPT:
return "sqlite3: database disk image is malformed"
case NOTFOUND:
return "sqlite3: unknown operation"
case FULL:
return "sqlite3: database or disk is full"
case CANTOPEN:
return "sqlite3: unable to open database file"
case PROTOCOL:
return "sqlite3: locking protocol"
case FORMAT:
break
case SCHEMA:
return "sqlite3: database schema has changed"
case TOOBIG:
return "sqlite3: string or blob too big"
case CONSTRAINT:
return "sqlite3: constraint failed"
case MISMATCH:
return "sqlite3: datatype mismatch"
case MISUSE:
return "sqlite3: bad parameter or other API misuse"
case NOLFS:
break
case AUTH:
return "sqlite3: authorization denied"
case EMPTY:
break
case RANGE:
return "sqlite3: column index out of range"
case NOTADB:
return "sqlite3: file is not a database"
case NOTICE:
return "sqlite3: notification message"
case WARNING:
return "sqlite3: warning message"
}
return "sqlite3: unknown error"
return util.ErrorCodeString(uint32(e))
}
// Temporary returns true for [BUSY] errors.
@@ -159,17 +95,7 @@ func (e ErrorCode) Temporary() bool {
// Error implements the error interface.
func (e ExtendedErrorCode) Error() string {
switch x := ErrorCode(e); {
case e == ABORT_ROLLBACK:
return "sqlite3: abort due to ROLLBACK"
case x < _ROW:
return x.Error()
case e == _ROW:
return "sqlite3: another row available"
case e == _DONE:
return "sqlite3: no more rows available"
}
return "sqlite3: unknown error"
return util.ErrorCodeString(uint32(e))
}
// Is tests whether this error matches a given [ErrorCode].
@@ -187,31 +113,3 @@ func (e ExtendedErrorCode) Temporary() bool {
func (e ExtendedErrorCode) Timeout() bool {
return e == BUSY_TIMEOUT
}
type errorString string
func (e errorString) Error() string { return string(e) }
const (
nilErr = errorString("sqlite3: invalid memory address or null pointer dereference")
oomErr = errorString("sqlite3: out of memory")
rangeErr = errorString("sqlite3: index out of range")
noNulErr = errorString("sqlite3: missing NUL terminator")
noGlobalErr = errorString("sqlite3: could not find global: ")
noFuncErr = errorString("sqlite3: could not find function: ")
binaryErr = errorString("sqlite3: no SQLite binary embed/set/loaded")
timeErr = errorString("sqlite3: invalid time value")
emptyErr = errorString("sqlite3: empty statement")
tailErr = errorString("sqlite3: non-empty tail")
notImplErr = errorString("sqlite3: not implemented")
whenceErr = errorString("sqlite3: invalid whence")
offsetErr = errorString("sqlite3: invalid offset")
)
func assertErr() errorString {
msg := "sqlite3: assertion failed"
if _, file, line, ok := runtime.Caller(1); ok {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return errorString(msg)
}

View File

@@ -1,15 +1,16 @@
package sqlite3
import (
"context"
"errors"
"strings"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
)
func Test_assertErr(t *testing.T) {
err := assertErr()
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:11)") {
err := util.AssertErr()
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:12)") {
t.Errorf("got %q", s)
}
}
@@ -120,10 +121,8 @@ func Test_ErrorCode_Error(t *testing.T) {
// Test all error codes.
for i := 0; i == int(ErrorCode(i)); i++ {
want := "sqlite3: "
r, _ := db.api.errstr.Call(context.TODO(), uint64(i))
if r != nil {
want += db.mem.readString(uint32(r[0]), _MAX_STRING)
}
r := db.call(db.api.errstr, uint64(i))
want += util.ReadString(db.mod, uint32(r[0]), _MAX_STRING)
got := ErrorCode(i).Error()
if got != want {
@@ -144,10 +143,8 @@ func Test_ExtendedErrorCode_Error(t *testing.T) {
// Test all extended error codes.
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
want := "sqlite3: "
r, _ := db.api.errstr.Call(context.TODO(), uint64(i))
if r != nil {
want += db.mem.readString(uint32(r[0]), _MAX_STRING)
}
r := db.call(db.api.errstr, uint64(i))
want += util.ReadString(db.mod, uint32(r[0]), _MAX_STRING)
got := ExtendedErrorCode(i).Error()
if got != want {

View File

@@ -26,7 +26,11 @@ func Example() {
log.Fatal(err)
}
stmt := db.MustPrepare(`SELECT id, name FROM users`)
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnInt(0), stmt.ColumnText(1))

8
go.mod
View File

@@ -4,7 +4,9 @@ go 1.19
require (
github.com/ncruces/julianday v0.1.5
github.com/tetratelabs/wazero v1.0.0-pre.9
golang.org/x/sync v0.1.0
golang.org/x/sys v0.5.0
github.com/tetratelabs/wazero v1.1.0
golang.org/x/sync v0.2.0
golang.org/x/sys v0.8.0
)
retract v0.4.0 // tagged from the wrong branch

12
go.sum
View File

@@ -1,8 +1,8 @@
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.0.0-pre.9 h1:2uVdi2bvTi/JQxG2cp3LRm2aRadd3nURn5jcfbvqZcw=
github.com/tetratelabs/wazero v1.0.0-pre.9/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
github.com/tetratelabs/wazero v1.1.0 h1:EByoAhC+QcYpwSZJSs/aV0uokxPwBgKxfiokSUwAknQ=
github.com/tetratelabs/wazero v1.1.0/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

116
internal/util/const.go Normal file
View File

@@ -0,0 +1,116 @@
package util
// https://sqlite.com/matrix/rescode.html
const (
OK = 0 /* Successful result */
ERROR = 1 /* Generic error */
INTERNAL = 2 /* Internal logic error in SQLite */
PERM = 3 /* Access permission denied */
ABORT = 4 /* Callback routine requested an abort */
BUSY = 5 /* The database file is locked */
LOCKED = 6 /* A table in the database is locked */
NOMEM = 7 /* A malloc() failed */
READONLY = 8 /* Attempt to write a readonly database */
INTERRUPT = 9 /* Operation terminated by sqlite3_interrupt() */
IOERR = 10 /* Some kind of disk I/O error occurred */
CORRUPT = 11 /* The database disk image is malformed */
NOTFOUND = 12 /* Unknown opcode in sqlite3_file_control() */
FULL = 13 /* Insertion failed because database is full */
CANTOPEN = 14 /* Unable to open the database file */
PROTOCOL = 15 /* Database lock protocol error */
EMPTY = 16 /* Internal use only */
SCHEMA = 17 /* The database schema changed */
TOOBIG = 18 /* String or BLOB exceeds size limit */
CONSTRAINT = 19 /* Abort due to constraint violation */
MISMATCH = 20 /* Data type mismatch */
MISUSE = 21 /* Library used incorrectly */
NOLFS = 22 /* Uses OS features not supported on host */
AUTH = 23 /* Authorization denied */
FORMAT = 24 /* Not used */
RANGE = 25 /* 2nd parameter to sqlite3_bind out of range */
NOTADB = 26 /* File opened that is not a database file */
NOTICE = 27 /* Notifications from sqlite3_log() */
WARNING = 28 /* Warnings from sqlite3_log() */
ROW = 100 /* sqlite3_step() has another row ready */
DONE = 101 /* sqlite3_step() has finished executing */
ERROR_MISSING_COLLSEQ = ERROR | (1 << 8)
ERROR_RETRY = ERROR | (2 << 8)
ERROR_SNAPSHOT = ERROR | (3 << 8)
IOERR_READ = IOERR | (1 << 8)
IOERR_SHORT_READ = IOERR | (2 << 8)
IOERR_WRITE = IOERR | (3 << 8)
IOERR_FSYNC = IOERR | (4 << 8)
IOERR_DIR_FSYNC = IOERR | (5 << 8)
IOERR_TRUNCATE = IOERR | (6 << 8)
IOERR_FSTAT = IOERR | (7 << 8)
IOERR_UNLOCK = IOERR | (8 << 8)
IOERR_RDLOCK = IOERR | (9 << 8)
IOERR_DELETE = IOERR | (10 << 8)
IOERR_BLOCKED = IOERR | (11 << 8)
IOERR_NOMEM = IOERR | (12 << 8)
IOERR_ACCESS = IOERR | (13 << 8)
IOERR_CHECKRESERVEDLOCK = IOERR | (14 << 8)
IOERR_LOCK = IOERR | (15 << 8)
IOERR_CLOSE = IOERR | (16 << 8)
IOERR_DIR_CLOSE = IOERR | (17 << 8)
IOERR_SHMOPEN = IOERR | (18 << 8)
IOERR_SHMSIZE = IOERR | (19 << 8)
IOERR_SHMLOCK = IOERR | (20 << 8)
IOERR_SHMMAP = IOERR | (21 << 8)
IOERR_SEEK = IOERR | (22 << 8)
IOERR_DELETE_NOENT = IOERR | (23 << 8)
IOERR_MMAP = IOERR | (24 << 8)
IOERR_GETTEMPPATH = IOERR | (25 << 8)
IOERR_CONVPATH = IOERR | (26 << 8)
IOERR_VNODE = IOERR | (27 << 8)
IOERR_AUTH = IOERR | (28 << 8)
IOERR_BEGIN_ATOMIC = IOERR | (29 << 8)
IOERR_COMMIT_ATOMIC = IOERR | (30 << 8)
IOERR_ROLLBACK_ATOMIC = IOERR | (31 << 8)
IOERR_DATA = IOERR | (32 << 8)
IOERR_CORRUPTFS = IOERR | (33 << 8)
LOCKED_SHAREDCACHE = LOCKED | (1 << 8)
LOCKED_VTAB = LOCKED | (2 << 8)
BUSY_RECOVERY = BUSY | (1 << 8)
BUSY_SNAPSHOT = BUSY | (2 << 8)
BUSY_TIMEOUT = BUSY | (3 << 8)
CANTOPEN_NOTEMPDIR = CANTOPEN | (1 << 8)
CANTOPEN_ISDIR = CANTOPEN | (2 << 8)
CANTOPEN_FULLPATH = CANTOPEN | (3 << 8)
CANTOPEN_CONVPATH = CANTOPEN | (4 << 8)
CANTOPEN_DIRTYWAL = CANTOPEN | (5 << 8) /* Not Used */
CANTOPEN_SYMLINK = CANTOPEN | (6 << 8)
CORRUPT_VTAB = CORRUPT | (1 << 8)
CORRUPT_SEQUENCE = CORRUPT | (2 << 8)
CORRUPT_INDEX = CORRUPT | (3 << 8)
READONLY_RECOVERY = READONLY | (1 << 8)
READONLY_CANTLOCK = READONLY | (2 << 8)
READONLY_ROLLBACK = READONLY | (3 << 8)
READONLY_DBMOVED = READONLY | (4 << 8)
READONLY_CANTINIT = READONLY | (5 << 8)
READONLY_DIRECTORY = READONLY | (6 << 8)
ABORT_ROLLBACK = ABORT | (2 << 8)
CONSTRAINT_CHECK = CONSTRAINT | (1 << 8)
CONSTRAINT_COMMITHOOK = CONSTRAINT | (2 << 8)
CONSTRAINT_FOREIGNKEY = CONSTRAINT | (3 << 8)
CONSTRAINT_FUNCTION = CONSTRAINT | (4 << 8)
CONSTRAINT_NOTNULL = CONSTRAINT | (5 << 8)
CONSTRAINT_PRIMARYKEY = CONSTRAINT | (6 << 8)
CONSTRAINT_TRIGGER = CONSTRAINT | (7 << 8)
CONSTRAINT_UNIQUE = CONSTRAINT | (8 << 8)
CONSTRAINT_VTAB = CONSTRAINT | (9 << 8)
CONSTRAINT_ROWID = CONSTRAINT | (10 << 8)
CONSTRAINT_PINNED = CONSTRAINT | (11 << 8)
CONSTRAINT_DATATYPE = CONSTRAINT | (12 << 8)
NOTICE_RECOVER_WAL = NOTICE | (1 << 8)
NOTICE_RECOVER_ROLLBACK = NOTICE | (2 << 8)
NOTICE_RBU = NOTICE | (3 << 8)
WARNING_AUTOINDEX = WARNING | (1 << 8)
AUTH_USER = AUTH | (1 << 8)
OK_LOAD_PERMANENTLY = OK | (1 << 8)
OK_SYMLINK = OK | (2 << 8) /* internal use only */
)

115
internal/util/error.go Normal file
View File

@@ -0,0 +1,115 @@
package util
import (
"fmt"
"runtime"
"strconv"
)
type ErrorString string
func (e ErrorString) Error() string { return string(e) }
const (
NilErr = ErrorString("sqlite3: invalid memory address or null pointer dereference")
OOMErr = ErrorString("sqlite3: out of memory")
RangeErr = ErrorString("sqlite3: index out of range")
NoNulErr = ErrorString("sqlite3: missing NUL terminator")
NoGlobalErr = ErrorString("sqlite3: could not find global: ")
NoFuncErr = ErrorString("sqlite3: could not find function: ")
BinaryErr = ErrorString("sqlite3: no SQLite binary embed/set/loaded")
TimeErr = ErrorString("sqlite3: invalid time value")
WhenceErr = ErrorString("sqlite3: invalid whence")
OffsetErr = ErrorString("sqlite3: invalid offset")
TailErr = ErrorString("sqlite3: multiple statements")
IsolationErr = ErrorString("sqlite3: unsupported isolation level")
NoVFSErr = ErrorString("sqlite3: no such vfs: ")
)
func AssertErr() ErrorString {
msg := "sqlite3: assertion failed"
if _, file, line, ok := runtime.Caller(1); ok {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return ErrorString(msg)
}
func Finalizer[T any](skip int) func(*T) {
msg := fmt.Sprintf("sqlite3: %T not closed", new(T))
if _, file, line, ok := runtime.Caller(skip + 1); ok && skip >= 0 {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return func(*T) { panic(ErrorString(msg)) }
}
func ErrorCodeString(rc uint32) string {
switch rc {
case ABORT_ROLLBACK:
return "sqlite3: abort due to ROLLBACK"
case ROW:
return "sqlite3: another row available"
case DONE:
return "sqlite3: no more rows available"
}
switch rc & 0xff {
case OK:
return "sqlite3: not an error"
case ERROR:
return "sqlite3: SQL logic error"
case INTERNAL:
break
case PERM:
return "sqlite3: access permission denied"
case ABORT:
return "sqlite3: query aborted"
case BUSY:
return "sqlite3: database is locked"
case LOCKED:
return "sqlite3: database table is locked"
case NOMEM:
return "sqlite3: out of memory"
case READONLY:
return "sqlite3: attempt to write a readonly database"
case INTERRUPT:
return "sqlite3: interrupted"
case IOERR:
return "sqlite3: disk I/O error"
case CORRUPT:
return "sqlite3: database disk image is malformed"
case NOTFOUND:
return "sqlite3: unknown operation"
case FULL:
return "sqlite3: database or disk is full"
case CANTOPEN:
return "sqlite3: unable to open database file"
case PROTOCOL:
return "sqlite3: locking protocol"
case FORMAT:
break
case SCHEMA:
return "sqlite3: database schema has changed"
case TOOBIG:
return "sqlite3: string or blob too big"
case CONSTRAINT:
return "sqlite3: constraint failed"
case MISMATCH:
return "sqlite3: datatype mismatch"
case MISUSE:
return "sqlite3: bad parameter or other API misuse"
case NOLFS:
break
case AUTH:
return "sqlite3: authorization denied"
case EMPTY:
break
case RANGE:
return "sqlite3: column index out of range"
case NOTADB:
return "sqlite3: file is not a database"
case NOTICE:
return "sqlite3: notification message"
case WARNING:
return "sqlite3: warning message"
}
return "sqlite3: unknown error"
}

102
internal/util/func.go Normal file
View File

@@ -0,0 +1,102 @@
package util
import (
"context"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
type i32 interface{ ~int32 | ~uint32 }
type i64 interface{ ~int64 | ~uint64 }
type funcII[TR, T0 i32] func(context.Context, api.Module, T0) TR
func (fn funcII[TR, T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0])))
}
func ExportFuncII[TR, T0 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcII[TR, T0](fn),
[]api.ValueType{api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIII[TR, T0, T1 i32] func(context.Context, api.Module, T0, T1) TR
func (fn funcIII[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
}
func ExportFuncIII[TR, T0, T1 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIII[TR, T0, T1](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIII[TR, T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2) TR
func (fn funcIIII[TR, T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2])))
}
func ExportFuncIIII[TR, T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIII[TR, T0, T1, T2](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIIII[TR, T0, T1, T2, T3 i32] func(context.Context, api.Module, T0, T1, T2, T3) TR
func (fn funcIIIII[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
}
func ExportFuncIIIII[TR, T0, T1, T2, T3 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIIII[TR, T0, T1, T2, T3](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIIIII[TR, T0, T1, T2, T3, T4 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4) TR
func (fn funcIIIIII[TR, T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4])))
}
func ExportFuncIIIIII[TR, T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3, T4) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIIIII[TR, T0, T1, T2, T3, T4](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIIIJ[TR, T0, T1, T2 i32, T3 i64] func(context.Context, api.Module, T0, T1, T2, T3) TR
func (fn funcIIIIJ[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
}
func ExportFuncIIIIJ[TR, T0, T1, T2 i32, T3 i64](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIIIJ[TR, T0, T1, T2, T3](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIJ[TR, T0 i32, T1 i64] func(context.Context, api.Module, T0, T1) TR
func (fn funcIIJ[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
}
func ExportFuncIIJ[TR, T0 i32, T1 i64](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIJ[TR, T0, T1](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}

110
internal/util/mem.go Normal file
View File

@@ -0,0 +1,110 @@
package util
import (
"bytes"
"math"
"github.com/tetratelabs/wazero/api"
)
func View(mod api.Module, ptr uint32, size uint64) []byte {
if ptr == 0 {
panic(NilErr)
}
if size > math.MaxUint32 {
panic(RangeErr)
}
buf, ok := mod.Memory().Read(ptr, uint32(size))
if !ok {
panic(RangeErr)
}
return buf
}
func ReadUint32(mod api.Module, ptr uint32) uint32 {
if ptr == 0 {
panic(NilErr)
}
v, ok := mod.Memory().ReadUint32Le(ptr)
if !ok {
panic(RangeErr)
}
return v
}
func WriteUint32(mod api.Module, ptr uint32, v uint32) {
if ptr == 0 {
panic(NilErr)
}
ok := mod.Memory().WriteUint32Le(ptr, v)
if !ok {
panic(RangeErr)
}
}
func ReadUint64(mod api.Module, ptr uint32) uint64 {
if ptr == 0 {
panic(NilErr)
}
v, ok := mod.Memory().ReadUint64Le(ptr)
if !ok {
panic(RangeErr)
}
return v
}
func WriteUint64(mod api.Module, ptr uint32, v uint64) {
if ptr == 0 {
panic(NilErr)
}
ok := mod.Memory().WriteUint64Le(ptr, v)
if !ok {
panic(RangeErr)
}
}
func ReadFloat64(mod api.Module, ptr uint32) float64 {
return math.Float64frombits(ReadUint64(mod, ptr))
}
func WriteFloat64(mod api.Module, ptr uint32, v float64) {
WriteUint64(mod, ptr, math.Float64bits(v))
}
func ReadString(mod api.Module, ptr, maxlen uint32) string {
if ptr == 0 {
panic(NilErr)
}
switch maxlen {
case 0:
return ""
case math.MaxUint32:
// avoid overflow
default:
maxlen = maxlen + 1
}
mem := mod.Memory()
buf, ok := mem.Read(ptr, maxlen)
if !ok {
buf, ok = mem.Read(ptr, mem.Size()-ptr)
if !ok {
panic(RangeErr)
}
}
if i := bytes.IndexByte(buf, 0); i < 0 {
panic(NoNulErr)
} else {
return string(buf[:i])
}
}
func WriteBytes(mod api.Module, ptr uint32, b []byte) {
buf := View(mod, ptr, uint64(len(b)))
copy(buf, b)
}
func WriteString(mod api.Module, ptr uint32, s string) {
buf := View(mod, ptr, uint64(len(s)+1))
buf[len(s)] = 0
copy(buf, s)
}

90
internal/util/mem_test.go Normal file
View File

@@ -0,0 +1,90 @@
package util
import (
"math"
"testing"
)
func TestView_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
View(mock, 0, 8)
t.Error("want panic")
}
func TestView_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
View(mock, 126, 8)
t.Error("want panic")
}
func TestView_overflow(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
View(mock, 1, math.MaxInt64)
t.Error("want panic")
}
func TestReadUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadUint32(mock, 0)
t.Error("want panic")
}
func TestReadUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadUint32(mock, 126)
t.Error("want panic")
}
func TestReadUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadUint64(mock, 0)
t.Error("want panic")
}
func TestReadUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadUint64(mock, 126)
t.Error("want panic")
}
func TestWriteUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
WriteUint32(mock, 0, 1)
t.Error("want panic")
}
func TestWriteUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
WriteUint32(mock, 126, 1)
t.Error("want panic")
}
func TestWriteUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
WriteUint64(mock, 0, 1)
t.Error("want panic")
}
func TestWriteUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
WriteUint64(mock, 126, 1)
t.Error("want panic")
}
func TestReadString_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadString(mock, 130, math.MaxUint32)
t.Error("want panic")
}

View File

@@ -1,63 +1,54 @@
package sqlite3
package util
import (
"context"
"encoding/binary"
"math"
"github.com/tetratelabs/wazero/api"
)
func init() {
Path = "./embed/sqlite3.wasm"
}
func newMemory(size uint32) memory {
mem := make(mockMemory, size)
return memory{mockModule{&mem}}
func NewMockModule(size uint32) api.Module {
mem := mockMemory{buf: make([]byte, size)}
return mockModule{&mem, nil}
}
type mockModule struct {
memory api.Memory
api.Module
}
func (m mockModule) Memory() api.Memory { return m.memory }
func (m mockModule) String() string { return "mockModule" }
func (m mockModule) Name() string { return "mockModule" }
func (m mockModule) ExportedGlobal(name string) api.Global { return nil }
func (m mockModule) ExportedMemory(name string) api.Memory { return nil }
func (m mockModule) ExportedFunction(name string) api.Function { return nil }
func (m mockModule) ExportedMemoryDefinitions() map[string]api.MemoryDefinition { return nil }
func (m mockModule) ExportedFunctionDefinitions() map[string]api.FunctionDefinition { return nil }
func (m mockModule) CloseWithExitCode(ctx context.Context, exitCode uint32) error { return nil }
func (m mockModule) Close(context.Context) error { return nil }
type mockMemory []byte
type mockMemory struct {
buf []byte
api.Memory
}
func (m mockMemory) Definition() api.MemoryDefinition { return nil }
func (m mockMemory) Size() uint32 { return uint32(len(m)) }
func (m mockMemory) Size() uint32 { return uint32(len(m.buf)) }
func (m mockMemory) ReadByte(offset uint32) (byte, bool) {
if offset >= m.Size() {
return 0, false
}
return m[offset], true
return m.buf[offset], true
}
func (m mockMemory) ReadUint16Le(offset uint32) (uint16, bool) {
if !m.hasSize(offset, 2) {
return 0, false
}
return binary.LittleEndian.Uint16(m[offset : offset+2]), true
return binary.LittleEndian.Uint16(m.buf[offset : offset+2]), true
}
func (m mockMemory) ReadUint32Le(offset uint32) (uint32, bool) {
if !m.hasSize(offset, 4) {
return 0, false
}
return binary.LittleEndian.Uint32(m[offset : offset+4]), true
return binary.LittleEndian.Uint32(m.buf[offset : offset+4]), true
}
func (m mockMemory) ReadFloat32Le(offset uint32) (float32, bool) {
@@ -72,7 +63,7 @@ func (m mockMemory) ReadUint64Le(offset uint32) (uint64, bool) {
if !m.hasSize(offset, 8) {
return 0, false
}
return binary.LittleEndian.Uint64(m[offset : offset+8]), true
return binary.LittleEndian.Uint64(m.buf[offset : offset+8]), true
}
func (m mockMemory) ReadFloat64Le(offset uint32) (float64, bool) {
@@ -87,14 +78,14 @@ func (m mockMemory) Read(offset, byteCount uint32) ([]byte, bool) {
if !m.hasSize(offset, byteCount) {
return nil, false
}
return m[offset : offset+byteCount : offset+byteCount], true
return m.buf[offset : offset+byteCount : offset+byteCount], true
}
func (m mockMemory) WriteByte(offset uint32, v byte) bool {
if offset >= m.Size() {
return false
}
m[offset] = v
m.buf[offset] = v
return true
}
@@ -102,7 +93,7 @@ func (m mockMemory) WriteUint16Le(offset uint32, v uint16) bool {
if !m.hasSize(offset, 2) {
return false
}
binary.LittleEndian.PutUint16(m[offset:], v)
binary.LittleEndian.PutUint16(m.buf[offset:], v)
return true
}
@@ -110,7 +101,7 @@ func (m mockMemory) WriteUint32Le(offset, v uint32) bool {
if !m.hasSize(offset, 4) {
return false
}
binary.LittleEndian.PutUint32(m[offset:], v)
binary.LittleEndian.PutUint32(m.buf[offset:], v)
return true
}
@@ -122,7 +113,7 @@ func (m mockMemory) WriteUint64Le(offset uint32, v uint64) bool {
if !m.hasSize(offset, 8) {
return false
}
binary.LittleEndian.PutUint64(m[offset:], v)
binary.LittleEndian.PutUint64(m.buf[offset:], v)
return true
}
@@ -134,7 +125,7 @@ func (m mockMemory) Write(offset uint32, val []byte) bool {
if !m.hasSize(offset, uint32(len(val))) {
return false
}
copy(m[offset:], val)
copy(m.buf[offset:], val)
return true
}
@@ -142,20 +133,16 @@ func (m mockMemory) WriteString(offset uint32, val string) bool {
if !m.hasSize(offset, uint32(len(val))) {
return false
}
copy(m[offset:], val)
copy(m.buf[offset:], val)
return true
}
func (m *mockMemory) Grow(delta uint32) (result uint32, ok bool) {
prev := (len(*m) + 65535) / 65536
*m = append(*m, make([]byte, 65536*delta)...)
prev := (len(m.buf) + 65535) / 65536
m.buf = append(m.buf, make([]byte, 65536*delta)...)
return uint32(prev), true
}
func (m mockMemory) PageSize() (result uint32) {
return uint32(len(m) / 65536)
}
func (m mockMemory) hasSize(offset uint32, byteCount uint32) bool {
return uint64(offset)+uint64(byteCount) <= uint64(len(m))
return uint64(offset)+uint64(byteCount) <= uint64(len(m.buf))
}

242
internal/util/mock_test.go Normal file
View File

@@ -0,0 +1,242 @@
package util
import (
"math"
"testing"
)
func Test_mockMemory_byte(t *testing.T) {
const want byte = 98
mock := NewMockModule(128)
_, ok := mock.Memory().ReadByte(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteByte(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteByte(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadByte(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
func Test_mockMemory_uint16(t *testing.T) {
const want uint16 = 9876
mock := NewMockModule(128)
_, ok := mock.Memory().ReadUint16Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint16Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint16Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadUint16Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
func Test_mockMemory_uint32(t *testing.T) {
const want uint32 = 987654321
mock := NewMockModule(128)
_, ok := mock.Memory().ReadUint32Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint32Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint32Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadUint32Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
func Test_mockMemory_uint64(t *testing.T) {
const want uint64 = 9876543210
mock := NewMockModule(128)
_, ok := mock.Memory().ReadUint64Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint64Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint64Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadUint64Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
func Test_mockMemory_float32(t *testing.T) {
const want float32 = math.Pi
mock := NewMockModule(128)
_, ok := mock.Memory().ReadFloat32Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteFloat32Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteFloat32Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadFloat32Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %f, want %f", got, want)
}
}
func Test_mockMemory_float64(t *testing.T) {
const want float64 = math.Pi
mock := NewMockModule(128)
_, ok := mock.Memory().ReadFloat64Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteFloat64Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteFloat64Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadFloat64Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %f, want %f", got, want)
}
}
func Test_mockMemory_bytes(t *testing.T) {
const want string = "\xca\xfe\xba\xbe"
mock := NewMockModule(128)
_, ok := mock.Memory().Read(128, uint32(len(want)))
if ok {
t.Error("want error")
}
ok = mock.Memory().Write(128, []byte(want))
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteString(128, want)
if ok {
t.Error("want error")
}
ok = mock.Memory().Write(0, []byte(want))
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().Read(0, uint32(len(want)))
if !ok {
t.Error("want ok")
}
if string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
ok = mock.Memory().WriteString(64, want)
if !ok {
t.Error("want ok")
}
got, ok = mock.Memory().Read(64, uint32(len(want)))
if !ok {
t.Error("want ok")
}
if string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
func Test_mockMemory_grow(t *testing.T) {
mock := NewMockModule(128)
_, ok := mock.Memory().ReadByte(65536)
if ok {
t.Error("want error")
}
got, ok := mock.Memory().Grow(1)
if !ok {
t.Error("want ok")
}
if got != 1 {
t.Errorf("got %d, want 1", got)
}
_, ok = mock.Memory().ReadByte(65536)
if !ok {
t.Error("want ok")
}
}

114
mem.go
View File

@@ -1,114 +0,0 @@
package sqlite3
import (
"bytes"
"math"
"github.com/tetratelabs/wazero/api"
)
type memory struct {
mod api.Module
}
func (m memory) view(ptr uint32, size uint64) []byte {
if ptr == 0 {
panic(nilErr)
}
if size > math.MaxUint32 {
panic(rangeErr)
}
buf, ok := m.mod.Memory().Read(ptr, uint32(size))
if !ok {
panic(rangeErr)
}
return buf
}
func (m memory) readUint32(ptr uint32) uint32 {
if ptr == 0 {
panic(nilErr)
}
v, ok := m.mod.Memory().ReadUint32Le(ptr)
if !ok {
panic(rangeErr)
}
return v
}
func (m memory) writeUint32(ptr, v uint32) {
if ptr == 0 {
panic(nilErr)
}
ok := m.mod.Memory().WriteUint32Le(ptr, v)
if !ok {
panic(rangeErr)
}
}
func (m memory) readUint64(ptr uint32) uint64 {
if ptr == 0 {
panic(nilErr)
}
v, ok := m.mod.Memory().ReadUint64Le(ptr)
if !ok {
panic(rangeErr)
}
return v
}
func (m memory) writeUint64(ptr uint32, v uint64) {
if ptr == 0 {
panic(nilErr)
}
ok := m.mod.Memory().WriteUint64Le(ptr, v)
if !ok {
panic(rangeErr)
}
}
func (m memory) readFloat64(ptr uint32) float64 {
return math.Float64frombits(m.readUint64(ptr))
}
func (m memory) writeFloat64(ptr uint32, v float64) {
m.writeUint64(ptr, math.Float64bits(v))
}
func (m memory) readString(ptr, maxlen uint32) string {
if ptr == 0 {
panic(nilErr)
}
switch maxlen {
case 0:
return ""
case math.MaxUint32:
// avoid overflow
default:
maxlen = maxlen + 1
}
mem := m.mod.Memory()
buf, ok := mem.Read(ptr, maxlen)
if !ok {
buf, ok = mem.Read(ptr, mem.Size()-ptr)
if !ok {
panic(rangeErr)
}
}
if i := bytes.IndexByte(buf, 0); i < 0 {
panic(noNulErr)
} else {
return string(buf[:i])
}
}
func (m memory) writeBytes(ptr uint32, b []byte) {
buf := m.view(ptr, uint64(len(b)))
copy(buf, b)
}
func (m memory) writeString(ptr uint32, s string) {
buf := m.view(ptr, uint64(len(s)+1))
buf[len(s)] = 0
copy(buf, s)
}

View File

@@ -1,90 +0,0 @@
package sqlite3
import (
"math"
"testing"
)
func Test_memory_view_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.view(0, 8)
t.Error("want panic")
}
func Test_memory_view_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.view(126, 8)
t.Error("want panic")
}
func Test_memory_view_overflow(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.view(1, math.MaxInt64)
t.Error("want panic")
}
func Test_memory_readUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint32(0)
t.Error("want panic")
}
func Test_memory_readUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint32(126)
t.Error("want panic")
}
func Test_memory_readUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint64(0)
t.Error("want panic")
}
func Test_memory_readUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint64(126)
t.Error("want panic")
}
func Test_memory_writeUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint32(0, 1)
t.Error("want panic")
}
func Test_memory_writeUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint32(126, 1)
t.Error("want panic")
}
func Test_memory_writeUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint64(0, 1)
t.Error("want panic")
}
func Test_memory_writeUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint64(126, 1)
t.Error("want panic")
}
func Test_memory_readString_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readString(130, math.MaxUint32)
t.Error("want panic")
}

358
module.go Normal file
View File

@@ -0,0 +1,358 @@
// Package sqlite3 wraps the C SQLite API.
package sqlite3
import (
"context"
"io"
"math"
"os"
"sync"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/sqlite3vfs"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
// Configure SQLite WASM.
//
// Importing package embed initializes these
// with an appropriate build of SQLite:
//
// import _ "github.com/ncruces/go-sqlite3/embed"
var (
Binary []byte // WASM binary to load.
Path string // Path to load the binary from.
)
var sqlite3 struct {
runtime wazero.Runtime
compiled wazero.CompiledModule
err error
once sync.Once
}
func instantiateModule() (*module, error) {
ctx := context.Background()
sqlite3.once.Do(compileModule)
if sqlite3.err != nil {
return nil, sqlite3.err
}
cfg := wazero.NewModuleConfig()
mod, err := sqlite3.runtime.InstantiateModule(ctx, sqlite3.compiled, cfg)
if err != nil {
return nil, err
}
return newModule(mod)
}
func compileModule() {
ctx := context.Background()
sqlite3.runtime = wazero.NewRuntime(ctx)
env := sqlite3vfs.ExportHostFunctions(sqlite3.runtime.NewHostModuleBuilder("env"))
_, sqlite3.err = env.Instantiate(ctx)
if sqlite3.err != nil {
return
}
bin := Binary
if bin == nil && Path != "" {
bin, sqlite3.err = os.ReadFile(Path)
if sqlite3.err != nil {
return
}
}
if bin == nil {
sqlite3.err = util.BinaryErr
return
}
sqlite3.compiled, sqlite3.err = sqlite3.runtime.CompileModule(ctx, bin)
}
type module struct {
ctx context.Context
mod api.Module
vfs io.Closer
api sqliteAPI
arg [8]uint64
}
func newModule(mod api.Module) (m *module, err error) {
m = &module{}
m.mod = mod
m.ctx, m.vfs = sqlite3vfs.NewContext(context.Background())
getFun := func(name string) api.Function {
f := mod.ExportedFunction(name)
if f == nil {
err = util.NoFuncErr + util.ErrorString(name)
return nil
}
return f
}
getVal := func(name string) uint32 {
g := mod.ExportedGlobal(name)
if g == nil {
err = util.NoGlobalErr + util.ErrorString(name)
return 0
}
return util.ReadUint32(mod, uint32(g.Get()))
}
m.api = sqliteAPI{
free: getFun("free"),
malloc: getFun("malloc"),
destructor: getVal("malloc_destructor"),
errcode: getFun("sqlite3_errcode"),
errstr: getFun("sqlite3_errstr"),
errmsg: getFun("sqlite3_errmsg"),
erroff: getFun("sqlite3_error_offset"),
open: getFun("sqlite3_open_v2"),
close: getFun("sqlite3_close"),
closeZombie: getFun("sqlite3_close_v2"),
prepare: getFun("sqlite3_prepare_v3"),
finalize: getFun("sqlite3_finalize"),
reset: getFun("sqlite3_reset"),
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
bindName: getFun("sqlite3_bind_parameter_name"),
bindNull: getFun("sqlite3_bind_null"),
bindInteger: getFun("sqlite3_bind_int64"),
bindFloat: getFun("sqlite3_bind_double"),
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
columnInteger: getFun("sqlite3_column_int64"),
columnFloat: getFun("sqlite3_column_double"),
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
autocommit: getFun("sqlite3_get_autocommit"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
changes: getFun("sqlite3_changes64"),
blobOpen: getFun("sqlite3_blob_open"),
blobClose: getFun("sqlite3_blob_close"),
blobReopen: getFun("sqlite3_blob_reopen"),
blobBytes: getFun("sqlite3_blob_bytes"),
blobRead: getFun("sqlite3_blob_read"),
blobWrite: getFun("sqlite3_blob_write"),
backupInit: getFun("sqlite3_backup_init"),
backupStep: getFun("sqlite3_backup_step"),
backupFinish: getFun("sqlite3_backup_finish"),
backupRemaining: getFun("sqlite3_backup_remaining"),
backupPageCount: getFun("sqlite3_backup_pagecount"),
}
if err != nil {
return nil, err
}
return m, nil
}
func (m *module) close() error {
err := m.mod.Close(m.ctx)
m.vfs.Close()
return err
}
func (m *module) error(rc uint64, handle uint32, sql ...string) error {
if rc == _OK {
return nil
}
err := Error{code: rc}
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
panic(util.OOMErr)
}
var r []uint64
r = m.call(m.api.errstr, rc)
if r != nil {
err.str = util.ReadString(m.mod, uint32(r[0]), _MAX_STRING)
}
r = m.call(m.api.errmsg, uint64(handle))
if r != nil {
err.msg = util.ReadString(m.mod, uint32(r[0]), _MAX_STRING)
}
if sql != nil {
r = m.call(m.api.erroff, uint64(handle))
if r != nil && r[0] != math.MaxUint32 {
err.sql = sql[0][r[0]:]
}
}
switch err.msg {
case err.str, "not an error":
err.msg = ""
}
return &err
}
func (m *module) call(fn api.Function, params ...uint64) []uint64 {
copy(m.arg[:], params)
err := fn.CallWithStack(m.ctx, m.arg[:])
if err != nil {
// The module closed or panicked; release resources.
m.vfs.Close()
panic(err)
}
return m.arg[:]
}
func (m *module) free(ptr uint32) {
if ptr == 0 {
return
}
m.call(m.api.free, uint64(ptr))
}
func (m *module) new(size uint64) uint32 {
if size > _MAX_ALLOCATION_SIZE {
panic(util.OOMErr)
}
r := m.call(m.api.malloc, size)
ptr := uint32(r[0])
if ptr == 0 && size != 0 {
panic(util.OOMErr)
}
return ptr
}
func (m *module) newBytes(b []byte) uint32 {
if b == nil {
return 0
}
ptr := m.new(uint64(len(b)))
util.WriteBytes(m.mod, ptr, b)
return ptr
}
func (m *module) newString(s string) uint32 {
ptr := m.new(uint64(len(s) + 1))
util.WriteString(m.mod, ptr, s)
return ptr
}
func (m *module) newArena(size uint64) arena {
return arena{
m: m,
base: m.new(size),
size: uint32(size),
}
}
type arena struct {
m *module
ptrs []uint32
base uint32
next uint32
size uint32
}
func (a *arena) free() {
if a.m == nil {
return
}
a.reset()
a.m.free(a.base)
a.m = nil
}
func (a *arena) reset() {
for _, ptr := range a.ptrs {
a.m.free(ptr)
}
a.ptrs = nil
a.next = 0
}
func (a *arena) new(size uint64) uint32 {
if size <= uint64(a.size-a.next) {
ptr := a.base + a.next
a.next += uint32(size)
return ptr
}
ptr := a.m.new(size)
a.ptrs = append(a.ptrs, ptr)
return ptr
}
func (a *arena) bytes(b []byte) uint32 {
if b == nil {
return 0
}
ptr := a.new(uint64(len(b)))
util.WriteBytes(a.m.mod, ptr, b)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint64(len(s) + 1))
util.WriteString(a.m.mod, ptr, s)
return ptr
}
type sqliteAPI struct {
free api.Function
malloc api.Function
errcode api.Function
errstr api.Function
errmsg api.Function
erroff api.Function
open api.Function
close api.Function
closeZombie api.Function
prepare api.Function
finalize api.Function
reset api.Function
step api.Function
exec api.Function
clearBindings api.Function
bindNull api.Function
bindCount api.Function
bindIndex api.Function
bindName api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
columnInteger api.Function
columnFloat api.Function
columnText api.Function
columnBlob api.Function
columnBytes api.Function
autocommit api.Function
lastRowid api.Function
changes api.Function
blobOpen api.Function
blobClose api.Function
blobReopen api.Function
blobBytes api.Function
blobRead api.Function
blobWrite api.Function
backupInit api.Function
backupStep api.Function
backupFinish api.Function
backupRemaining api.Function
backupPageCount api.Function
destructor uint32
}

View File

@@ -4,65 +4,75 @@ import (
"bytes"
"math"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
)
func init() {
Path = "./embed/sqlite3.wasm"
}
func TestConn_error_OOM(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
defer func() { _ = recover() }()
db.error(uint64(NOMEM))
m.error(uint64(NOMEM), 0)
t.Error("want panic")
}
func TestConn_call_nil(t *testing.T) {
func TestConn_call_closed(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
m.close()
defer func() { _ = recover() }()
db.call(db.api.free)
m.call(m.api.free)
t.Error("want panic")
}
func TestConn_new(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
testOOM := func(size uint64) {
t.Run("MaxUint32", func(t *testing.T) {
defer func() { _ = recover() }()
db.new(size)
m.new(math.MaxUint32)
t.Error("want panic")
}
})
testOOM(math.MaxUint32)
testOOM(_MAX_ALLOCATION_SIZE)
t.Run("_MAX_ALLOCATION_SIZE", func(t *testing.T) {
defer func() { _ = recover() }()
m.new(_MAX_ALLOCATION_SIZE)
m.new(_MAX_ALLOCATION_SIZE)
t.Error("want panic")
})
}
func TestConn_newArena(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
arena := db.newArena(16)
arena := m.newArena(16)
defer arena.free()
const title = "Lorem ipsum"
@@ -71,7 +81,7 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := db.mem.readString(ptr, math.MaxUint32); got != title {
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != title {
t.Errorf("got %q, want %q", got, title)
}
@@ -80,7 +90,7 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := db.mem.readString(ptr, math.MaxUint32); got != body {
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != body {
t.Errorf("got %q, want %q", got, body)
}
arena.free()
@@ -89,25 +99,25 @@ func TestConn_newArena(t *testing.T) {
func TestConn_newBytes(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
ptr := db.newBytes(nil)
ptr := m.newBytes(nil)
if ptr != 0 {
t.Errorf("got %#x, want nullptr", ptr)
}
buf := []byte("sqlite3")
ptr = db.newBytes(buf)
ptr = m.newBytes(buf)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := buf
if got := db.mem.view(ptr, uint64(len(want))); !bytes.Equal(got, want) {
if got := util.View(m.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -115,25 +125,25 @@ func TestConn_newBytes(t *testing.T) {
func TestConn_newString(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
ptr := db.newString("")
ptr := m.newString("")
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
str := "sqlite3\000sqlite3"
ptr = db.newString(str)
ptr = m.newString(str)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := str + "\000"
if got := db.mem.view(ptr, uint64(len(want))); string(got) != want {
if got := util.View(m.mod, ptr, uint64(len(want))); string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -141,40 +151,40 @@ func TestConn_newString(t *testing.T) {
func TestConn_getString(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
ptr := db.newString("")
ptr := m.newString("")
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
str := "sqlite3" + "\000 drop this"
ptr = db.newString(str)
ptr = m.newString(str)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := "sqlite3"
if got := db.mem.readString(ptr, math.MaxUint32); got != want {
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != want {
t.Errorf("got %q, want %q", got, want)
}
if got := db.mem.readString(ptr, 0); got != "" {
if got := util.ReadString(m.mod, ptr, 0); got != "" {
t.Errorf("got %q, want empty", got)
}
func() {
defer func() { _ = recover() }()
db.mem.readString(ptr, uint32(len(want)/2))
util.ReadString(m.mod, ptr, uint32(len(want)/2))
t.Error("want panic")
}()
func() {
defer func() { _ = recover() }()
db.mem.readString(0, math.MaxUint32)
util.ReadString(m.mod, 0, math.MaxUint32)
t.Error("want panic")
}()
}
@@ -182,18 +192,18 @@ func TestConn_getString(t *testing.T) {
func TestConn_free(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
db.free(0)
m.free(0)
ptr := db.new(1)
ptr := m.new(1)
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
db.free(ptr)
m.free(ptr)
}

3
sqlite3/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
sqlite3.c
sqlite3.h
sqlite3ext.h

View File

@@ -1,9 +0,0 @@
#include <stddef.h>
#include "main.c"
#include "os.c"
#include "qsort.c"
#include "sqlite3.c"
sqlite3_destructor_type malloc_destructor = &free;
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);

View File

@@ -1,13 +1,31 @@
#!/usr/bin/env bash
set -eo pipefail
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
if [ ! -f "sqlite3.c" ]; then
url="https://sqlite.org/2023/sqlite-amalgamation-3410000.zip"
curl "$url" > sqlite.zip
unzip -d . sqlite.zip
mv sqlite-amalgamation-*/sqlite3* .
rm -rf sqlite-amalgamation-*
rm sqlite.zip
fi
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3420000.zip"
unzip -d . sqlite-amalgamation-*.zip
mv sqlite-amalgamation-*/sqlite3* .
rm -rf sqlite-amalgamation-*
cd ext/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uuid.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/series.c"
cd ~-
cd ../sqlite3vfs/tests/mptest/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/mptest.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config02.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash02.subtest"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/multiwrite01.test"
cd ~-
cd ../sqlite3vfs/tests/speedtest1/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/test/speedtest1.c"
cd ~-

1
sqlite3/ext/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
*.c

View File

@@ -1,8 +1,7 @@
#!/usr/bin/env bash
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
clang-format -i \
main.c \
os.c \
qsort.c \
amalg.c
shopt -s extglob
clang-format -i !(sqlite3*).@(c|h)

View File

@@ -1,14 +1,28 @@
#include <stdbool.h>
#include <stddef.h>
#include "sqlite3.h"
// Configuration
#include "sqlite_cfg.h"
// Amalgamation
#include "sqlite3.c"
// VFS
#include "vfs.c"
// Extensions
#include "ext/base64.c"
#include "ext/decimal.c"
#include "ext/regexp.c"
#include "ext/series.c"
#include "ext/uint.c"
#include "ext/uuid.c"
#include "time.c"
int main() {
int rc = sqlite3_initialize();
if (rc != SQLITE_OK) return 1;
}
sqlite3_vfs *os_vfs();
int sqlite3_os_init() {
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
__attribute__((constructor)) void init() {
sqlite3_initialize();
sqlite3_auto_extension((void (*)(void))sqlite3_base_init);
sqlite3_auto_extension((void (*)(void))sqlite3_decimal_init);
sqlite3_auto_extension((void (*)(void))sqlite3_regexp_init);
sqlite3_auto_extension((void (*)(void))sqlite3_series_init);
sqlite3_auto_extension((void (*)(void))sqlite3_uint_init);
sqlite3_auto_extension((void (*)(void))sqlite3_uuid_init);
sqlite3_auto_extension((void (*)(void))sqlite3_time_init);
}

View File

@@ -1,92 +0,0 @@
#include <time.h>
#include "sqlite3.h"
int os_localtime(sqlite3_int64, struct tm *);
int os_randomness(sqlite3_vfs *, int nByte, char *zOut);
int os_sleep(sqlite3_vfs *, int microseconds);
int os_current_time(sqlite3_vfs *, double *);
int os_current_time_64(sqlite3_vfs *, sqlite3_int64 *);
int os_open(sqlite3_vfs *, sqlite3_filename zName, sqlite3_file *, int flags,
int *pOutFlags);
int os_delete(sqlite3_vfs *, const char *zName, int syncDir);
int os_access(sqlite3_vfs *, const char *zName, int flags, int *pResOut);
int os_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut);
struct os_file {
sqlite3_file base;
int id;
int lock;
};
int os_close(sqlite3_file *);
int os_read(sqlite3_file *, void *, int iAmt, sqlite3_int64 iOfst);
int os_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
int os_truncate(sqlite3_file *, sqlite3_int64 size);
int os_sync(sqlite3_file *, int flags);
int os_file_size(sqlite3_file *, sqlite3_int64 *pSize);
int os_file_control(sqlite3_file *pFile, int op, void *pArg);
int os_lock(sqlite3_file *pFile, int eLock);
int os_unlock(sqlite3_file *pFile, int eLock);
int os_check_reserved_lock(sqlite3_file *pFile, int *pResOut);
static int no_lock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; }
static int no_unlock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; }
static int no_check_reserved_lock(sqlite3_file *pFile, int *pResOut) {
*pResOut = 0;
return SQLITE_OK;
}
static int no_file_control(sqlite3_file *pFile, int op, void *pArg) {
return SQLITE_NOTFOUND;
}
static int no_sector_size(sqlite3_file *pFile) { return 0; }
static int no_device_characteristics(sqlite3_file *pFile) { return 0; }
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
return os_localtime((sqlite3_int64)*pTime, pTm);
}
static int os_open_w(sqlite3_vfs *vfs, sqlite3_filename zName,
sqlite3_file *file, int flags, int *pOutFlags) {
static const sqlite3_io_methods os_io = {
.iVersion = 1,
.xClose = os_close,
.xRead = os_read,
.xWrite = os_write,
.xTruncate = os_truncate,
.xSync = os_sync,
.xFileSize = os_file_size,
.xLock = os_lock,
.xUnlock = os_unlock,
.xCheckReservedLock = os_check_reserved_lock,
.xFileControl = no_file_control,
.xDeviceCharacteristics = no_device_characteristics,
};
int rc = os_open(vfs, zName, file, flags, pOutFlags);
file->pMethods = (char)rc == SQLITE_OK ? &os_io : NULL;
return rc;
}
sqlite3_vfs *os_vfs() {
static sqlite3_vfs os_vfs = {
.iVersion = 2,
.szOsFile = sizeof(struct os_file),
.mxPathname = 512,
.zName = "os",
.xOpen = os_open_w,
.xDelete = os_delete,
.xAccess = os_access,
.xFullPathname = os_full_pathname,
.xRandomness = os_randomness,
.xSleep = os_sleep,
.xCurrentTime = os_current_time,
.xCurrentTimeInt64 = os_current_time_64,
};
return &os_vfs;
}

View File

@@ -1,14 +0,0 @@
#include <stddef.h>
void qsort_r(void *, size_t, size_t,
int (*)(const void *, const void *, void *), void *);
typedef int (*cmpfun)(const void *, const void *);
static int wrapper_cmp(const void *v1, const void *v2, void *cmp) {
return ((cmpfun)cmp)(v1, v2);
}
void qsort(void *base, size_t nel, size_t width, cmpfun cmp) {
qsort_r(base, nel, width, wrapper_cmp, cmp);
}

View File

@@ -28,14 +28,19 @@
#define SQLITE_OMIT_AUTOINIT
#define SQLITE_USE_ALLOCA
// Other Options
#define SQLITE_ALLOW_URI_AUTHORITY
// Because WASM does not support shared memory,
// SQLite disables it for WASM builds.
// SQLite disables WAL for WASM builds.
// We set the default locking mode to EXCLUSIVE instead.
// https://www.sqlite.org/wal.html#noshm
#undef SQLITE_OMIT_WAL
#ifndef SQLITE_DEFAULT_LOCKING_MODE
#define SQLITE_DEFAULT_LOCKING_MODE 1
#endif
// Recommended Extensions
// Amalgamated Extensions
#define SQLITE_ENABLE_MATH_FUNCTIONS 1
#define SQLITE_ENABLE_JSON1 1
@@ -46,15 +51,20 @@
#define SQLITE_ENABLE_RTREE 1
#define SQLITE_ENABLE_GEOPOLY 1
// Snapshot
// #define SQLITE_ENABLE_SNAPSHOT 1
// Session Extension
// #define SQLITE_ENABLE_SESSION 1
// #define SQLITE_ENABLE_PREUPDATE_HOOK 1
// #define SQLITE_ENABLE_SESSION
// #define SQLITE_ENABLE_PREUPDATE_HOOK
// Resumable Bulk Update Extension
// #define SQLITE_ENABLE_RBU 1
// https://stackoverflow.com/a/50616684
#define SECOND(...) SECOND_I(__VA_ARGS__, , )
#define SECOND_I(A, B, ...) B
#define GLUE(A, B) GLUE_I(A, B)
#define GLUE_I(A, B) A##_##B
#define CREATE_REPLACER(A) SECOND(GLUE(A, __LINE__), A)
#define REPLACE_AT_LINE(A) , A
// Implemented in Go.
int localtime_s(struct tm *const pTm, time_t const *const pTime);
// Implemented in vfs.c.
int localtime_s(struct tm *const pTm, time_t const *const pTime);
#define sqlite3_vfs_find CREATE_REPLACER(sqlite3_vfs_find_wrapper)
#define sqlite3_vfs_find_wrapper_25397 REPLACE_AT_LINE(sqlite3_vfs_find)

32
sqlite3/time.c Normal file
View File

@@ -0,0 +1,32 @@
#include <string.h>
#include "sqlite3.h"
static int time_collation(void *pArg, int nKey1, const void *pKey1, int nKey2,
const void *pKey2) {
// Remove a Z suffix if one key is no longer than the other.
// A Z suffix collates before any character but after the empty string.
// This avoids making different keys equal.
const int nK1 = nKey1;
const int nK2 = nKey2;
const char *pK1 = (const char *)pKey1;
const char *pK2 = (const char *)pKey2;
if (nK1 && nK1 <= nK2 && pK1[nK1 - 1] == 'Z') {
nKey1--;
}
if (nK2 && nK2 <= nK1 && pK2[nK2 - 1] == 'Z') {
nKey2--;
}
int n = nKey1 < nKey2 ? nKey1 : nKey2;
int rc = memcmp(pKey1, pKey2, n);
if (rc == 0) {
rc = nKey1 - nKey2;
}
return rc;
}
int sqlite3_time_init(sqlite3 *db, char **pzErrMsg,
const sqlite3_api_routines *pApi) {
return sqlite3_create_collation(db, "time", SQLITE_UTF8, 0, time_collation);
}

138
sqlite3/vfs.c Normal file
View File

@@ -0,0 +1,138 @@
#include <time.h>
#include "sqlite3.h"
int go_localtime(struct tm *, sqlite3_int64);
int go_vfs_find(const char *zVfsName);
int go_randomness(sqlite3_vfs *, int nByte, char *zOut);
int go_sleep(sqlite3_vfs *, int microseconds);
int go_current_time(sqlite3_vfs *, double *);
int go_current_time_64(sqlite3_vfs *, sqlite3_int64 *);
int go_open(sqlite3_vfs *, sqlite3_filename zName, sqlite3_file *, int flags,
int *pOutFlags);
int go_delete(sqlite3_vfs *, const char *zName, int syncDir);
int go_access(sqlite3_vfs *, const char *zName, int flags, int *pResOut);
int go_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut);
int go_close(sqlite3_file *);
int go_read(sqlite3_file *, void *, int iAmt, sqlite3_int64 iOfst);
int go_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
int go_truncate(sqlite3_file *, sqlite3_int64 size);
int go_sync(sqlite3_file *, int flags);
int go_file_size(sqlite3_file *, sqlite3_int64 *pSize);
int go_file_control(sqlite3_file *, int op, void *pArg);
int go_sector_size(sqlite3_file *file);
int go_device_characteristics(sqlite3_file *file);
int go_lock(sqlite3_file *, int eLock);
int go_unlock(sqlite3_file *, int eLock);
int go_check_reserved_lock(sqlite3_file *, int *pResOut);
static int go_open_wrapper(sqlite3_vfs *vfs, sqlite3_filename zName,
sqlite3_file *file, int flags, int *pOutFlags) {
static const sqlite3_io_methods os_io = {
.iVersion = 1,
.xClose = go_close,
.xRead = go_read,
.xWrite = go_write,
.xTruncate = go_truncate,
.xSync = go_sync,
.xFileSize = go_file_size,
.xLock = go_lock,
.xUnlock = go_unlock,
.xCheckReservedLock = go_check_reserved_lock,
.xFileControl = go_file_control,
.xSectorSize = go_sector_size,
.xDeviceCharacteristics = go_device_characteristics,
};
memset(file, 0, vfs->szOsFile);
int rc = go_open(vfs, zName, file, flags, pOutFlags);
if (rc) {
return rc;
}
file->pMethods = &os_io;
return SQLITE_OK;
}
struct go_file {
sqlite3_file base;
int handle;
};
int sqlite3_os_init() {
static sqlite3_vfs os_vfs = {
.iVersion = 2,
.szOsFile = sizeof(struct go_file),
.mxPathname = 512,
.zName = "os",
.xOpen = go_open_wrapper,
.xDelete = go_delete,
.xAccess = go_access,
.xFullPathname = go_full_pathname,
.xRandomness = go_randomness,
.xSleep = go_sleep,
.xCurrentTime = go_current_time,
.xCurrentTimeInt64 = go_current_time_64,
};
return sqlite3_vfs_register(&os_vfs, /*default=*/true);
}
sqlite3_destructor_type malloc_destructor = &free;
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
return go_localtime(pTm, (sqlite3_int64)*pTime);
}
#undef sqlite3_vfs_find
sqlite3_vfs *sqlite3_vfs_find_wrapper(const char *zVfsName) {
if (zVfsName) {
static sqlite3_vfs *go_vfs_list;
sqlite3_vfs *found = NULL;
for (sqlite3_vfs **next = &go_vfs_list; *next;) {
sqlite3_vfs *it = *next;
if (go_vfs_find(it->zName)) {
if (!strcmp(zVfsName, it->zName)) found = it;
next = &it->pNext;
} else {
*next = it->pNext;
free(it);
}
}
if (found) {
return found;
}
if (go_vfs_find(zVfsName)) {
sqlite3_vfs *prev = go_vfs_list;
go_vfs_list = malloc(sizeof(sqlite3_vfs) + strlen(zVfsName) + 1);
char *name = (char *)(go_vfs_list + 1);
strcpy(name, zVfsName);
*go_vfs_list = (sqlite3_vfs){
.iVersion = 2,
.szOsFile = sizeof(struct go_file),
.mxPathname = 512,
.zName = name,
.pNext = prev,
.xOpen = go_open_wrapper,
.xDelete = go_delete,
.xAccess = go_access,
.xFullPathname = go_full_pathname,
.xRandomness = go_randomness,
.xSleep = go_sleep,
.xCurrentTime = go_current_time,
.xCurrentTimeInt64 = go_current_time_64,
};
return go_vfs_list;
}
}
return sqlite3_vfs_find(zVfsName);
}
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");
static_assert(offsetof(sqlite3_vfs, zName) == 16, "Unexpected offset");
static_assert(offsetof(sqlite3, u1.isInterrupted) == 280, "Unexpected offset");

210
sqlite3vfs/const.go Normal file
View File

@@ -0,0 +1,210 @@
package sqlite3vfs
import "github.com/ncruces/go-sqlite3/internal/util"
const (
_MAX_STRING = 512 // Used for short strings: names, error messages…
_MAX_PATHNAME = 512
_DEFAULT_SECTOR_SIZE = 4096
)
// https://www.sqlite.org/rescode.html
type _ErrorCode uint32
func (e _ErrorCode) Error() string {
return util.ErrorCodeString(uint32(e))
}
const (
_OK _ErrorCode = util.OK
_PERM _ErrorCode = util.PERM
_BUSY _ErrorCode = util.BUSY
_IOERR _ErrorCode = util.IOERR
_NOTFOUND _ErrorCode = util.NOTFOUND
_CANTOPEN _ErrorCode = util.CANTOPEN
_IOERR_READ _ErrorCode = util.IOERR_READ
_IOERR_SHORT_READ _ErrorCode = util.IOERR_SHORT_READ
_IOERR_WRITE _ErrorCode = util.IOERR_WRITE
_IOERR_FSYNC _ErrorCode = util.IOERR_FSYNC
_IOERR_DIR_FSYNC _ErrorCode = util.IOERR_DIR_FSYNC
_IOERR_TRUNCATE _ErrorCode = util.IOERR_TRUNCATE
_IOERR_FSTAT _ErrorCode = util.IOERR_FSTAT
_IOERR_UNLOCK _ErrorCode = util.IOERR_UNLOCK
_IOERR_RDLOCK _ErrorCode = util.IOERR_RDLOCK
_IOERR_DELETE _ErrorCode = util.IOERR_DELETE
_IOERR_ACCESS _ErrorCode = util.IOERR_ACCESS
_IOERR_CHECKRESERVEDLOCK _ErrorCode = util.IOERR_CHECKRESERVEDLOCK
_IOERR_LOCK _ErrorCode = util.IOERR_LOCK
_IOERR_CLOSE _ErrorCode = util.IOERR_CLOSE
_IOERR_SEEK _ErrorCode = util.IOERR_SEEK
_IOERR_DELETE_NOENT _ErrorCode = util.IOERR_DELETE_NOENT
_CANTOPEN_FULLPATH _ErrorCode = util.CANTOPEN_FULLPATH
_OK_SYMLINK _ErrorCode = util.OK_SYMLINK
)
// OpenFlag is a flag for the [VFS] Open method.
//
// https://www.sqlite.org/c3ref/c_open_autoproxy.html
type OpenFlag uint32
const (
OPEN_READONLY OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */
OPEN_READWRITE OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */
OPEN_CREATE OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */
OPEN_DELETEONCLOSE OpenFlag = 0x00000008 /* VFS only */
OPEN_EXCLUSIVE OpenFlag = 0x00000010 /* VFS only */
OPEN_AUTOPROXY OpenFlag = 0x00000020 /* VFS only */
OPEN_URI OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */
OPEN_MEMORY OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */
OPEN_MAIN_DB OpenFlag = 0x00000100 /* VFS only */
OPEN_TEMP_DB OpenFlag = 0x00000200 /* VFS only */
OPEN_TRANSIENT_DB OpenFlag = 0x00000400 /* VFS only */
OPEN_MAIN_JOURNAL OpenFlag = 0x00000800 /* VFS only */
OPEN_TEMP_JOURNAL OpenFlag = 0x00001000 /* VFS only */
OPEN_SUBJOURNAL OpenFlag = 0x00002000 /* VFS only */
OPEN_SUPER_JOURNAL OpenFlag = 0x00004000 /* VFS only */
OPEN_NOMUTEX OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */
OPEN_FULLMUTEX OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */
OPEN_SHAREDCACHE OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */
OPEN_PRIVATECACHE OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
OPEN_WAL OpenFlag = 0x00080000 /* VFS only */
OPEN_NOFOLLOW OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
)
// AccessFlag is a flag for the [VFS] Access method.
//
// https://www.sqlite.org/c3ref/c_access_exists.html
type AccessFlag uint32
const (
ACCESS_EXISTS AccessFlag = 0
ACCESS_READWRITE AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
ACCESS_READ AccessFlag = 2 /* Unused */
)
// SyncFlag is a flag for the [File] Sync method.
//
// https://www.sqlite.org/c3ref/c_sync_dataonly.html
type SyncFlag uint32
const (
SYNC_NORMAL SyncFlag = 0x00002
SYNC_FULL SyncFlag = 0x00003
SYNC_DATAONLY SyncFlag = 0x00010
)
// LockLevel is a value used with [File] Lock and Unlock methods.
//
// https://www.sqlite.org/c3ref/c_lock_exclusive.html
type LockLevel uint32
const (
// No locks are held on the database.
// The database may be neither read nor written.
// Any internally cached data is considered suspect and subject to
// verification against the database file before being used.
// Other processes can read or write the database as their own locking
// states permit.
// This is the default state.
LOCK_NONE LockLevel = 0 /* xUnlock() only */
// The database may be read but not written.
// Any number of processes can hold SHARED locks at the same time,
// hence there can be many simultaneous readers.
// But no other thread or process is allowed to write to the database file
// while one or more SHARED locks are active.
LOCK_SHARED LockLevel = 1 /* xLock() or xUnlock() */
// A RESERVED lock means that the process is planning on writing to the
// database file at some point in the future but that it is currently just
// reading from the file.
// Only a single RESERVED lock may be active at one time,
// though multiple SHARED locks can coexist with a single RESERVED lock.
// RESERVED differs from PENDING in that new SHARED locks can be acquired
// while there is a RESERVED lock.
LOCK_RESERVED LockLevel = 2 /* xLock() only */
// A PENDING lock means that the process holding the lock wants to write to
// the database as soon as possible and is just waiting on all current
// SHARED locks to clear so that it can get an EXCLUSIVE lock.
// No new SHARED locks are permitted against the database if a PENDING lock
// is active, though existing SHARED locks are allowed to continue.
LOCK_PENDING LockLevel = 3 /* internal use only */
// An EXCLUSIVE lock is needed in order to write to the database file.
// Only one EXCLUSIVE lock is allowed on the file and no other locks of any
// kind are allowed to coexist with an EXCLUSIVE lock.
// In order to maximize concurrency, SQLite works to minimize the amount of
// time that EXCLUSIVE locks are held.
LOCK_EXCLUSIVE LockLevel = 4 /* xLock() only */
)
// DeviceCharacteristic is a flag retuned by the [File] DeviceCharacteristics method.
//
// https://www.sqlite.org/c3ref/c_iocap_atomic.html
type DeviceCharacteristic uint32
const (
IOCAP_ATOMIC DeviceCharacteristic = 0x00000001
IOCAP_ATOMIC512 DeviceCharacteristic = 0x00000002
IOCAP_ATOMIC1K DeviceCharacteristic = 0x00000004
IOCAP_ATOMIC2K DeviceCharacteristic = 0x00000008
IOCAP_ATOMIC4K DeviceCharacteristic = 0x00000010
IOCAP_ATOMIC8K DeviceCharacteristic = 0x00000020
IOCAP_ATOMIC16K DeviceCharacteristic = 0x00000040
IOCAP_ATOMIC32K DeviceCharacteristic = 0x00000080
IOCAP_ATOMIC64K DeviceCharacteristic = 0x00000100
IOCAP_SAFE_APPEND DeviceCharacteristic = 0x00000200
IOCAP_SEQUENTIAL DeviceCharacteristic = 0x00000400
IOCAP_UNDELETABLE_WHEN_OPEN DeviceCharacteristic = 0x00000800
IOCAP_POWERSAFE_OVERWRITE DeviceCharacteristic = 0x00001000
IOCAP_IMMUTABLE DeviceCharacteristic = 0x00002000
IOCAP_BATCH_ATOMIC DeviceCharacteristic = 0x00004000
)
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type _FcntlOpcode uint32
const (
_FCNTL_LOCKSTATE _FcntlOpcode = 1
_FCNTL_GET_LOCKPROXYFILE _FcntlOpcode = 2
_FCNTL_SET_LOCKPROXYFILE _FcntlOpcode = 3
_FCNTL_LAST_ERRNO _FcntlOpcode = 4
_FCNTL_SIZE_HINT _FcntlOpcode = 5
_FCNTL_CHUNK_SIZE _FcntlOpcode = 6
_FCNTL_FILE_POINTER _FcntlOpcode = 7
_FCNTL_SYNC_OMITTED _FcntlOpcode = 8
_FCNTL_WIN32_AV_RETRY _FcntlOpcode = 9
_FCNTL_PERSIST_WAL _FcntlOpcode = 10
_FCNTL_OVERWRITE _FcntlOpcode = 11
_FCNTL_VFSNAME _FcntlOpcode = 12
_FCNTL_POWERSAFE_OVERWRITE _FcntlOpcode = 13
_FCNTL_PRAGMA _FcntlOpcode = 14
_FCNTL_BUSYHANDLER _FcntlOpcode = 15
_FCNTL_TEMPFILENAME _FcntlOpcode = 16
_FCNTL_MMAP_SIZE _FcntlOpcode = 18
_FCNTL_TRACE _FcntlOpcode = 19
_FCNTL_HAS_MOVED _FcntlOpcode = 20
_FCNTL_SYNC _FcntlOpcode = 21
_FCNTL_COMMIT_PHASETWO _FcntlOpcode = 22
_FCNTL_WIN32_SET_HANDLE _FcntlOpcode = 23
_FCNTL_WAL_BLOCK _FcntlOpcode = 24
_FCNTL_ZIPVFS _FcntlOpcode = 25
_FCNTL_RBU _FcntlOpcode = 26
_FCNTL_VFS_POINTER _FcntlOpcode = 27
_FCNTL_JOURNAL_POINTER _FcntlOpcode = 28
_FCNTL_WIN32_GET_HANDLE _FcntlOpcode = 29
_FCNTL_PDB _FcntlOpcode = 30
_FCNTL_BEGIN_ATOMIC_WRITE _FcntlOpcode = 31
_FCNTL_COMMIT_ATOMIC_WRITE _FcntlOpcode = 32
_FCNTL_ROLLBACK_ATOMIC_WRITE _FcntlOpcode = 33
_FCNTL_LOCK_TIMEOUT _FcntlOpcode = 34
_FCNTL_DATA_VERSION _FcntlOpcode = 35
_FCNTL_SIZE_LIMIT _FcntlOpcode = 36
_FCNTL_CKPT_DONE _FcntlOpcode = 37
_FCNTL_RESERVE_BYTES _FcntlOpcode = 38
_FCNTL_CKPT_START _FcntlOpcode = 39
_FCNTL_EXTERNAL_READER _FcntlOpcode = 40
_FCNTL_CKSM_FILE _FcntlOpcode = 41
_FCNTL_RESET_CACHE _FcntlOpcode = 42
)

View File

@@ -0,0 +1,179 @@
package mptest
import (
"bytes"
"context"
"crypto/rand"
"embed"
"io"
"io/fs"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"github.com/ncruces/go-sqlite3/sqlite3vfs"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
)
//go:embed testdata/mptest.wasm
var binary []byte
//go:embed testdata/*.*test
var scripts embed.FS
var (
rt wazero.Runtime
module wazero.CompiledModule
instances atomic.Uint64
)
func init() {
ctx := context.TODO()
rt = wazero.NewRuntime(ctx)
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
env := sqlite3vfs.ExportHostFunctions(rt.NewHostModuleBuilder("env"))
env.NewFunctionBuilder().WithFunc(system).Export("system")
_, err := env.Instantiate(ctx)
if err != nil {
panic(err)
}
module, err = rt.CompileModule(ctx, binary)
if err != nil {
panic(err)
}
}
func config(ctx context.Context) wazero.ModuleConfig {
name := strconv.FormatUint(instances.Add(1), 10)
log := ctx.Value(logger{}).(io.Writer)
fs, err := fs.Sub(scripts, "testdata")
if err != nil {
panic(err)
}
return wazero.NewModuleConfig().
WithName(name).WithStdout(log).WithStderr(log).WithFS(fs).
WithSysWalltime().WithSysNanotime().WithSysNanosleep().
WithOsyield(runtime.Gosched).
WithRandSource(rand.Reader)
}
func system(ctx context.Context, mod api.Module, ptr uint32) uint32 {
buf, _ := mod.Memory().Read(ptr, mod.Memory().Size()-ptr)
buf = buf[:bytes.IndexByte(buf, 0)]
args := strings.Split(string(buf), " ")
for i := range args {
args[i] = strings.Trim(args[i], `"`)
}
args = args[:len(args)-1]
cfg := config(ctx).WithArgs(args...)
go func() {
ctx, vfs := sqlite3vfs.NewContext(ctx)
mod, _ := rt.InstantiateModule(ctx, module, cfg)
mod.Close(ctx)
vfs.Close()
}()
return 0
}
func Test_config01(t *testing.T) {
ctx, vfs := sqlite3vfs.NewContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "config01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func Test_config02(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
if os.Getenv("CI") != "" {
t.Skip("skipping in CI")
}
ctx, vfs := sqlite3vfs.NewContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "config02.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func Test_crash01(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
ctx, vfs := sqlite3vfs.NewContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "crash01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func Test_multiwrite01(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
ctx, vfs := sqlite3vfs.NewContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "multiwrite01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func newContext(t *testing.T) context.Context {
return context.WithValue(context.Background(), logger{}, &testWriter{T: t})
}
type logger struct{}
type testWriter struct {
*testing.T
buf []byte
mtx sync.Mutex
}
func (l *testWriter) Write(p []byte) (n int, err error) {
l.mtx.Lock()
defer l.mtx.Unlock()
l.buf = append(l.buf, p...)
for {
before, after, found := bytes.Cut(l.buf, []byte("\n"))
if !found {
return len(p), nil
}
l.Logf("%s", before)
l.buf = after
}
}

View File

@@ -0,0 +1,2 @@
mptest.wasm filter=lfs diff=lfs merge=lfs -text
*.*test -crlf

View File

@@ -0,0 +1 @@
mptest.c

28
sqlite3vfs/tests/mptest/testdata/build.sh vendored Executable file
View File

@@ -0,0 +1,28 @@
#!/usr/bin/env bash
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../../../../
BINARYEN="$ROOT/tools/binaryen-version_112/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
-o mptest.wasm main.c \
-I"$ROOT/sqlite3" \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-Wl,--stack-first \
-Wl,--import-undefined \
-DSQLITE_DEFAULT_SYNCHRONOUS=0 \
-DSQLITE_DEFAULT_LOCKING_MODE=0 \
-DHAVE_USLEEP -DSQLITE_NO_SYNC \
-DSQLITE_THREADSAFE=0 -DSQLITE_OMIT_LOAD_EXTENSION \
-D_WASI_EMULATED_GETPID -lwasi-emulated-getpid
"$BINARYEN/wasm-opt" -g -O2 mptest.wasm -o mptest.tmp \
--enable-multivalue --enable-mutable-globals \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext
mv mptest.tmp mptest.wasm

View File

@@ -0,0 +1,46 @@
/*
** Configure five tasks in different ways, then run tests.
*/
--if vfsname() GLOB 'unix'
PRAGMA page_size=8192;
--task 1
PRAGMA journal_mode=PERSIST;
PRAGMA mmap_size=0;
--end
--task 2
PRAGMA journal_mode=TRUNCATE;
PRAGMA mmap_size=28672;
--end
--task 3
PRAGMA journal_mode=MEMORY;
--end
--task 4
PRAGMA journal_mode=OFF;
--end
--task 4
PRAGMA mmap_size(268435456);
--end
--source multiwrite01.test
--wait all
PRAGMA page_size=16384;
VACUUM;
CREATE TABLE pgsz(taskid, sz INTEGER);
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--wait all
SELECT sz FROM pgsz;
--match 16384 16384 16384 16384 16384

View File

@@ -0,0 +1,123 @@
/*
** Configure five tasks in different ways, then run tests.
*/
PRAGMA page_size=512;
--task 1
PRAGMA mmap_size=0;
--end
--task 2
PRAGMA mmap_size=28672;
--end
--task 3
PRAGMA mmap_size=8192;
--end
--task 4
PRAGMA mmap_size=65536;
--end
--task 5
PRAGMA mmap_size=268435456;
--end
--source multiwrite01.test
--source crash02.subtest
PRAGMA page_size=1024;
VACUUM;
CREATE TABLE pgsz(taskid, sz INTEGER);
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--source crash02.subtest
--wait all
SELECT sz FROM pgsz;
--match 1024 1024 1024 1024 1024
PRAGMA page_size=2048;
VACUUM;
DELETE FROM pgsz;
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--source crash02.subtest
--wait all
SELECT sz FROM pgsz;
--match 2048 2048 2048 2048 2048
PRAGMA page_size=8192;
VACUUM;
DELETE FROM pgsz;
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--source crash02.subtest
--wait all
SELECT sz FROM pgsz;
--match 8192 8192 8192 8192 8192
PRAGMA page_size=16384;
VACUUM;
DELETE FROM pgsz;
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--source crash02.subtest
--wait all
SELECT sz FROM pgsz;
--match 16384 16384 16384 16384 16384
PRAGMA auto_vacuum=FULL;
VACUUM;
--source multiwrite01.test
--source crash02.subtest
--wait all
PRAGMA auto_vacuum=FULL;
PRAGMA page_size=512;
VACUUM;
--source multiwrite01.test
--source crash02.subtest

View File

@@ -0,0 +1,106 @@
/* Test cases involving incomplete transactions that must be rolled back.
*/
--task 1
DROP TABLE IF EXISTS t1;
CREATE TABLE t1(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t1 VALUES(1, randomblob(2000));
INSERT INTO t1 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t1 SELECT a+2, randomblob(1500) FROM t1;
INSERT INTO t1 SELECT a+4, randomblob(1500) FROM t1;
INSERT INTO t1 SELECT a+8, randomblob(1500) FROM t1;
--sleep 1
INSERT INTO t1 SELECT a+16, randomblob(1500) FROM t1;
--sleep 1
INSERT INTO t1 SELECT a+32, randomblob(1500) FROM t1;
SELECT count(*) FROM t1;
--match 64
SELECT avg(length(b)) FROM t1;
--match 1500.0
--sleep 2
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t1;
--match 247
SELECT a FROM t1 WHERE b='x17y';
--match 17
CREATE INDEX t1b ON t1(b);
SELECT a FROM t1 WHERE b='x17y';
--match 17
SELECT a FROM t1 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--wait 1
--task 2
DROP TABLE IF EXISTS t2;
CREATE TABLE t2(a INTEGER PRIMARY KEY, b);
INSERT INTO t2 SELECT a, b FROM t1;
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t2;
--match 247
SELECT a FROM t2 WHERE b='x17y';
--match 17
CREATE INDEX t2b ON t2(b);
SELECT a FROM t2 WHERE b='x17y';
--match 17
SELECT a FROM t2 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 3
DROP TABLE IF EXISTS t3;
CREATE TABLE t3(a INTEGER PRIMARY KEY, b);
INSERT INTO t3 SELECT a, b FROM t1;
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t3;
--match 247
SELECT a FROM t3 WHERE b='x17y';
--match 17
CREATE INDEX t3b ON t3(b);
SELECT a FROM t3 WHERE b='x17y';
--match 17
SELECT a FROM t3 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 4
DROP TABLE IF EXISTS t4;
CREATE TABLE t4(a INTEGER PRIMARY KEY, b);
INSERT INTO t4 SELECT a, b FROM t1;
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t4;
--match 247
SELECT a FROM t4 WHERE b='x17y';
--match 17
CREATE INDEX t4b ON t4(b);
SELECT a FROM t4 WHERE b='x17y';
--match 17
SELECT a FROM t4 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 5
DROP TABLE IF EXISTS t5;
CREATE TABLE t5(a INTEGER PRIMARY KEY, b);
INSERT INTO t5 SELECT a, b FROM t1;
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t5;
--match 247
SELECT a FROM t5 WHERE b='x17y';
--match 17
CREATE INDEX t5b ON t5(b);
SELECT a FROM t5 WHERE b='x17y';
--match 17
SELECT a FROM t5 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--wait all
/* After the database file has been set up, run the crash2 subscript
** multiple times. */
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest

View File

@@ -0,0 +1,53 @@
/*
** This script is called from crash01.test and config02.test and perhaps other
** script. After the database file has been set up, make a big rollback
** journal in client 1, then crash client 1.
** Then in the other clients, do an integrity check.
*/
--task 1 leave-hot-journal
--sleep 5
--finish
PRAGMA cache_size=10;
BEGIN;
UPDATE t1 SET b=randomblob(20000);
UPDATE t2 SET b=randomblob(20000);
UPDATE t3 SET b=randomblob(20000);
UPDATE t4 SET b=randomblob(20000);
UPDATE t5 SET b=randomblob(20000);
UPDATE t1 SET b=NULL;
UPDATE t2 SET b=NULL;
UPDATE t3 SET b=NULL;
UPDATE t4 SET b=NULL;
UPDATE t5 SET b=NULL;
--print Task one crashing an incomplete transaction
--exit 1
--end
--task 2 integrity_check-2
SELECT count(*) FROM t1;
--match 64
--sleep 100
PRAGMA integrity_check(10);
--match ok
--end
--task 3 integrity_check-3
SELECT count(*) FROM t1;
--match 64
--sleep 100
PRAGMA integrity_check(10);
--match ok
--end
--task 4 integrity_check-4
SELECT count(*) FROM t1;
--match 64
--sleep 100
PRAGMA integrity_check(10);
--match ok
--end
--task 5 integrity_check-5
SELECT count(*) FROM t1;
--match 64
--sleep 100
PRAGMA integrity_check(10);
--match ok
--end
--wait all

18
sqlite3vfs/tests/mptest/testdata/main.c vendored Normal file
View File

@@ -0,0 +1,18 @@
#include <stdbool.h>
#include <stddef.h>
// Configuration
#include "sqlite_cfg.h"
// Amalgamation
#include "sqlite3.c"
// VFS
#include "vfs.c"
__attribute__((constructor)) void init() { sqlite3_initialize(); }
static int dont_unlink(const char *pathname) { return 0; }
#define sqlite3_enable_load_extension(...)
#define sqlite3_trace(...)
#define unlink dont_unlink
#undef UNUSED_PARAMETER
#include "mptest.c"

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:aafb873c26d32dbf7498e4c4dd243eb6411c056432818121463d8cb49bc15f70
size 1479293

View File

@@ -0,0 +1,415 @@
/*
** This script sets up five different tasks all writing and updating
** the database at the same time, but each in its own table.
*/
--task 1 build-t1
DROP TABLE IF EXISTS t1;
CREATE TABLE t1(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t1 VALUES(1, randomblob(2000));
INSERT INTO t1 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t1 SELECT a+2, randomblob(1500) FROM t1;
INSERT INTO t1 SELECT a+4, randomblob(1500) FROM t1;
INSERT INTO t1 SELECT a+8, randomblob(1500) FROM t1;
--sleep 1
INSERT INTO t1 SELECT a+16, randomblob(1500) FROM t1;
--sleep 1
INSERT INTO t1 SELECT a+32, randomblob(1500) FROM t1;
SELECT count(*) FROM t1;
--match 64
SELECT avg(length(b)) FROM t1;
--match 1500.0
--sleep 2
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t1;
--match 247
SELECT a FROM t1 WHERE b='x17y';
--match 17
CREATE INDEX t1b ON t1(b);
SELECT a FROM t1 WHERE b='x17y';
--match 17
SELECT a FROM t1 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 2 build-t2
DROP TABLE IF EXISTS t2;
CREATE TABLE t2(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t2 VALUES(1, randomblob(2000));
INSERT INTO t2 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t2 SELECT a+2, randomblob(1500) FROM t2;
INSERT INTO t2 SELECT a+4, randomblob(1500) FROM t2;
INSERT INTO t2 SELECT a+8, randomblob(1500) FROM t2;
--sleep 1
INSERT INTO t2 SELECT a+16, randomblob(1500) FROM t2;
--sleep 1
INSERT INTO t2 SELECT a+32, randomblob(1500) FROM t2;
SELECT count(*) FROM t2;
--match 64
SELECT avg(length(b)) FROM t2;
--match 1500.0
--sleep 2
UPDATE t2 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t2;
--match 247
SELECT a FROM t2 WHERE b='x17y';
--match 17
CREATE INDEX t2b ON t2(b);
SELECT a FROM t2 WHERE b='x17y';
--match 17
SELECT a FROM t2 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 3 build-t3
DROP TABLE IF EXISTS t3;
CREATE TABLE t3(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t3 VALUES(1, randomblob(2000));
INSERT INTO t3 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t3 SELECT a+2, randomblob(1500) FROM t3;
INSERT INTO t3 SELECT a+4, randomblob(1500) FROM t3;
INSERT INTO t3 SELECT a+8, randomblob(1500) FROM t3;
--sleep 1
INSERT INTO t3 SELECT a+16, randomblob(1500) FROM t3;
--sleep 1
INSERT INTO t3 SELECT a+32, randomblob(1500) FROM t3;
SELECT count(*) FROM t3;
--match 64
SELECT avg(length(b)) FROM t3;
--match 1500.0
--sleep 2
UPDATE t3 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t3;
--match 247
SELECT a FROM t3 WHERE b='x17y';
--match 17
CREATE INDEX t3b ON t3(b);
SELECT a FROM t3 WHERE b='x17y';
--match 17
SELECT a FROM t3 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 4 build-t4
DROP TABLE IF EXISTS t4;
CREATE TABLE t4(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t4 VALUES(1, randomblob(2000));
INSERT INTO t4 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t4 SELECT a+2, randomblob(1500) FROM t4;
INSERT INTO t4 SELECT a+4, randomblob(1500) FROM t4;
INSERT INTO t4 SELECT a+8, randomblob(1500) FROM t4;
--sleep 1
INSERT INTO t4 SELECT a+16, randomblob(1500) FROM t4;
--sleep 1
INSERT INTO t4 SELECT a+32, randomblob(1500) FROM t4;
SELECT count(*) FROM t4;
--match 64
SELECT avg(length(b)) FROM t4;
--match 1500.0
--sleep 2
UPDATE t4 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t4;
--match 247
SELECT a FROM t4 WHERE b='x17y';
--match 17
CREATE INDEX t4b ON t4(b);
SELECT a FROM t4 WHERE b='x17y';
--match 17
SELECT a FROM t4 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 5 build-t5
DROP TABLE IF EXISTS t5;
CREATE TABLE t5(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t5 VALUES(1, randomblob(2000));
INSERT INTO t5 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t5 SELECT a+2, randomblob(1500) FROM t5;
INSERT INTO t5 SELECT a+4, randomblob(1500) FROM t5;
INSERT INTO t5 SELECT a+8, randomblob(1500) FROM t5;
--sleep 1
INSERT INTO t5 SELECT a+16, randomblob(1500) FROM t5;
--sleep 1
INSERT INTO t5 SELECT a+32, randomblob(1500) FROM t5;
SELECT count(*) FROM t5;
--match 64
SELECT avg(length(b)) FROM t5;
--match 1500.0
--sleep 2
UPDATE t5 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t5;
--match 247
SELECT a FROM t5 WHERE b='x17y';
--match 17
CREATE INDEX t5b ON t5(b);
SELECT a FROM t5 WHERE b='x17y';
--match 17
SELECT a FROM t5 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--wait all
SELECT count(*), sum(length(b)) FROM t1;
--match 64 247
SELECT count(*), sum(length(b)) FROM t2;
--match 64 247
SELECT count(*), sum(length(b)) FROM t3;
--match 64 247
SELECT count(*), sum(length(b)) FROM t4;
--match 64 247
SELECT count(*), sum(length(b)) FROM t5;
--match 64 247
--task 1
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 5
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 3
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 2
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 4
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--wait all
--task 5
DROP INDEX t5b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t5b ON t5(b DESC);
--end
--task 3
DROP INDEX t3b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t3b ON t3(b DESC);
--end
--task 1
DROP INDEX t1b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t1b ON t1(b DESC);
--end
--task 2
DROP INDEX t2b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t2b ON t2(b DESC);
--end
--task 4
DROP INDEX t4b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t4b ON t4(b DESC);
--end
--wait all
--task 1
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 5
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 3
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 2
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 4
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--wait all
VACUUM;
PRAGMA integrity_check(10);
--match ok
--task 1
UPDATE t1 SET b=randomblob(20000);
--sleep 5
UPDATE t1 SET b='x'||a||'y';
SELECT a FROM t1 WHERE b='x63y';
--match 63
--end
--task 2
UPDATE t2 SET b=randomblob(20000);
--sleep 5
UPDATE t2 SET b='x'||a||'y';
SELECT a FROM t2 WHERE b='x63y';
--match 63
--end
--task 3
UPDATE t3 SET b=randomblob(20000);
--sleep 5
UPDATE t3 SET b='x'||a||'y';
SELECT a FROM t3 WHERE b='x63y';
--match 63
--end
--task 4
UPDATE t4 SET b=randomblob(20000);
--sleep 5
UPDATE t4 SET b='x'||a||'y';
SELECT a FROM t4 WHERE b='x63y';
--match 63
--end
--task 5
UPDATE t5 SET b=randomblob(20000);
--sleep 5
UPDATE t5 SET b='x'||a||'y';
SELECT a FROM t5 WHERE b='x63y';
--match 63
--end
--wait all
--task 1
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--task 5
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--task 3
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--task 2
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--task 4
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--wait all

View File

@@ -0,0 +1,85 @@
package speedtest1
import (
"bytes"
"context"
"crypto/rand"
"io"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"testing"
_ "embed"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"github.com/ncruces/go-sqlite3/sqlite3vfs"
)
//go:embed testdata/speedtest1.wasm
var binary []byte
var (
rt wazero.Runtime
module wazero.CompiledModule
output bytes.Buffer
options []string
)
func init() {
ctx := context.TODO()
rt = wazero.NewRuntime(ctx)
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
env := sqlite3vfs.ExportHostFunctions(rt.NewHostModuleBuilder("env"))
_, err := env.Instantiate(ctx)
if err != nil {
panic(err)
}
module, err = rt.CompileModule(ctx, binary)
if err != nil {
panic(err)
}
}
func TestMain(m *testing.M) {
i := 1
options = append(options, "speedtest1")
for _, arg := range os.Args[1:] {
if strings.HasPrefix(arg, "-test.") {
os.Args[i] = arg
i++
} else {
options = append(options, arg)
}
}
os.Args = os.Args[:i]
code := m.Run()
io.Copy(os.Stderr, &output)
os.Exit(code)
}
func Benchmark_speedtest1(b *testing.B) {
output.Reset()
ctx, vfs := sqlite3vfs.NewContext(context.Background())
name := filepath.Join(b.TempDir(), "test.db")
args := append(options, "--size", strconv.Itoa(b.N), name)
cfg := wazero.NewModuleConfig().
WithArgs(args...).WithName("speedtest1").
WithStdout(&output).WithStderr(&output).
WithSysWalltime().WithSysNanotime().WithSysNanosleep().
WithOsyield(runtime.Gosched).
WithRandSource(rand.Reader)
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
b.Error(err)
}
mod.Close(ctx)
vfs.Close()
}

View File

@@ -0,0 +1 @@
speedtest1.wasm filter=lfs diff=lfs merge=lfs -text

View File

@@ -0,0 +1 @@
speedtest1.c

23
sqlite3vfs/tests/speedtest1/testdata/build.sh vendored Executable file
View File

@@ -0,0 +1,23 @@
#!/usr/bin/env bash
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../../../../
BINARYEN="$ROOT/tools/binaryen-version_112/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
-o speedtest1.wasm main.c \
-I"$ROOT/sqlite3" \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-Wl,--stack-first \
-Wl,--import-undefined
"$BINARYEN/wasm-opt" -g -O2 speedtest1.wasm -o speedtest1.tmp \
--enable-multivalue --enable-mutable-globals \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext
mv speedtest1.tmp speedtest1.wasm

View File

@@ -0,0 +1,12 @@
#include <stdbool.h>
#include <stddef.h>
// Configuration
#include "sqlite_cfg.h"
// Amalgamation
#include "sqlite3.c"
// VFS
#include "vfs.c"
#define randomFunc(args...) randomFunc2(args)
#include "speedtest1.c"

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2638f5458f93b60f97006456f60e3e294cffb9ac22b20dba748fbd12e34f056d
size 1515841

366
sqlite3vfs/vfs.go Normal file
View File

@@ -0,0 +1,366 @@
package sqlite3vfs
import (
"context"
"crypto/rand"
"io"
"reflect"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/julianday"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
// ExportHostFunctions registers the required VFS host functions
// with the provided env module.
//
// Users of the [github.com/ncruces/go-sqlite3] package need not call this directly.
func ExportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncII(env, "go_vfs_find", vfsFind)
util.ExportFuncIIJ(env, "go_localtime", vfsLocaltime)
util.ExportFuncIIII(env, "go_randomness", vfsRandomness)
util.ExportFuncIII(env, "go_sleep", vfsSleep)
util.ExportFuncIII(env, "go_current_time", vfsCurrentTime)
util.ExportFuncIII(env, "go_current_time_64", vfsCurrentTime64)
util.ExportFuncIIIII(env, "go_full_pathname", vfsFullPathname)
util.ExportFuncIIII(env, "go_delete", vfsDelete)
util.ExportFuncIIIII(env, "go_access", vfsAccess)
util.ExportFuncIIIIII(env, "go_open", vfsOpen)
util.ExportFuncII(env, "go_close", vfsClose)
util.ExportFuncIIIIJ(env, "go_read", vfsRead)
util.ExportFuncIIIIJ(env, "go_write", vfsWrite)
util.ExportFuncIIJ(env, "go_truncate", vfsTruncate)
util.ExportFuncIII(env, "go_sync", vfsSync)
util.ExportFuncIII(env, "go_file_size", vfsFileSize)
util.ExportFuncIIII(env, "go_file_control", vfsFileControl)
util.ExportFuncII(env, "go_sector_size", vfsSectorSize)
util.ExportFuncII(env, "go_device_characteristics", vfsDeviceCharacteristics)
util.ExportFuncIII(env, "go_lock", vfsLock)
util.ExportFuncIII(env, "go_unlock", vfsUnlock)
util.ExportFuncIII(env, "go_check_reserved_lock", vfsCheckReservedLock)
return env
}
type vfsKey struct{}
type vfsState struct {
files []File
}
// NewContext creates a new context to hold [api.Module] specific VFS data.
//
// This context should be passed to any [api.Function] calls that might
// generate VFS host callbacks.
//
// The returned [io.Closer] should be closed after the [api.Module] is closed,
// to release any associated resources.
//
// Users of the [github.com/ncruces/go-sqlite3] package need not call this directly.
func NewContext(ctx context.Context) (context.Context, io.Closer) {
vfs := &vfsState{}
return context.WithValue(ctx, vfsKey{}, vfs), vfs
}
func (vfs *vfsState) Close() error {
for _, f := range vfs.files {
if f != nil {
f.Close()
}
}
vfs.files = nil
return nil
}
func vfsFind(ctx context.Context, mod api.Module, zVfsName uint32) uint32 {
name := util.ReadString(mod, zVfsName, _MAX_STRING)
if Find(name) != nil {
return 1
}
return 0
}
func vfsLocaltime(ctx context.Context, mod api.Module, pTm uint32, t int64) _ErrorCode {
tm := time.Unix(t, 0)
var isdst int
if tm.IsDST() {
isdst = 1
}
const size = 32 / 8
// https://pubs.opengroup.org/onlinepubs/7908799/xsh/time.h.html
util.WriteUint32(mod, pTm+0*size, uint32(tm.Second()))
util.WriteUint32(mod, pTm+1*size, uint32(tm.Minute()))
util.WriteUint32(mod, pTm+2*size, uint32(tm.Hour()))
util.WriteUint32(mod, pTm+3*size, uint32(tm.Day()))
util.WriteUint32(mod, pTm+4*size, uint32(tm.Month()-time.January))
util.WriteUint32(mod, pTm+5*size, uint32(tm.Year()-1900))
util.WriteUint32(mod, pTm+6*size, uint32(tm.Weekday()-time.Sunday))
util.WriteUint32(mod, pTm+7*size, uint32(tm.YearDay()-1))
util.WriteUint32(mod, pTm+8*size, uint32(isdst))
return _OK
}
func vfsRandomness(ctx context.Context, mod api.Module, pVfs, nByte, zByte uint32) uint32 {
mem := util.View(mod, zByte, uint64(nByte))
n, _ := rand.Reader.Read(mem)
return uint32(n)
}
func vfsSleep(ctx context.Context, mod api.Module, pVfs, nMicro uint32) _ErrorCode {
time.Sleep(time.Duration(nMicro) * time.Microsecond)
return _OK
}
func vfsCurrentTime(ctx context.Context, mod api.Module, pVfs, prNow uint32) _ErrorCode {
day := julianday.Float(time.Now())
util.WriteFloat64(mod, prNow, day)
return _OK
}
func vfsCurrentTime64(ctx context.Context, mod api.Module, pVfs, piNow uint32) _ErrorCode {
day, nsec := julianday.Date(time.Now())
msec := day*86_400_000 + nsec/1_000_000
util.WriteUint64(mod, piNow, uint64(msec))
return _OK
}
func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull, zFull uint32) _ErrorCode {
vfs := vfsGet(mod, pVfs)
path := util.ReadString(mod, zRelative, _MAX_PATHNAME)
path, err := vfs.FullPathname(path)
size := uint64(len(path) + 1)
if size > uint64(nFull) {
return _CANTOPEN_FULLPATH
}
mem := util.View(mod, zFull, size)
mem[len(path)] = 0
copy(mem, path)
return vfsErrorCode(err, _CANTOPEN_FULLPATH)
}
func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32) _ErrorCode {
vfs := vfsGet(mod, pVfs)
path := util.ReadString(mod, zPath, _MAX_PATHNAME)
err := vfs.Delete(path, syncDir != 0)
return vfsErrorCode(err, _IOERR_DELETE)
}
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags AccessFlag, pResOut uint32) _ErrorCode {
vfs := vfsGet(mod, pVfs)
path := util.ReadString(mod, zPath, _MAX_PATHNAME)
ok, err := vfs.Access(path, flags)
var res uint32
if ok {
res = 1
}
util.WriteUint32(mod, pResOut, res)
return vfsErrorCode(err, _IOERR_ACCESS)
}
func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, flags OpenFlag, pOutFlags uint32) _ErrorCode {
vfs := vfsGet(mod, pVfs)
var path string
if zPath != 0 {
path = util.ReadString(mod, zPath, _MAX_PATHNAME)
}
file, flags, err := vfs.Open(path, flags)
if err != nil {
return vfsErrorCode(err, _CANTOPEN)
}
vfsFileRegister(ctx, mod, pFile, file)
if pOutFlags != 0 {
util.WriteUint32(mod, pOutFlags, uint32(flags))
}
return _OK
}
func vfsClose(ctx context.Context, mod api.Module, pFile uint32) _ErrorCode {
err := vfsFileClose(ctx, mod, pFile)
if err != nil {
return _IOERR_CLOSE
}
return _OK
}
func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst int64) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
buf := util.View(mod, zBuf, uint64(iAmt))
n, err := file.ReadAt(buf, iOfst)
if n == int(iAmt) {
return _OK
}
if n == 0 && err != io.EOF {
return _IOERR_READ
}
for i := range buf[n:] {
buf[n+i] = 0
}
return _IOERR_SHORT_READ
}
func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst int64) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
buf := util.View(mod, zBuf, uint64(iAmt))
_, err := file.WriteAt(buf, iOfst)
if err != nil {
return _IOERR_WRITE
}
return _OK
}
func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte int64) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
err := file.Truncate(nByte)
return vfsErrorCode(err, _IOERR_TRUNCATE)
}
func vfsSync(ctx context.Context, mod api.Module, pFile uint32, flags SyncFlag) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
err := file.Sync(flags)
return vfsErrorCode(err, _IOERR_FSYNC)
}
func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
size, err := file.FileSize()
util.WriteUint64(mod, pSize, uint64(size))
return vfsErrorCode(err, _IOERR_SEEK)
}
func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock LockLevel) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
err := file.Lock(eLock)
return vfsErrorCode(err, _IOERR_LOCK)
}
func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock LockLevel) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
err := file.Unlock(eLock)
return vfsErrorCode(err, _IOERR_UNLOCK)
}
func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
locked, err := file.CheckReservedLock()
var res uint32
if locked {
res = 1
}
util.WriteUint32(mod, pResOut, res)
return vfsErrorCode(err, _IOERR_CHECKRESERVEDLOCK)
}
func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _FcntlOpcode, pArg uint32) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
switch op {
case _FCNTL_LOCKSTATE:
if file, ok := file.(FileLockState); ok {
util.WriteUint32(mod, pArg, uint32(file.LockState()))
return _OK
}
case _FCNTL_LOCK_TIMEOUT:
if file, ok := file.(*vfsFile); ok {
millis := file.lockTimeout.Milliseconds()
file.lockTimeout = time.Duration(util.ReadUint32(mod, pArg)) * time.Millisecond
util.WriteUint32(mod, pArg, uint32(millis))
return _OK
}
case _FCNTL_POWERSAFE_OVERWRITE:
if file, ok := file.(FilePowersafeOverwrite); ok {
switch util.ReadUint32(mod, pArg) {
case 0:
file.SetPowersafeOverwrite(false)
case 1:
file.SetPowersafeOverwrite(true)
default:
if file.PowersafeOverwrite() {
util.WriteUint32(mod, pArg, 1)
} else {
util.WriteUint32(mod, pArg, 0)
}
}
return _OK
}
case _FCNTL_SIZE_HINT:
if file, ok := file.(FileSizeHint); ok {
size := util.ReadUint64(mod, pArg)
err := file.SizeHint(int64(size))
return vfsErrorCode(err, _IOERR_TRUNCATE)
}
case _FCNTL_HAS_MOVED:
if file, ok := file.(FileHasMoved); ok {
moved, err := file.HasMoved()
var res uint32
if moved {
res = 1
}
util.WriteUint32(mod, pArg, res)
return vfsErrorCode(err, _IOERR_FSTAT)
}
}
// Consider also implementing these opcodes (in use by SQLite):
// _FCNTL_BUSYHANDLER
// _FCNTL_COMMIT_PHASETWO
// _FCNTL_PDB
// _FCNTL_PRAGMA
// _FCNTL_SYNC
return _NOTFOUND
}
func vfsSectorSize(ctx context.Context, mod api.Module, pFile uint32) uint32 {
file := vfsFileGet(ctx, mod, pFile)
return uint32(file.SectorSize())
}
func vfsDeviceCharacteristics(ctx context.Context, mod api.Module, pFile uint32) DeviceCharacteristic {
file := vfsFileGet(ctx, mod, pFile)
return file.DeviceCharacteristics()
}
func vfsGet(mod api.Module, pVfs uint32) VFS {
if pVfs == 0 {
return vfsOS{}
}
const zNameOffset = 16
name := util.ReadString(mod, util.ReadUint32(mod, pVfs+zNameOffset), _MAX_STRING)
if name == "os" {
return vfsOS{}
}
if vfs := Find(name); vfs != nil {
return vfs
}
panic(util.NoVFSErr + util.ErrorString(name))
}
func vfsErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
switch v := reflect.ValueOf(err); v.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32:
return _ErrorCode(v.Uint())
}
return def
}

109
sqlite3vfs/vfs_api.go Normal file
View File

@@ -0,0 +1,109 @@
// Package sqlite3vfs wraps the C SQLite VFS API.
package sqlite3vfs
import "sync"
// A VFS defines the interface between the SQLite core and the underlying operating system.
//
// Use sqlite3.ErrorCode or sqlite3.ExtendedErrorCode to return specific error codes.
//
// https://www.sqlite.org/c3ref/vfs.html
type VFS interface {
Open(name string, flags OpenFlag) (File, OpenFlag, error)
Delete(name string, syncDir bool) error
Access(name string, flags AccessFlag) (bool, error)
FullPathname(name string) (string, error)
}
// A File represents an open file in the OS interface layer.
//
// Use sqlite3.ErrorCode or sqlite3.ExtendedErrorCode to return specific error codes.
// In particular, sqlite3.BUSY is necessary to correctly implement lock methods.
//
// https://www.sqlite.org/c3ref/io_methods.html
type File interface {
Close() error
ReadAt(p []byte, off int64) (n int, err error)
WriteAt(p []byte, off int64) (n int, err error)
Truncate(size int64) error
Sync(flags SyncFlag) error
FileSize() (int64, error)
Lock(lock LockLevel) error
Unlock(lock LockLevel) error
CheckReservedLock() (bool, error)
SectorSize() int
DeviceCharacteristics() DeviceCharacteristic
}
// FileLockState extends [File] to implement the
// SQLITE_FCNTL_LOCKSTATE file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type FileLockState interface {
File
LockState() LockLevel
}
// FileLockState extends [File] to implement the
// SQLITE_FCNTL_SIZE_HINT file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type FileSizeHint interface {
File
SizeHint(size int64) error
}
// FileLockState extends [File] to implement the
// SQLITE_FCNTL_HAS_MOVED file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type FileHasMoved interface {
File
HasMoved() (bool, error)
}
// FileLockState extends [File] to implement the
// SQLITE_FCNTL_POWERSAFE_OVERWRITE file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type FilePowersafeOverwrite interface {
File
PowersafeOverwrite() bool
SetPowersafeOverwrite(bool)
}
var (
vfsRegistry map[string]VFS
vfsRegistryMtx sync.Mutex
)
// Find returns a VFS given its name.
// If there is no match, nil is returned.
//
// https://www.sqlite.org/c3ref/vfs_find.html
func Find(name string) VFS {
vfsRegistryMtx.Lock()
defer vfsRegistryMtx.Unlock()
return vfsRegistry[name]
}
// Register registers a VFS.
//
// https://www.sqlite.org/c3ref/vfs_find.html
func Register(name string, vfs VFS) {
vfsRegistryMtx.Lock()
defer vfsRegistryMtx.Unlock()
if vfsRegistry == nil {
vfsRegistry = map[string]VFS{}
}
vfsRegistry[name] = vfs
}
// Unregister unregisters a VFS.
//
// https://www.sqlite.org/c3ref/vfs_find.html
func Unregister(name string) {
vfsRegistryMtx.Lock()
defer vfsRegistryMtx.Unlock()
delete(vfsRegistry, name)
}

View File

@@ -0,0 +1,48 @@
package sqlite3vfs_test
import (
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/sqlite3vfs"
)
type testVFS struct {
*testing.T
}
func (t testVFS) Open(name string, flags sqlite3vfs.OpenFlag) (sqlite3vfs.File, sqlite3vfs.OpenFlag, error) {
t.Log("Open", name, flags)
t.SkipNow()
return nil, flags, nil
}
func (t testVFS) Delete(name string, syncDir bool) error {
t.Log("Delete", name, syncDir)
return nil
}
func (t testVFS) Access(name string, flags sqlite3vfs.AccessFlag) (bool, error) {
t.Log("Access", name, flags)
return true, nil
}
func (t testVFS) FullPathname(name string) (string, error) {
t.Log("FullPathname", name)
return name, nil
}
func TestRegister(t *testing.T) {
vfs := testVFS{t}
sqlite3vfs.Register("foo", vfs)
defer sqlite3vfs.Unregister("foo")
conn, err := sqlite3.Open("file:file.db?vfs=foo")
if err != nil {
t.Fatal(err)
}
defer conn.Close()
t.Error("want skip")
}

223
sqlite3vfs/vfs_file.go Normal file
View File

@@ -0,0 +1,223 @@
package sqlite3vfs
import (
"context"
"errors"
"io"
"io/fs"
"os"
"path/filepath"
"runtime"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
)
type vfsOS struct{}
func (vfsOS) FullPathname(path string) (string, error) {
path, err := filepath.Abs(path)
if err != nil {
return "", err
}
fi, err := os.Lstat(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return path, nil
}
return "", err
}
if fi.Mode()&fs.ModeSymlink != 0 {
err = _OK_SYMLINK
}
return path, err
}
func (vfsOS) Delete(path string, syncDir bool) error {
err := os.Remove(path)
if errors.Is(err, fs.ErrNotExist) {
return _IOERR_DELETE_NOENT
}
if err != nil {
return err
}
if runtime.GOOS != "windows" && syncDir {
f, err := os.Open(filepath.Dir(path))
if err != nil {
return _OK
}
defer f.Close()
err = osSync(f, false, false)
if err != nil {
return _IOERR_DIR_FSYNC
}
}
return nil
}
func (vfsOS) Access(name string, flags AccessFlag) (bool, error) {
err := osAccess(name, flags)
if flags == ACCESS_EXISTS {
if errors.Is(err, fs.ErrNotExist) {
return false, nil
}
} else {
if errors.Is(err, fs.ErrPermission) {
return false, nil
}
}
return err == nil, err
}
func (vfsOS) Open(name string, flags OpenFlag) (File, OpenFlag, error) {
var oflags int
if flags&OPEN_EXCLUSIVE != 0 {
oflags |= os.O_EXCL
}
if flags&OPEN_CREATE != 0 {
oflags |= os.O_CREATE
}
if flags&OPEN_READONLY != 0 {
oflags |= os.O_RDONLY
}
if flags&OPEN_READWRITE != 0 {
oflags |= os.O_RDWR
}
var err error
var f *os.File
if name == "" {
f, err = os.CreateTemp("", "*.db")
} else {
f, err = osOpenFile(name, oflags, 0666)
}
if err != nil {
return nil, flags, err
}
if flags&OPEN_DELETEONCLOSE != 0 {
os.Remove(f.Name())
}
file := vfsFile{
File: f,
psow: true,
readOnly: flags&OPEN_READONLY != 0,
syncDir: runtime.GOOS != "windows" &&
flags&(OPEN_CREATE) != 0 &&
flags&(OPEN_MAIN_JOURNAL|OPEN_SUPER_JOURNAL|OPEN_WAL) != 0,
}
return &file, flags, nil
}
type vfsFile struct {
*os.File
lockTimeout time.Duration
lock LockLevel
psow bool
syncDir bool
readOnly bool
}
var (
// Ensure these interfaces are implemented:
_ FileLockState = &vfsFile{}
_ FileHasMoved = &vfsFile{}
_ FileSizeHint = &vfsFile{}
_ FilePowersafeOverwrite = &vfsFile{}
)
func vfsFileNew(vfs *vfsState, file File) uint32 {
// Find an empty slot.
for id, f := range vfs.files {
if f == nil {
vfs.files[id] = file
return uint32(id)
}
}
// Add a new slot.
vfs.files = append(vfs.files, file)
return uint32(len(vfs.files) - 1)
}
func vfsFileRegister(ctx context.Context, mod api.Module, pFile uint32, file File) {
const fileHandleOffset = 4
id := vfsFileNew(ctx.Value(vfsKey{}).(*vfsState), file)
util.WriteUint32(mod, pFile+fileHandleOffset, id)
}
func vfsFileGet(ctx context.Context, mod api.Module, pFile uint32) File {
const fileHandleOffset = 4
vfs := ctx.Value(vfsKey{}).(*vfsState)
id := util.ReadUint32(mod, pFile+fileHandleOffset)
return vfs.files[id]
}
func vfsFileClose(ctx context.Context, mod api.Module, pFile uint32) error {
const fileHandleOffset = 4
vfs := ctx.Value(vfsKey{}).(*vfsState)
id := util.ReadUint32(mod, pFile+fileHandleOffset)
file := vfs.files[id]
vfs.files[id] = nil
return file.Close()
}
func (f *vfsFile) Sync(flags SyncFlag) error {
dataonly := (flags & SYNC_DATAONLY) != 0
fullsync := (flags & 0x0f) == SYNC_FULL
err := osSync(f.File, fullsync, dataonly)
if err != nil {
return err
}
if runtime.GOOS != "windows" && f.syncDir {
f.syncDir = false
d, err := os.Open(filepath.Dir(f.File.Name()))
if err != nil {
return nil
}
defer d.Close()
err = osSync(d, false, false)
if err != nil {
return _IOERR_DIR_FSYNC
}
}
return nil
}
func (f *vfsFile) FileSize() (int64, error) {
return f.Seek(0, io.SeekEnd)
}
func (*vfsFile) SectorSize() int {
return _DEFAULT_SECTOR_SIZE
}
func (f *vfsFile) DeviceCharacteristics() DeviceCharacteristic {
if f.psow {
return IOCAP_POWERSAFE_OVERWRITE
}
return 0
}
func (f *vfsFile) SizeHint(size int64) error {
return osAllocate(f.File, size)
}
func (f *vfsFile) HasMoved() (bool, error) {
fi, err := f.Stat()
if err != nil {
return false, err
}
pi, err := os.Stat(f.Name())
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return false, err
}
return !os.SameFile(fi, pi), nil
}
func (f *vfsFile) LockState() LockLevel { return f.lock }
func (f *vfsFile) PowersafeOverwrite() bool { return f.psow }
func (f *vfsFile) SetPowersafeOverwrite(psow bool) { f.psow = psow }

150
sqlite3vfs/vfs_lock.go Normal file
View File

@@ -0,0 +1,150 @@
package sqlite3vfs
import (
"os"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
const (
_PENDING_BYTE = 0x40000000
_RESERVED_BYTE = (_PENDING_BYTE + 1)
_SHARED_FIRST = (_PENDING_BYTE + 2)
_SHARED_SIZE = 510
)
func (file *vfsFile) Lock(eLock LockLevel) error {
// Argument check. SQLite never explicitly requests a pending lock.
if eLock != LOCK_SHARED && eLock != LOCK_RESERVED && eLock != LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
switch {
case file.lock < LOCK_NONE || file.lock > LOCK_EXCLUSIVE:
// Connection state check.
panic(util.AssertErr())
case file.lock == LOCK_NONE && eLock > LOCK_SHARED:
// We never move from unlocked to anything higher than a shared lock.
panic(util.AssertErr())
case file.lock != LOCK_SHARED && eLock == LOCK_RESERVED:
// A shared lock is always held when a reserved lock is requested.
panic(util.AssertErr())
}
// If we already have an equal or more restrictive lock, do nothing.
if file.lock >= eLock {
return nil
}
// Do not allow any kind of write-lock on a read-only database.
if file.readOnly && eLock >= LOCK_RESERVED {
return _IOERR_LOCK
}
switch eLock {
case LOCK_SHARED:
// Must be unlocked to get SHARED.
if file.lock != LOCK_NONE {
panic(util.AssertErr())
}
if rc := osGetSharedLock(file.File, file.lockTimeout); rc != _OK {
return rc
}
file.lock = LOCK_SHARED
return nil
case LOCK_RESERVED:
// Must be SHARED to get RESERVED.
if file.lock != LOCK_SHARED {
panic(util.AssertErr())
}
if rc := osGetReservedLock(file.File, file.lockTimeout); rc != _OK {
return rc
}
file.lock = LOCK_RESERVED
return nil
case LOCK_EXCLUSIVE:
// Must be SHARED, RESERVED or PENDING to get EXCLUSIVE.
if file.lock <= LOCK_NONE || file.lock >= LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
// A PENDING lock is needed before acquiring an EXCLUSIVE lock.
if file.lock < LOCK_PENDING {
if rc := osGetPendingLock(file.File); rc != _OK {
return rc
}
file.lock = LOCK_PENDING
}
if rc := osGetExclusiveLock(file.File, file.lockTimeout); rc != _OK {
return rc
}
file.lock = LOCK_EXCLUSIVE
return nil
default:
panic(util.AssertErr())
}
}
func (file *vfsFile) Unlock(eLock LockLevel) error {
// Argument check.
if eLock != LOCK_NONE && eLock != LOCK_SHARED {
panic(util.AssertErr())
}
// Connection state check.
if file.lock < LOCK_NONE || file.lock > LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
// If we don't have a more restrictive lock, do nothing.
if file.lock <= eLock {
return nil
}
switch eLock {
case LOCK_SHARED:
if rc := osDowngradeLock(file.File, file.lock); rc != _OK {
return rc
}
file.lock = LOCK_SHARED
return nil
case LOCK_NONE:
rc := osReleaseLock(file.File, file.lock)
file.lock = LOCK_NONE
return rc
default:
panic(util.AssertErr())
}
}
func (file *vfsFile) CheckReservedLock() (bool, error) {
// Connection state check.
if file.lock < LOCK_NONE || file.lock > LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
if file.lock >= LOCK_RESERVED {
return true, nil
}
return osCheckReservedLock(file.File)
}
func osGetReservedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Acquire the RESERVED lock.
return osWriteLock(file, _RESERVED_BYTE, 1, timeout)
}
func osGetPendingLock(file *os.File) _ErrorCode {
// Acquire the PENDING lock.
return osWriteLock(file, _PENDING_BYTE, 1, 0)
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
// Test the RESERVED lock.
return osCheckLock(file, _RESERVED_BYTE, 1)
}

219
sqlite3vfs/vfs_lock_test.go Normal file
View File

@@ -0,0 +1,219 @@
package sqlite3vfs
import (
"context"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
)
func Test_vfsLock(t *testing.T) {
switch runtime.GOOS {
case "linux", "darwin", "windows":
break
default:
t.Skip("OS lacks OFD locks")
}
name := filepath.Join(t.TempDir(), "test.db")
// Create a temporary file.
file1, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
if err != nil {
t.Fatal(err)
}
defer file1.Close()
// Open the temporary file again.
file2, err := os.OpenFile(name, os.O_RDWR, 0)
if err != nil {
t.Fatal(err)
}
defer file2.Close()
const (
pFile1 = 4
pFile2 = 16
pOutput = 32
)
mod := util.NewMockModule(128)
ctx, vfs := NewContext(context.TODO())
defer vfs.Close()
vfsFileRegister(ctx, mod, pFile1, &vfsFile{File: file1})
vfsFileRegister(ctx, mod, pFile2, &vfsFile{File: file2})
rc := vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_NONE) {
t.Error("invalid lock state", got)
}
rc = vfsLock(ctx, mod, pFile2, LOCK_SHARED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) {
t.Error("invalid lock state", got)
}
rc = vfsLock(ctx, mod, pFile2, LOCK_RESERVED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsLock(ctx, mod, pFile2, LOCK_SHARED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_RESERVED) {
t.Error("invalid lock state", got)
}
rc = vfsLock(ctx, mod, pFile2, LOCK_EXCLUSIVE)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_EXCLUSIVE) {
t.Error("invalid lock state", got)
}
rc = vfsLock(ctx, mod, pFile1, LOCK_SHARED)
if rc == _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCKSTATE, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_NONE) {
t.Error("invalid lock state", got)
}
rc = vfsUnlock(ctx, mod, pFile2, LOCK_SHARED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(ctx, mod, pFile1, LOCK_SHARED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCKSTATE, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) {
t.Error("invalid lock state", got)
}
rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCK_TIMEOUT, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
}

56
sqlite3vfs/vfs_os_bsd.go Normal file
View File

@@ -0,0 +1,56 @@
//go:build freebsd || openbsd || netbsd || dragonfly || (darwin && sqlite3_bsd)
package sqlite3vfs
import (
"os"
"time"
"golang.org/x/sys/unix"
)
func osUnlock(file *os.File, start, len int64) _ErrorCode {
if start == 0 && len == 0 {
err := unix.Flock(int(file.Fd()), unix.LOCK_UN)
if err != nil {
return _IOERR_UNLOCK
}
}
return _OK
}
func osLock(file *os.File, how int, timeout time.Duration, def _ErrorCode) _ErrorCode {
var err error
for {
err = unix.Flock(int(file.Fd()), how)
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
break
}
if timeout < time.Millisecond {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)
}
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.LOCK_SH|unix.LOCK_NB, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.LOCK_EX|unix.LOCK_NB, timeout, _IOERR_LOCK)
}
func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) {
lock := unix.Flock_t{
Type: unix.F_RDLCK,
Start: start,
Len: len,
}
if unix.FcntlFlock(file.Fd(), unix.F_GETLK, &lock) != nil {
return false, _IOERR_CHECKRESERVEDLOCK
}
return lock.Type != unix.F_UNLCK, _OK
}

106
sqlite3vfs/vfs_os_darwin.go Normal file
View File

@@ -0,0 +1,106 @@
//go:build !sqlite3_bsd
package sqlite3vfs
import (
"io"
"os"
"time"
"golang.org/x/sys/unix"
)
const (
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
_F_OFD_SETLK = 90
_F_OFD_SETLKW = 91
_F_OFD_GETLK = 92
_F_OFD_SETLKWTIMEOUT = 93
)
type flocktimeout_t struct {
fl unix.Flock_t
timeout unix.Timespec
}
func osSync(file *os.File, fullsync, dataonly bool) error {
if fullsync {
return file.Sync()
}
return unix.Fsync(int(file.Fd()))
}
func osAllocate(file *os.File, size int64) error {
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return err
}
if size <= off {
return nil
}
// https://stackoverflow.com/a/11497568/867786
store := unix.Fstore_t{
Flags: unix.F_ALLOCATECONTIG,
Posmode: unix.F_PEOFPOSMODE,
Offset: 0,
Length: size,
}
// Try to get a continuous chunk of disk space.
err = unix.FcntlFstore(file.Fd(), unix.F_PREALLOCATE, &store)
if err != nil {
// OK, perhaps we are too fragmented, allocate non-continuous.
store.Flags = unix.F_ALLOCATEALL
unix.FcntlFstore(file.Fd(), unix.F_PREALLOCATE, &store)
}
return file.Truncate(size)
}
func osUnlock(file *os.File, start, len int64) _ErrorCode {
err := unix.FcntlFlock(file.Fd(), _F_OFD_SETLK, &unix.Flock_t{
Type: unix.F_UNLCK,
Start: start,
Len: len,
})
if err != nil {
return _IOERR_UNLOCK
}
return _OK
}
func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, def _ErrorCode) _ErrorCode {
lock := flocktimeout_t{fl: unix.Flock_t{
Type: typ,
Start: start,
Len: len,
}}
var err error
if timeout == 0 {
err = unix.FcntlFlock(file.Fd(), _F_OFD_SETLK, &lock.fl)
} else {
lock.timeout = unix.NsecToTimespec(int64(timeout / time.Nanosecond))
err = unix.FcntlFlock(file.Fd(), _F_OFD_SETLKWTIMEOUT, &lock.fl)
}
return osLockErrorCode(err, def)
}
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_RDLCK, start, len, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_WRLCK, start, len, timeout, _IOERR_LOCK)
}
func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) {
lock := unix.Flock_t{
Type: unix.F_RDLCK,
Start: start,
Len: len,
}
if unix.FcntlFlock(file.Fd(), _F_OFD_GETLK, &lock) != nil {
return false, _IOERR_CHECKRESERVEDLOCK
}
return lock.Type != unix.F_UNLCK, _OK
}

View File

@@ -0,0 +1,25 @@
package sqlite3vfs
import (
"os"
"golang.org/x/sys/unix"
)
func osSync(file *os.File, fullsync, dataonly bool) error {
if dataonly {
_, _, err := unix.Syscall(unix.SYS_FDATASYNC, file.Fd(), 0, 0)
if err != 0 {
return err
}
return nil
}
return file.Sync()
}
func osAllocate(file *os.File, size int64) error {
if size == 0 {
return nil
}
return unix.Fallocate(int(file.Fd()), 0, 0, size)
}

63
sqlite3vfs/vfs_os_ofd.go Normal file
View File

@@ -0,0 +1,63 @@
//go:build linux || illumos
package sqlite3vfs
import (
"os"
"time"
"golang.org/x/sys/unix"
)
func osUnlock(file *os.File, start, len int64) _ErrorCode {
err := unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &unix.Flock_t{
Type: unix.F_UNLCK,
Start: start,
Len: len,
})
if err != nil {
return _IOERR_UNLOCK
}
return _OK
}
func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, def _ErrorCode) _ErrorCode {
lock := unix.Flock_t{
Type: typ,
Start: start,
Len: len,
}
var err error
for {
err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock)
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
break
}
if timeout < time.Millisecond {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)
}
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_RDLCK, start, len, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_WRLCK, start, len, timeout, _IOERR_LOCK)
}
func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) {
lock := unix.Flock_t{
Type: unix.F_RDLCK,
Start: start,
Len: len,
}
if unix.FcntlFlock(file.Fd(), unix.F_OFD_GETLK, &lock) != nil {
return false, _IOERR_CHECKRESERVEDLOCK
}
return lock.Type != unix.F_UNLCK, _OK
}

View File

@@ -0,0 +1,23 @@
//go:build !linux && (!darwin || sqlite3_bsd)
package sqlite3vfs
import (
"io"
"os"
)
func osSync(file *os.File, fullsync, dataonly bool) error {
return file.Sync()
}
func osAllocate(file *os.File, size int64) error {
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return err
}
if size <= off {
return nil
}
return file.Truncate(size)
}

87
sqlite3vfs/vfs_os_unix.go Normal file
View File

@@ -0,0 +1,87 @@
//go:build unix
package sqlite3vfs
import (
"io/fs"
"os"
"time"
"golang.org/x/sys/unix"
)
func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(name, flag, perm)
}
func osAccess(path string, flags AccessFlag) error {
var access uint32 // unix.F_OK
switch flags {
case ACCESS_READWRITE:
access = unix.R_OK | unix.W_OK
case ACCESS_READ:
access = unix.R_OK
}
return unix.Access(path, access)
}
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Test the PENDING lock before acquiring a new SHARED lock.
if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending {
return _BUSY
}
// Acquire the SHARED lock.
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
}
func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
if timeout == 0 {
timeout = time.Millisecond
}
// Acquire the EXCLUSIVE lock.
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
}
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
if state >= LOCK_EXCLUSIVE {
// Downgrade to a SHARED lock.
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return _IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return _IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return osUnlock(file, _PENDING_BYTE, 2)
}
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
// Release all locks.
return osUnlock(file, 0, 0)
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(unix.Errno); ok {
switch errno {
case
unix.EACCES,
unix.EAGAIN,
unix.EBUSY,
unix.EINTR,
unix.ENOLCK,
unix.EDEADLK,
unix.ETIMEDOUT:
return _BUSY
case unix.EPERM:
return _PERM
}
}
return def
}

View File

@@ -0,0 +1,240 @@
package sqlite3vfs
import (
"io/fs"
"os"
"syscall"
"time"
"golang.org/x/sys/windows"
)
// osOpenFile is a simplified copy of [os.openFileNolog]
// that uses syscall.FILE_SHARE_DELETE.
// https://go.dev/src/os/file_windows.go
//
// See: https://go.dev/issue/32088
func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
if name == "" {
return nil, &os.PathError{Op: "open", Path: name, Err: syscall.ENOENT}
}
r, e := syscallOpen(name, flag, uint32(perm.Perm()))
if e != nil {
return nil, &os.PathError{Op: "open", Path: name, Err: e}
}
return os.NewFile(uintptr(r), name), nil
}
func osAccess(path string, flags AccessFlag) error {
fi, err := os.Stat(path)
if err != nil {
return err
}
if flags == ACCESS_EXISTS {
return nil
}
var want fs.FileMode = windows.S_IRUSR
if flags == ACCESS_READWRITE {
want |= windows.S_IWUSR
}
if fi.IsDir() {
want |= windows.S_IXUSR
}
if fi.Mode()&want != want {
return fs.ErrPermission
}
return nil
}
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Acquire the PENDING lock temporarily before acquiring a new SHARED lock.
rc := osReadLock(file, _PENDING_BYTE, 1, timeout)
if rc == _OK {
// Acquire the SHARED lock.
rc = osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
// Release the PENDING lock.
osUnlock(file, _PENDING_BYTE, 1)
}
return rc
}
func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
if timeout == 0 {
timeout = time.Millisecond
}
// Release the SHARED lock.
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Acquire the EXCLUSIVE lock.
rc := osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
if rc != _OK {
// Reacquire the SHARED lock.
osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
}
return rc
}
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
if state >= LOCK_EXCLUSIVE {
// Release the SHARED lock.
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Reacquire the SHARED lock.
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// This should never happen.
// We should always be able to reacquire the read lock.
return _IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
if state >= LOCK_RESERVED {
osUnlock(file, _RESERVED_BYTE, 1)
}
if state >= LOCK_PENDING {
osUnlock(file, _PENDING_BYTE, 1)
}
return _OK
}
func osReleaseLock(file *os.File, state LockLevel) _ErrorCode {
// Release all locks.
if state >= LOCK_RESERVED {
osUnlock(file, _RESERVED_BYTE, 1)
}
if state >= LOCK_SHARED {
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
}
if state >= LOCK_PENDING {
osUnlock(file, _PENDING_BYTE, 1)
}
return _OK
}
func osUnlock(file *os.File, start, len uint32) _ErrorCode {
err := windows.UnlockFileEx(windows.Handle(file.Fd()),
0, len, 0, &windows.Overlapped{Offset: start})
if err == windows.ERROR_NOT_LOCKED {
return _OK
}
if err != nil {
return _IOERR_UNLOCK
}
return _OK
}
func osLock(file *os.File, flags, start, len uint32, timeout time.Duration, def _ErrorCode) _ErrorCode {
var err error
for {
err = windows.LockFileEx(windows.Handle(file.Fd()), flags,
0, len, 0, &windows.Overlapped{Offset: start})
if errno, _ := err.(windows.Errno); errno != windows.ERROR_LOCK_VIOLATION {
break
}
if timeout < time.Millisecond {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)
}
func osReadLock(file *os.File, start, len uint32, timeout time.Duration) _ErrorCode {
return osLock(file,
windows.LOCKFILE_FAIL_IMMEDIATELY,
start, len, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len uint32, timeout time.Duration) _ErrorCode {
return osLock(file,
windows.LOCKFILE_FAIL_IMMEDIATELY|windows.LOCKFILE_EXCLUSIVE_LOCK,
start, len, timeout, _IOERR_LOCK)
}
func osCheckLock(file *os.File, start, len uint32) (bool, _ErrorCode) {
rc := osLock(file,
windows.LOCKFILE_FAIL_IMMEDIATELY,
start, len, 0, _IOERR_CHECKRESERVEDLOCK)
if rc == _BUSY {
return true, _OK
}
if rc == _OK {
osUnlock(file, start, len)
}
return false, rc
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(windows.Errno); ok {
// https://devblogs.microsoft.com/oldnewthing/20140905-00/?p=63
switch errno {
case
windows.ERROR_LOCK_VIOLATION,
windows.ERROR_IO_PENDING,
windows.ERROR_OPERATION_ABORTED:
return _BUSY
}
}
return def
}
// syscallOpen is a simplified copy of [syscall.Open]
// that uses syscall.FILE_SHARE_DELETE.
// https://go.dev/src/syscall/syscall_windows.go
func syscallOpen(path string, mode int, perm uint32) (fd syscall.Handle, err error) {
if len(path) == 0 {
return syscall.InvalidHandle, syscall.ERROR_FILE_NOT_FOUND
}
pathp, err := syscall.UTF16PtrFromString(path)
if err != nil {
return syscall.InvalidHandle, err
}
var access uint32
switch mode & (syscall.O_RDONLY | syscall.O_WRONLY | syscall.O_RDWR) {
case syscall.O_RDONLY:
access = syscall.GENERIC_READ
case syscall.O_WRONLY:
access = syscall.GENERIC_WRITE
case syscall.O_RDWR:
access = syscall.GENERIC_READ | syscall.GENERIC_WRITE
}
if mode&syscall.O_CREAT != 0 {
access |= syscall.GENERIC_WRITE
}
if mode&syscall.O_APPEND != 0 {
access &^= syscall.GENERIC_WRITE
access |= syscall.FILE_APPEND_DATA
}
sharemode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE)
var createmode uint32
switch {
case mode&(syscall.O_CREAT|syscall.O_EXCL) == (syscall.O_CREAT | syscall.O_EXCL):
createmode = syscall.CREATE_NEW
case mode&(syscall.O_CREAT|syscall.O_TRUNC) == (syscall.O_CREAT | syscall.O_TRUNC):
createmode = syscall.CREATE_ALWAYS
case mode&syscall.O_CREAT == syscall.O_CREAT:
createmode = syscall.OPEN_ALWAYS
case mode&syscall.O_TRUNC == syscall.O_TRUNC:
createmode = syscall.TRUNCATE_EXISTING
default:
createmode = syscall.OPEN_EXISTING
}
var attrs uint32 = syscall.FILE_ATTRIBUTE_NORMAL
if perm&syscall.S_IWRITE == 0 {
attrs = syscall.FILE_ATTRIBUTE_READONLY
}
if createmode == syscall.OPEN_EXISTING && access == syscall.GENERIC_READ {
// Necessary for opening directory handles.
attrs |= syscall.FILE_FLAG_BACKUP_SEMANTICS
}
return syscall.CreateFile(pathp, access, sharemode, nil, createmode, attrs, 0)
}

344
sqlite3vfs/vfs_test.go Normal file
View File

@@ -0,0 +1,344 @@
package sqlite3vfs
import (
"bytes"
"context"
"errors"
"io/fs"
"math"
"os"
"path/filepath"
"syscall"
"testing"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/julianday"
)
func Test_vfsLocaltime(t *testing.T) {
mod := util.NewMockModule(128)
ctx := context.TODO()
tm := time.Now()
rc := vfsLocaltime(ctx, mod, 4, tm.Unix())
if rc != 0 {
t.Fatal("returned", rc)
}
if s := util.ReadUint32(mod, 4+0*4); int(s) != tm.Second() {
t.Error("wrong second")
}
if m := util.ReadUint32(mod, 4+1*4); int(m) != tm.Minute() {
t.Error("wrong minute")
}
if h := util.ReadUint32(mod, 4+2*4); int(h) != tm.Hour() {
t.Error("wrong hour")
}
if d := util.ReadUint32(mod, 4+3*4); int(d) != tm.Day() {
t.Error("wrong day")
}
if m := util.ReadUint32(mod, 4+4*4); time.Month(1+m) != tm.Month() {
t.Error("wrong month")
}
if y := util.ReadUint32(mod, 4+5*4); 1900+int(y) != tm.Year() {
t.Error("wrong year")
}
if w := util.ReadUint32(mod, 4+6*4); time.Weekday(w) != tm.Weekday() {
t.Error("wrong weekday")
}
if d := util.ReadUint32(mod, 4+7*4); int(d) != tm.YearDay()-1 {
t.Error("wrong yearday")
}
}
func Test_vfsRandomness(t *testing.T) {
mod := util.NewMockModule(128)
ctx := context.TODO()
rc := vfsRandomness(ctx, mod, 0, 16, 4)
if rc != 16 {
t.Fatal("returned", rc)
}
var zero [16]byte
if got := util.View(mod, 4, 16); bytes.Equal(got, zero[:]) {
t.Fatal("all zero")
}
}
func Test_vfsSleep(t *testing.T) {
mod := util.NewMockModule(128)
ctx := context.TODO()
now := time.Now()
rc := vfsSleep(ctx, mod, 0, 123456)
if rc != 0 {
t.Fatal("returned", rc)
}
want := 123456 * time.Microsecond
if got := time.Since(now); got < want {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsCurrentTime(t *testing.T) {
mod := util.NewMockModule(128)
ctx := context.TODO()
now := time.Now()
rc := vfsCurrentTime(ctx, mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
want := julianday.Float(now)
if got := util.ReadFloat64(mod, 4); float32(got) != float32(want) {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsCurrentTime64(t *testing.T) {
mod := util.NewMockModule(128)
ctx := context.TODO()
now := time.Now()
time.Sleep(time.Millisecond)
rc := vfsCurrentTime64(ctx, mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
day, nsec := julianday.Date(now)
want := day*86_400_000 + nsec/1_000_000
if got := util.ReadUint64(mod, 4); float32(got) != float32(want) {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsFullPathname(t *testing.T) {
mod := util.NewMockModule(128 + _MAX_PATHNAME)
util.WriteString(mod, 4, ".")
ctx := context.TODO()
rc := vfsFullPathname(ctx, mod, 0, 4, 0, 8)
if rc != _CANTOPEN_FULLPATH {
t.Errorf("returned %d, want %d", rc, _CANTOPEN_FULLPATH)
}
rc = vfsFullPathname(ctx, mod, 0, 4, _MAX_PATHNAME, 8)
if rc != _OK {
t.Fatal("returned", rc)
}
want, _ := filepath.Abs(".")
if got := util.ReadString(mod, 8, _MAX_PATHNAME); got != want {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsDelete(t *testing.T) {
name := filepath.Join(t.TempDir(), "test.db")
file, err := os.Create(name)
if err != nil {
t.Fatal(err)
}
file.Close()
mod := util.NewMockModule(128 + _MAX_PATHNAME)
util.WriteString(mod, 4, name)
ctx := context.TODO()
rc := vfsDelete(ctx, mod, 0, 4, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
if _, err := os.Stat(name); !errors.Is(err, fs.ErrNotExist) {
t.Fatal("did not delete the file")
}
rc = vfsDelete(ctx, mod, 0, 4, 1)
if rc != _IOERR_DELETE_NOENT {
t.Fatal("returned", rc)
}
}
func Test_vfsAccess(t *testing.T) {
dir := t.TempDir()
file := filepath.Join(t.TempDir(), "test.db")
if f, err := os.Create(file); err != nil {
t.Fatal(err)
} else {
f.Close()
}
if err := os.Chmod(file, syscall.S_IRUSR); err != nil {
t.Fatal(err)
}
mod := util.NewMockModule(128 + _MAX_PATHNAME)
util.WriteString(mod, 8, dir)
ctx := context.TODO()
rc := vfsAccess(ctx, mod, 0, 8, ACCESS_EXISTS, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, 4); got != 1 {
t.Error("directory did not exist")
}
rc = vfsAccess(ctx, mod, 0, 8, ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, 4); got != 1 {
t.Error("can't access directory")
}
util.WriteString(mod, 8, file)
rc = vfsAccess(ctx, mod, 0, 8, ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, 4); got != 0 {
t.Error("can access file")
}
}
func Test_vfsFile(t *testing.T) {
mod := util.NewMockModule(128)
ctx, vfs := NewContext(context.TODO())
defer vfs.Close()
// Open a temporary file.
rc := vfsOpen(ctx, mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check sector size.
if size := vfsSectorSize(ctx, mod, 4); size != _DEFAULT_SECTOR_SIZE {
t.Fatal("returned", size)
}
// Write stuff.
text := "Hello world!"
util.WriteString(mod, 16, text)
rc = vfsWrite(ctx, mod, 4, 16, uint32(len(text)), 0)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check file size.
rc = vfsFileSize(ctx, mod, 4, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, 16); got != uint32(len(text)) {
t.Errorf("got %d", got)
}
// Partial read at offset.
rc = vfsRead(ctx, mod, 4, 16, uint32(len(text)), 4)
if rc != _IOERR_SHORT_READ {
t.Fatal("returned", rc)
}
if got := util.ReadString(mod, 16, 64); got != text[4:] {
t.Errorf("got %q", got)
}
// Truncate the file.
rc = vfsTruncate(ctx, mod, 4, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check file size.
rc = vfsFileSize(ctx, mod, 4, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, 16); got != 4 {
t.Errorf("got %d", got)
}
// Read at offset.
rc = vfsRead(ctx, mod, 4, 32, 4, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadString(mod, 32, 64); got != text[:4] {
t.Errorf("got %q", got)
}
// Close the file.
rc = vfsClose(ctx, mod, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
}
func Test_vfsFile_psow(t *testing.T) {
mod := util.NewMockModule(128)
ctx, vfs := NewContext(context.TODO())
defer vfs.Close()
// Open a temporary file.
rc := vfsOpen(ctx, mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
// Read powersafe overwrite.
util.WriteUint32(mod, 16, math.MaxUint32)
rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, 16); got == 0 {
t.Error("psow disabled")
}
// Unset powersafe overwrite.
util.WriteUint32(mod, 16, 0)
rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
// Read powersafe overwrite.
util.WriteUint32(mod, 16, math.MaxUint32)
rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, 16); got != 0 {
t.Error("psow enabled")
}
// Set powersafe overwrite.
util.WriteUint32(mod, 16, 1)
rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
// Read powersafe overwrite.
util.WriteUint32(mod, 16, math.MaxUint32)
rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, 16); got == 0 {
t.Error("psow disabled")
}
// Close the file.
rc = vfsClose(ctx, mod, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
}

93
stmt.go
View File

@@ -3,6 +3,8 @@ package sqlite3
import (
"math"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Stmt is a prepared statement object.
@@ -10,8 +12,8 @@ import (
// https://www.sqlite.org/c3ref/stmt.html
type Stmt struct {
c *Conn
handle uint32
err error
handle uint32
}
// Close destroys the prepared statement object.
@@ -119,7 +121,7 @@ func (s *Stmt) BindName(param int) string {
if ptr == 0 {
return ""
}
return s.c.mem.readString(ptr, _MAX_STRING)
return util.ReadString(s.c.mod, ptr, _MAX_STRING)
}
// BindBool binds a bool to the prepared statement.
@@ -172,7 +174,7 @@ func (s *Stmt) BindText(param int, value string) error {
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
s.c.api.destructor, _UTF8)
uint64(s.c.api.destructor), _UTF8)
return s.c.error(r[0])
}
@@ -186,7 +188,7 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
r := s.c.call(s.c.api.bindBlob,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
s.c.api.destructor)
uint64(s.c.api.destructor))
return s.c.error(r[0])
}
@@ -215,6 +217,9 @@ func (s *Stmt) BindNull(param int) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
if format == TimeFormatDefault {
return s.bindRFC3339Nano(param, value)
}
switch v := format.Encode(value).(type) {
case string:
s.BindText(param, v)
@@ -223,11 +228,25 @@ func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
case float64:
s.BindFloat(param, v)
default:
panic(assertErr())
panic(util.AssertErr())
}
return nil
}
func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
const maxlen = uint64(len(time.RFC3339Nano))
ptr := s.c.new(maxlen)
buf := util.View(s.c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(buf)),
uint64(s.c.api.destructor), _UTF8)
return s.c.error(r[0])
}
// ColumnCount returns the number of columns in a result set.
//
// https://www.sqlite.org/c3ref/column_count.html
@@ -247,9 +266,9 @@ func (s *Stmt) ColumnName(col int) string {
ptr := uint32(r[0])
if ptr == 0 {
panic(oomErr)
panic(util.OOMErr)
}
return s.c.mem.readString(ptr, _MAX_STRING)
return util.ReadString(s.c.mod, ptr, _MAX_STRING)
}
// ColumnType returns the initial [Datatype] of the result column.
@@ -320,7 +339,7 @@ func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time {
case NULL:
return time.Time{}
default:
panic(assertErr())
panic(util.AssertErr())
}
t, err := format.Decode(v)
if err != nil {
@@ -334,21 +353,7 @@ func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnText(col int) string {
r := s.c.call(s.c.api.columnText,
uint64(s.handle), uint64(col))
ptr := uint32(r[0])
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.handle))
s.err = s.c.error(r[0])
return ""
}
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
mem := s.c.mem.view(ptr, r[0])
return string(mem)
return string(s.ColumnRawText(col))
}
// ColumnBlob appends to buf and returns
@@ -357,21 +362,53 @@ func (s *Stmt) ColumnText(col int) string {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
r := s.c.call(s.c.api.columnBlob,
return append(buf, s.ColumnRawBlob(col)...)
}
// ColumnRawText returns the value of the result column as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnRawText(col int) []byte {
r := s.c.call(s.c.api.columnText,
uint64(s.handle), uint64(col))
ptr := uint32(r[0])
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.handle))
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r[0])
return buf[0:0]
return nil
}
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
mem := s.c.mem.view(ptr, r[0])
return append(buf[0:0], mem...)
return util.View(s.c.mod, ptr, r[0])
}
// ColumnRawBlob returns the value of the result column as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnRawBlob(col int) []byte {
r := s.c.call(s.c.api.columnBlob,
uint64(s.handle), uint64(col))
ptr := uint32(r[0])
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r[0])
return nil
}
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
return util.View(s.c.mod, ptr, r[0])
}
// Return true if stmt is an empty SQL statement.

127
tests/backup_test.go Normal file
View File

@@ -0,0 +1,127 @@
package tests
import (
"path/filepath"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestBackup(t *testing.T) {
t.Parallel()
backupName := filepath.Join(t.TempDir(), "backup.db")
func() { // Create backup.
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}
err = db.Backup("main", backupName)
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}()
func() { // Restore backup.
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Restore("main", backupName)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
row := 0
ids := []int{0, 1, 2}
names := []string{"go", "zig", "whatever"}
for ; stmt.Step(); row++ {
id := stmt.ColumnInt(0)
name := stmt.ColumnText(1)
if id != ids[row] {
t.Errorf("got %d, want %d", id, ids[row])
}
if name != names[row] {
t.Errorf("got %q, want %q", name, names[row])
}
}
if row != 3 {
t.Errorf("got %d, want %d", row, len(ids))
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}()
func() { // Errors.
db, err := sqlite3.Open(backupName)
if err != nil {
t.Fatal(err)
}
defer db.Close()
tx, err := db.BeginExclusive()
if err != nil {
t.Fatal(err)
}
err = db.Restore("main", backupName)
if err == nil {
t.Fatal("want error")
}
err = tx.Rollback()
if err != nil {
t.Fatal(err)
}
err = db.Restore("main", backupName)
if err == nil {
t.Fatal("want error")
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}()
}

View File

@@ -5,6 +5,7 @@ import (
"crypto/rand"
"errors"
"fmt"
"hash/adler32"
"io"
"testing"
@@ -48,17 +49,17 @@ func TestBlob(t *testing.T) {
t.Fatal(err)
}
_, err = io.Copy(blob, bytes.NewReader(data[:size/2]))
_, err = blob.Write(data[:size/2])
if err != nil {
t.Fatal(err)
}
_, err = io.Copy(blob, bytes.NewReader(data[:]))
if !errors.Is(err, sqlite3.ERROR) {
t.Fatal("want error")
n, err := blob.Write(data[:])
if n != 0 || !errors.Is(err, sqlite3.ERROR) {
t.Fatalf("got (%d, %v), want (0, ERROR)", n, err)
}
_, err = io.Copy(blob, bytes.NewReader(data[size/2:size]))
_, err = blob.Write(data[size/2 : size])
if err != nil {
t.Fatal(err)
}
@@ -87,6 +88,126 @@ func TestBlob(t *testing.T) {
}
}
func TestBlob_large(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1000000))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
size := blob.Size()
if size != 1000000 {
t.Errorf("got %d, want 1000000", size)
}
hash := adler32.New()
_, err = io.CopyN(blob, io.TeeReader(rand.Reader, hash), 1000000)
if err != nil {
t.Fatal(err)
}
_, err = blob.Seek(0, io.SeekStart)
if err != nil {
t.Fatal(err)
}
want := hash.Sum32()
hash.Reset()
_, err = io.Copy(hash, blob)
if err != nil {
t.Fatal(err)
}
if got := hash.Sum32(); got != want {
t.Fatalf("got %d, want %d", got, want)
}
if err := blob.Close(); err != nil {
t.Fatal(err)
}
if err := db.Close(); err != nil {
t.Fatal(err)
}
}
func TestBlob_overflow(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
n, err := blob.ReadFrom(rand.Reader)
if n != 1024 || !errors.Is(err, sqlite3.ERROR) {
t.Fatalf("got (%d, %v), want (0, ERROR)", n, err)
}
n, err = blob.ReadFrom(rand.Reader)
if n != 0 || !errors.Is(err, sqlite3.ERROR) {
t.Fatalf("got (%d, %v), want (0, ERROR)", n, err)
}
_, err = blob.Seek(-128, io.SeekEnd)
if err != nil {
t.Fatal(err)
}
n, err = blob.WriteTo(io.Discard)
if n != 128 || err != nil {
t.Fatalf("got (%d, %v), want (128, nil)", n, err)
}
n, err = blob.WriteTo(io.Discard)
if n != 0 || err != nil {
t.Fatalf("got (%d, %v), want (0, nil)", n, err)
}
if err := blob.Close(); err != nil {
t.Fatal(err)
}
if err := db.Close(); err != nil {
t.Fatal(err)
}
}
func TestBlob_invalid(t *testing.T) {
t.Parallel()

View File

@@ -126,7 +126,7 @@ func testTxQuery(t params) {
if r.Err() != nil {
t.Fatal(err)
}
t.Fatal("expected one rows")
t.Fatal("expected one row")
}
var name string

View File

@@ -3,6 +3,7 @@ package tests
import (
"context"
"errors"
"reflect"
"strings"
"testing"
@@ -13,7 +14,7 @@ import (
func TestConn_Open_dir(t *testing.T) {
t.Parallel()
_, err := sqlite3.Open(".")
_, err := sqlite3.OpenFlags(".", 0)
if err == nil {
t.Fatal("want error")
}
@@ -58,6 +59,41 @@ func TestConn_Close_BUSY(t *testing.T) {
}
}
func TestConn_Pragma(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open("file::memory:?_pragma=busy_timeout(1000)")
if err != nil {
t.Fatal(err)
}
defer db.Close()
got, err := db.Pragma("busy_timeout")
if err != nil {
t.Fatal(err)
}
want := []string{"1000"}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
var serr *sqlite3.Error
_, err = db.Pragma("+")
if err == nil {
t.Error("want: error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: near "+": syntax error` {
t.Error("got message:", got)
}
}
func TestConn_SetInterrupt(t *testing.T) {
t.Parallel()
@@ -202,59 +238,3 @@ func TestConn_Prepare_invalid(t *testing.T) {
t.Error("got message:", got)
}
}
func TestConn_MustPrepare_empty(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(``)
t.Error("want panic")
}
func TestConn_MustPrepare_tail(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(`SELECT 1; -- HERE`)
t.Error("want panic")
}
func TestConn_MustPrepare_invalid(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(`SELECT`)
t.Error("want panic")
}
func TestConn_Pragma(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.Pragma("encoding=''")
t.Error("want panic")
}

77
tests/ext_test.go Normal file
View File

@@ -0,0 +1,77 @@
package tests
import (
"testing"
"github.com/ncruces/go-sqlite3"
)
func Test_base64(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
// base64
stmt, _, err := db.Prepare(`SELECT base64('TWFueSBoYW5kcyBtYWtlIGxpZ2h0IHdvcmsu')`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal("expected one row")
}
if got := stmt.ColumnText(0); got != "Many hands make light work." {
t.Errorf("got %q", got)
}
}
func Test_decimal(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT decimal_add(decimal('0.1'), decimal('0.2')) = decimal('0.3')`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal("expected one row")
}
if !stmt.ColumnBool(0) {
t.Error("want true")
}
}
func Test_uint(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT 'z2' < 'z11' COLLATE UINT`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal("expected one row")
}
if !stmt.ColumnBool(0) {
t.Error("want true")
}
}

View File

@@ -14,8 +14,15 @@ import (
)
func TestParallel(t *testing.T) {
var iter int
if testing.Short() {
iter = 1000
} else {
iter = 5000
}
name := filepath.Join(t.TempDir(), "test.db")
testParallel(t, name, 1000)
testParallel(t, name, iter)
testIntegrity(t, name)
}
@@ -135,7 +142,7 @@ func testParallel(t *testing.T, name string, n int) {
}
var group errgroup.Group
group.SetLimit(4)
group.SetLimit(6)
for i := 0; i < n; i++ {
if i&7 != 7 {
group.Go(reader)

169
tests/time_test.go Normal file
View File

@@ -0,0 +1,169 @@
package tests
import (
"reflect"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
)
func TestTimeFormat_Encode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt sqlite3.TimeFormat
time time.Time
want any
}{
{sqlite3.TimeFormatDefault, reference, "2013-10-07T04:23:19.12-04:00"},
{sqlite3.TimeFormatJulianDay, reference, 2456572.849526851851852},
{sqlite3.TimeFormatUnix, reference, int64(1381134199)},
{sqlite3.TimeFormatUnixFrac, reference, 1381134199.120},
{sqlite3.TimeFormatUnixMilli, reference, int64(1381134199_120)},
{sqlite3.TimeFormatUnixMicro, reference, int64(1381134199_120000)},
{sqlite3.TimeFormatUnixNano, reference, int64(1381134199_120000000)},
{sqlite3.TimeFormat7, reference, "2013-10-07T08:23:19.120"},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
if got := tt.fmt.Encode(tt.time); !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q.Encode(%v) = %v, want %v", tt.fmt, tt.time, got, tt.want)
}
})
}
}
func TestTimeFormat_Decode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt sqlite3.TimeFormat
val any
want time.Time
wantDelta time.Duration
wantErr bool
}{
{sqlite3.TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
{sqlite3.TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
{sqlite3.TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
{sqlite3.TimeFormatJulianDay, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
{sqlite3.TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
{sqlite3.TimeFormatUnix, int64(1381134199), reference, time.Second, false},
{sqlite3.TimeFormatUnix, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnix, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMilli, "1381134199120", reference, 0, false},
{sqlite3.TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
{sqlite3.TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
{sqlite3.TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMilli, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
{sqlite3.TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
{sqlite3.TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
{sqlite3.TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMicro, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
{sqlite3.TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
{sqlite3.TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
{sqlite3.TimeFormatUnixNano, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixNano, false, time.Time{}, 0, true},
{sqlite3.TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
{sqlite3.TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
{sqlite3.TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199", reference, time.Second, false},
{sqlite3.TimeFormatAuto, "1381134199120", reference, 0, false},
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, false},
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, false},
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, true},
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, false},
{sqlite3.TimeFormat3, false, time.Time{}, 0, true},
{sqlite3.TimeFormat9, false, time.Time{}, 0, true},
{sqlite3.TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
{sqlite3.TimeFormatDefault, false, time.Time{}, 0, true},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
got, err := tt.fmt.Decode(tt.val)
if (err != nil) != tt.wantErr {
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
return
}
if tt.want.Sub(got).Abs() > tt.wantDelta {
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
}
})
}
}
func TestDB_timeCollation(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS times (tstamp COLLATE TIME)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`INSERT INTO times VALUES (?), (?), (?)`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
stmt.BindTime(1, time.Unix(0, 0).UTC(), sqlite3.TimeFormatDefault)
stmt.BindTime(2, time.Unix(0, -1).UTC(), sqlite3.TimeFormatDefault)
stmt.BindTime(3, time.Unix(0, +1).UTC(), sqlite3.TimeFormatDefault)
stmt.Step()
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
stmt, _, err = db.Prepare(`SELECT tstamp FROM times ORDER BY tstamp`)
if err != nil {
t.Fatal(err)
}
var t0 time.Time
for stmt.Step() {
t1 := stmt.ColumnTime(0, sqlite3.TimeFormatAuto)
if t0.After(t1) {
t.Errorf("got %v after %v", t0, t1)
}
t0 = t1
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -29,6 +29,7 @@ func TestConn_Transaction_exec(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
return stmt.ColumnInt(0)
}
@@ -117,6 +118,7 @@ func TestConn_Transaction_panic(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
@@ -183,10 +185,10 @@ func TestConn_Transaction_interrupt(t *testing.T) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
var nilErr error
tx.End(&nilErr)
if !errors.Is(nilErr, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", nilErr)
err = nil
tx.End(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
db.SetInterrupt(context.Background())
@@ -208,6 +210,33 @@ func TestConn_Transaction_interrupt(t *testing.T) {
}
}
func TestConn_Transaction_interrupted(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
cancel()
tx := db.Begin()
err = tx.Commit()
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = nil
tx.End(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
}
func TestConn_Transaction_rollback(t *testing.T) {
t.Parallel()
@@ -275,6 +304,7 @@ func TestConn_Savepoint_exec(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
return stmt.ColumnInt(0)
}
@@ -283,7 +313,7 @@ func TestConn_Savepoint_exec(t *testing.T) {
}
insert := func(succeed bool) (err error) {
defer db.Savepoint()(&err)
defer db.Savepoint().Release(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
@@ -341,7 +371,7 @@ func TestConn_Savepoint_panic(t *testing.T) {
}
panics := func() (err error) {
defer db.Savepoint()(&err)
defer db.Savepoint().Release(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
@@ -361,6 +391,7 @@ func TestConn_Savepoint_panic(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
@@ -391,12 +422,12 @@ func TestConn_Savepoint_interrupt(t *testing.T) {
t.Fatal(err)
}
release := db.Savepoint()
savept := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
}
release(&err)
savept.Release(&err)
if err != nil {
t.Fatal(err)
}
@@ -404,19 +435,19 @@ func TestConn_Savepoint_interrupt(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
release1 := db.Savepoint()
savept1 := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (2)`)
if err != nil {
t.Fatal(err)
}
release2 := db.Savepoint()
savept2 := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (3)`)
if err != nil {
t.Fatal(err)
}
cancel()
db.Savepoint()(&err)
db.Savepoint().Release(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
@@ -427,15 +458,15 @@ func TestConn_Savepoint_interrupt(t *testing.T) {
}
err = context.Canceled
release2(&err)
savept2.Release(&err)
if err != context.Canceled {
t.Fatal(err)
}
var nilErr error
release1(&nilErr)
if !errors.Is(nilErr, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", nilErr)
err = nil
savept1.Release(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
db.SetInterrupt(context.Background())
@@ -471,7 +502,7 @@ func TestConn_Savepoint_rollback(t *testing.T) {
t.Fatal(err)
}
release := db.Savepoint()
savept := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
@@ -480,7 +511,7 @@ func TestConn_Savepoint_rollback(t *testing.T) {
if err != nil {
t.Fatal(err)
}
release(&err)
savept.Release(&err)
if err != nil {
t.Fatal(err)
}

View File

@@ -1,19 +1,23 @@
package sqlite3
package tests
import "testing"
import (
"testing"
"github.com/ncruces/go-sqlite3"
)
func TestDatatype_String(t *testing.T) {
t.Parallel()
tests := []struct {
data Datatype
data sqlite3.Datatype
want string
}{
{INTEGER, "INTEGER"},
{FLOAT, "FLOAT"},
{TEXT, "TEXT"},
{BLOB, "BLOB"},
{NULL, "NULL"},
{sqlite3.INTEGER, "INTEGER"},
{sqlite3.FLOAT, "FLOAT"},
{sqlite3.TEXT, "TEXT"},
{sqlite3.BLOB, "BLOB"},
{sqlite3.NULL, "NULL"},
{10, "10"},
}
for _, tt := range tests {

40
time.go
View File

@@ -6,6 +6,7 @@ import (
"strings"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/julianday"
)
@@ -62,13 +63,18 @@ const (
// [TimeFormatDefault] and [TimeFormatAuto] encode using [time.RFC3339Nano],
// with nanosecond accuracy, and preserving any timezone offset.
//
// This is the format used by the database/sql driver:
// [database/sql.Row.Scan] is able to decode as [time.Time]
// This is the format used by the [database/sql] driver:
// [database/sql.Row.Scan] will decode as [time.Time]
// values encoded with [time.RFC3339Nano].
//
// Time values encoded with [time.RFC3339Nano] cannot be sorted as strings
// to produce a time-ordered sequence.
// Use [TimeFormat7] for time-ordered encoding.
//
// Assuming that the time zones of the time values are the same (e.g., all in UTC),
// and expressed using the same string (e.g., all "Z" or all "+00:00"),
// use the TIME [collating sequence] to produce a time-ordered sequence.
//
// Otherwise, use [TimeFormat7] for time-ordered encoding.
//
// Formats [TimeFormat1] through [TimeFormat10]
// convert time values to UTC before encoding.
@@ -78,6 +84,8 @@ const (
// or an int64 for the other numeric formats.
//
// https://www.sqlite.org/lang_datefunc.html
//
// [collating sequence]: https://www.sqlite.org/datatype3.html#collating_sequences
func (f TimeFormat) Encode(t time.Time) any {
switch f {
// Numeric formats
@@ -123,9 +131,9 @@ func (f TimeFormat) Encode(t time.Time) any {
// [TimeFormatAuto] implements (and extends) the SQLite auto modifier.
// Julian day numbers are safe to use for historical dates,
// from 4712BC through 9999AD.
// Unix timestamps (expressed in seconds, milliseconds, microseconds, or nanoseconds),
// are safe to use for current events, from 1980 through at least 2260.
// Unix timestamps before 1980 may be misinterpreted as julian day numbers,
// Unix timestamps (expressed in seconds, milliseconds, microseconds, or nanoseconds)
// are safe to use for current events, from at least 1980 through at least 2260.
// Unix timestamps before 1980 and after 9999 may be misinterpreted as julian day numbers,
// or have the wrong time unit.
//
// https://www.sqlite.org/lang_datefunc.html
@@ -141,7 +149,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return julianday.Time(v, 0), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case TimeFormatUnix, TimeFormatUnixFrac:
@@ -160,7 +168,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return time.Unix(v, 0), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case TimeFormatUnixMilli:
@@ -177,7 +185,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return time.UnixMilli(int64(v)), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case TimeFormatUnixMicro:
@@ -194,14 +202,14 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return time.UnixMicro(int64(v)), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case TimeFormatUnixNano:
if s, ok := v.(string); ok {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
v = i
}
@@ -211,7 +219,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return time.Unix(0, int64(v)), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
// Special formats
@@ -281,7 +289,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
return TimeFormatUnixNano.Decode(v)
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case
@@ -293,7 +301,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
TimeFormat7, TimeFormat7TZ:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
return f.parseRelaxed(s)
@@ -303,7 +311,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
TimeFormat10, TimeFormat10TZ:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
t, err := f.parseRelaxed(s)
return t.AddDate(2000, 0, 0), err
@@ -311,7 +319,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
default:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
if f == "" {
f = time.RFC3339Nano

View File

@@ -1,118 +0,0 @@
package sqlite3
import (
"reflect"
"testing"
"time"
)
func TestTimeFormat_Encode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt TimeFormat
time time.Time
want any
}{
{TimeFormatDefault, reference, "2013-10-07T04:23:19.12-04:00"},
{TimeFormatJulianDay, reference, 2456572.849526851851852},
{TimeFormatUnix, reference, int64(1381134199)},
{TimeFormatUnixFrac, reference, 1381134199.120},
{TimeFormatUnixMilli, reference, int64(1381134199_120)},
{TimeFormatUnixMicro, reference, int64(1381134199_120000)},
{TimeFormatUnixNano, reference, int64(1381134199_120000000)},
{TimeFormat7, reference, "2013-10-07T08:23:19.120"},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
if got := tt.fmt.Encode(tt.time); !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q.Encode(%v) = %v, want %v", tt.fmt, tt.time, got, tt.want)
}
})
}
}
func TestTimeFormat_Decode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
reftime := time.Date(2000, 1, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt TimeFormat
val any
want time.Time
wantDelta time.Duration
wantErr bool
}{
{TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
{TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
{TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
{TimeFormatJulianDay, false, time.Time{}, 0, true},
{TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
{TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
{TimeFormatUnix, int64(1381134199), reference, time.Second, false},
{TimeFormatUnix, "abc", time.Time{}, 0, true},
{TimeFormatUnix, false, time.Time{}, 0, true},
{TimeFormatUnixMilli, "1381134199120", reference, 0, false},
{TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
{TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
{TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
{TimeFormatUnixMilli, false, time.Time{}, 0, true},
{TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
{TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
{TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
{TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
{TimeFormatUnixMicro, false, time.Time{}, 0, true},
{TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
{TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
{TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
{TimeFormatUnixNano, "abc", time.Time{}, 0, true},
{TimeFormatUnixNano, false, time.Time{}, 0, true},
{TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
{TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
{TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199", reference, time.Second, false},
{TimeFormatAuto, "1381134199120", reference, 0, false},
{TimeFormatAuto, "1381134199120000", reference, 0, false},
{TimeFormatAuto, "1381134199120000000", reference, 0, false},
{TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{TimeFormatAuto, "04:23:19.12-04:00", reftime, 0, false},
{TimeFormatAuto, "abc", time.Time{}, 0, true},
{TimeFormatAuto, false, time.Time{}, 0, true},
{TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
{TimeFormat9, "04:23:19.12-04:00", reftime, 0, false},
{TimeFormat9, "08:23:19.12", reftime, 0, false},
{TimeFormat3, false, time.Time{}, 0, true},
{TimeFormat9, false, time.Time{}, 0, true},
{TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
{TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
{TimeFormatDefault, false, time.Time{}, 0, true},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
got, err := tt.fmt.Decode(tt.val)
if (err != nil) != tt.wantErr {
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
return
}
if tt.want.Sub(got).Abs() > tt.wantDelta {
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
}
})
}
}

155
tx.go
View File

@@ -4,9 +4,14 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"runtime"
"strconv"
)
// Tx is an in-progress database transaction.
//
// https://www.sqlite.org/lang_transaction.html
type Tx struct {
c *Conn
}
@@ -15,8 +20,9 @@ type Tx struct {
//
// https://www.sqlite.org/lang_transaction.html
func (c *Conn) Begin() Tx {
err := c.Exec(`BEGIN DEFERRED`)
if err != nil && !errors.Is(err, INTERRUPT) {
// BEGIN even if interrupted.
err := c.txExecInterrupted(`BEGIN DEFERRED`)
if err != nil {
panic(err)
}
return Tx{c}
@@ -63,21 +69,22 @@ func (tx Tx) End(errp *error) {
defer panic(recovered)
}
if tx.c.GetAutocommit() {
// There is nothing to commit/rollback.
return
}
if *errp == nil && recovered == nil {
// Success path.
if tx.c.GetAutocommit() { // There is nothing to commit.
return
}
*errp = tx.Commit()
if *errp == nil {
return
}
// Possible interrupt, fall through to the error path.
// Fall through to the error path.
}
// Error path.
if tx.c.GetAutocommit() { // There is nothing to rollback.
return
}
err := tx.Rollback()
if err != nil {
panic(err)
@@ -91,33 +98,28 @@ func (tx Tx) Commit() error {
return tx.c.Exec(`COMMIT`)
}
// Rollback rollsback the transaction.
// Rollback rolls back the transaction,
// even if the connection has been interrupted.
//
// https://www.sqlite.org/lang_transaction.html
func (tx Tx) Rollback() error {
// ROLLBACK even if the connection has been interrupted.
old := tx.c.SetInterrupt(context.Background())
defer tx.c.SetInterrupt(old)
return tx.c.Exec(`ROLLBACK`)
return tx.c.txExecInterrupted(`ROLLBACK`)
}
// Savepoint creates a named SQLite transaction using SAVEPOINT.
//
// On success Savepoint returns a release func that will call either
// RELEASE or ROLLBACK depending on whether the parameter *error
// points to a nil or non-nil error.
//
// This is meant to be deferred:
//
// func doWork(conn *sqlite3.Conn) (err error) {
// defer conn.Savepoint()(&err)
//
// // ... do work in the transaction
// }
// Savepoint is a marker within a transaction
// that allows for partial rollback.
//
// https://www.sqlite.org/lang_savepoint.html
func (c *Conn) Savepoint() (release func(*error)) {
name := "sqlite3.Savepoint" // names can be reused
type Savepoint struct {
c *Conn
name string
}
// Savepoint establishes a new transaction savepoint.
//
// https://www.sqlite.org/lang_savepoint.html
func (c *Conn) Savepoint() Savepoint {
name := "sqlite3.Savepoint"
var pc [1]uintptr
if n := runtime.Callers(2, pc[:]); n > 0 {
frames := runtime.CallersFrames(pc[:n])
@@ -126,52 +128,75 @@ func (c *Conn) Savepoint() (release func(*error)) {
name = frame.Function
}
}
// Names can be reused; this makes catching bugs more likely.
name += "#" + strconv.Itoa(int(rand.Int31()))
err := c.Exec(fmt.Sprintf("SAVEPOINT %q;", name))
err := c.txExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name))
if err != nil {
if errors.Is(err, INTERRUPT) {
return func(errp *error) {
if *errp == nil {
*errp = err
}
}
}
panic(err)
}
return Savepoint{c: c, name: name}
}
return func(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
// Release releases the savepoint rolling back any changes
// if *error points to a non-nil error.
//
// This is meant to be deferred:
//
// func doWork(conn *sqlite3.Conn) (err error) {
// savept := conn.Savepoint()
// defer savept.Release(&err)
//
// // ... do work in the transaction
// }
func (s Savepoint) Release(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
if c.GetAutocommit() {
// There is nothing to commit/rollback.
if *errp == nil && recovered == nil {
// Success path.
if s.c.GetAutocommit() { // There is nothing to commit.
return
}
if *errp == nil && recovered == nil {
// Success path.
// RELEASE the savepoint successfully.
*errp = c.Exec(fmt.Sprintf("RELEASE %q;", name))
if *errp == nil {
return
}
// Possible interrupt, fall through to the error path.
*errp = s.c.Exec(fmt.Sprintf("RELEASE %q;", s.name))
if *errp == nil {
return
}
// Fall through to the error path.
}
// Error path.
// Always ROLLBACK even if the connection has been interrupted.
old := c.SetInterrupt(context.Background())
defer c.SetInterrupt(old)
err := c.Exec(fmt.Sprintf("ROLLBACK TO %q;", name))
if err != nil {
panic(err)
}
err = c.Exec(fmt.Sprintf("RELEASE %q;", name))
if err != nil {
panic(err)
}
// Error path.
if s.c.GetAutocommit() { // There is nothing to rollback.
return
}
// ROLLBACK and RELEASE even if interrupted.
err := s.c.txExecInterrupted(fmt.Sprintf(`
ROLLBACK TO %[1]q;
RELEASE %[1]q;
`, s.name))
if err != nil {
panic(err)
}
}
// Rollback rolls the transaction back to the savepoint,
// even if the connection has been interrupted.
// Rollback does not release the savepoint.
//
// https://www.sqlite.org/lang_transaction.html
func (s Savepoint) Rollback() error {
// ROLLBACK even if interrupted.
return s.c.txExecInterrupted(fmt.Sprintf("ROLLBACK TO %q;", s.name))
}
func (c *Conn) txExecInterrupted(sql string) error {
err := c.Exec(sql)
if errors.Is(err, INTERRUPT) {
old := c.SetInterrupt(context.Background())
defer c.SetInterrupt(old)
err = c.Exec(sql)
}
return err
}

318
vfs.go
View File

@@ -1,318 +0,0 @@
package sqlite3
import (
"context"
"crypto/rand"
"errors"
"io"
"io/fs"
"os"
"path/filepath"
"runtime"
"syscall"
"time"
"github.com/ncruces/julianday"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/sys"
)
func vfsInstantiate(ctx context.Context, r wazero.Runtime) {
wasi := r.NewHostModuleBuilder("wasi_snapshot_preview1")
wasi.NewFunctionBuilder().WithFunc(vfsExit).Export("proc_exit")
_, err := wasi.Instantiate(ctx)
if err != nil {
panic(err)
}
env := r.NewHostModuleBuilder("env")
env.NewFunctionBuilder().WithFunc(vfsLocaltime).Export("os_localtime")
env.NewFunctionBuilder().WithFunc(vfsRandomness).Export("os_randomness")
env.NewFunctionBuilder().WithFunc(vfsSleep).Export("os_sleep")
env.NewFunctionBuilder().WithFunc(vfsCurrentTime).Export("os_current_time")
env.NewFunctionBuilder().WithFunc(vfsCurrentTime64).Export("os_current_time_64")
env.NewFunctionBuilder().WithFunc(vfsFullPathname).Export("os_full_pathname")
env.NewFunctionBuilder().WithFunc(vfsDelete).Export("os_delete")
env.NewFunctionBuilder().WithFunc(vfsAccess).Export("os_access")
env.NewFunctionBuilder().WithFunc(vfsOpen).Export("os_open")
env.NewFunctionBuilder().WithFunc(vfsClose).Export("os_close")
env.NewFunctionBuilder().WithFunc(vfsRead).Export("os_read")
env.NewFunctionBuilder().WithFunc(vfsWrite).Export("os_write")
env.NewFunctionBuilder().WithFunc(vfsTruncate).Export("os_truncate")
env.NewFunctionBuilder().WithFunc(vfsSync).Export("os_sync")
env.NewFunctionBuilder().WithFunc(vfsFileSize).Export("os_file_size")
env.NewFunctionBuilder().WithFunc(vfsLock).Export("os_lock")
env.NewFunctionBuilder().WithFunc(vfsUnlock).Export("os_unlock")
env.NewFunctionBuilder().WithFunc(vfsCheckReservedLock).Export("os_check_reserved_lock")
env.NewFunctionBuilder().WithFunc(vfsFileControl).Export("os_file_control")
_, err = env.Instantiate(ctx)
if err != nil {
panic(err)
}
}
type vfsOSMethods bool
const vfsOS vfsOSMethods = false
func vfsExit(ctx context.Context, mod api.Module, exitCode uint32) {
// Ensure other callers see the exit code.
_ = mod.CloseWithExitCode(ctx, exitCode)
// Prevent any code from executing after this function.
panic(sys.NewExitError(mod.Name(), exitCode))
}
func vfsLocaltime(ctx context.Context, mod api.Module, t uint64, pTm uint32) uint32 {
tm := time.Unix(int64(t), 0)
var isdst int
if tm.IsDST() {
isdst = 1
}
// https://pubs.opengroup.org/onlinepubs/7908799/xsh/time.h.html
mem := memory{mod}
mem.writeUint32(pTm+0*ptrlen, uint32(tm.Second()))
mem.writeUint32(pTm+1*ptrlen, uint32(tm.Minute()))
mem.writeUint32(pTm+2*ptrlen, uint32(tm.Hour()))
mem.writeUint32(pTm+3*ptrlen, uint32(tm.Day()))
mem.writeUint32(pTm+4*ptrlen, uint32(tm.Month()-time.January))
mem.writeUint32(pTm+5*ptrlen, uint32(tm.Year()-1900))
mem.writeUint32(pTm+6*ptrlen, uint32(tm.Weekday()-time.Sunday))
mem.writeUint32(pTm+7*ptrlen, uint32(tm.YearDay()-1))
mem.writeUint32(pTm+8*ptrlen, uint32(isdst))
return _OK
}
func vfsRandomness(ctx context.Context, mod api.Module, pVfs, nByte, zByte uint32) uint32 {
mem := memory{mod}.view(zByte, uint64(nByte))
n, _ := rand.Reader.Read(mem)
return uint32(n)
}
func vfsSleep(ctx context.Context, pVfs, nMicro uint32) uint32 {
time.Sleep(time.Duration(nMicro) * time.Microsecond)
return _OK
}
func vfsCurrentTime(ctx context.Context, mod api.Module, pVfs, prNow uint32) uint32 {
day := julianday.Float(time.Now())
memory{mod}.writeFloat64(prNow, day)
return _OK
}
func vfsCurrentTime64(ctx context.Context, mod api.Module, pVfs, piNow uint32) uint32 {
day, nsec := julianday.Date(time.Now())
msec := day*86_400_000 + nsec/1_000_000
memory{mod}.writeUint64(piNow, uint64(msec))
return _OK
}
func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull, zFull uint32) uint32 {
rel := memory{mod}.readString(zRelative, _MAX_PATHNAME)
abs, err := filepath.Abs(rel)
if err != nil {
return uint32(IOERR)
}
// Consider either using [filepath.EvalSymlinks] to canonicalize the path (as the Unix VFS does).
// Or using [os.Readlink] to resolve a symbolic link (as the Unix VFS did).
// This might be buggy on Windows (the Windows VFS doesn't try).
size := uint64(len(abs) + 1)
if size > uint64(nFull) {
return uint32(CANTOPEN_FULLPATH)
}
mem := memory{mod}.view(zFull, size)
mem[len(abs)] = 0
copy(mem, abs)
return _OK
}
func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32) uint32 {
path := memory{mod}.readString(zPath, _MAX_PATHNAME)
err := os.Remove(path)
if errors.Is(err, fs.ErrNotExist) {
return _OK
}
if err != nil {
return uint32(IOERR_DELETE)
}
if runtime.GOOS != "windows" && syncDir != 0 {
f, err := os.Open(filepath.Dir(path))
if err == nil {
err = f.Sync()
f.Close()
}
if err != nil {
return uint32(IOERR_DELETE)
}
}
return _OK
}
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags _AccessFlag, pResOut uint32) uint32 {
// Consider using [syscall.Access] for [ACCESS_READWRITE]/[ACCESS_READ]
// (as the Unix VFS does).
path := memory{mod}.readString(zPath, _MAX_PATHNAME)
fi, err := os.Stat(path)
var res uint32
switch {
case flags == _ACCESS_EXISTS:
switch {
case err == nil:
res = 1
case errors.Is(err, fs.ErrNotExist):
res = 0
default:
return uint32(IOERR_ACCESS)
}
case err == nil:
var want fs.FileMode = syscall.S_IRUSR
if flags == _ACCESS_READWRITE {
want |= syscall.S_IWUSR
}
if fi.IsDir() {
want |= syscall.S_IXUSR
}
if fi.Mode()&want == want {
res = 1
} else {
res = 0
}
case errors.Is(err, fs.ErrPermission):
res = 0
default:
return uint32(IOERR_ACCESS)
}
memory{mod}.writeUint32(pResOut, res)
return _OK
}
func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, flags OpenFlag, pOutFlags uint32) uint32 {
var oflags int
if flags&OPEN_EXCLUSIVE != 0 {
oflags |= os.O_EXCL
}
if flags&OPEN_CREATE != 0 {
oflags |= os.O_CREATE
}
if flags&OPEN_READONLY != 0 {
oflags |= os.O_RDONLY
}
if flags&OPEN_READWRITE != 0 {
oflags |= os.O_RDWR
}
var err error
var file *os.File
if zName == 0 {
file, err = os.CreateTemp("", "*.db")
} else {
name := memory{mod}.readString(zName, _MAX_PATHNAME)
file, err = os.OpenFile(name, oflags, 0600)
}
if err != nil {
return uint32(CANTOPEN)
}
if flags&OPEN_DELETEONCLOSE != 0 {
vfsOS.DeleteOnClose(file)
}
id := vfsGetFileID(file)
vfsFilePtr{mod, pFile}.SetID(id).SetLock(_NO_LOCK)
if pOutFlags != 0 {
memory{mod}.writeUint32(pOutFlags, uint32(flags))
}
return _OK
}
func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 {
id := vfsFilePtr{mod, pFile}.ID()
err := vfsCloseFile(id)
if err != nil {
return uint32(IOERR_CLOSE)
}
return _OK
}
func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 {
buf := memory{mod}.view(zBuf, uint64(iAmt))
file := vfsFilePtr{mod, pFile}.OSFile()
n, err := file.ReadAt(buf, int64(iOfst))
if n == int(iAmt) {
return _OK
}
if n == 0 && err != io.EOF {
return uint32(IOERR_READ)
}
for i := range buf[n:] {
buf[n+i] = 0
}
return uint32(IOERR_SHORT_READ)
}
func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 {
buf := memory{mod}.view(zBuf, uint64(iAmt))
file := vfsFilePtr{mod, pFile}.OSFile()
_, err := file.WriteAt(buf, int64(iOfst))
if err != nil {
return uint32(IOERR_WRITE)
}
return _OK
}
func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte uint64) uint32 {
file := vfsFilePtr{mod, pFile}.OSFile()
err := file.Truncate(int64(nByte))
if err != nil {
return uint32(IOERR_TRUNCATE)
}
return _OK
}
func vfsSync(ctx context.Context, mod api.Module, pFile, flags uint32) uint32 {
file := vfsFilePtr{mod, pFile}.OSFile()
err := file.Sync()
if err != nil {
return uint32(IOERR_FSYNC)
}
return _OK
}
func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) uint32 {
// This uses [os.File.Seek] because we don't care about the offset for reading/writing.
// But consider using [os.File.Stat] instead (as other VFSes do).
file := vfsFilePtr{mod, pFile}.OSFile()
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return uint32(IOERR_SEEK)
}
memory{mod}.writeUint64(pSize, uint64(off))
return _OK
}
func vfsFileControl(ctx context.Context, pFile, op, pArg uint32) uint32 {
// SQLite calls vfsFileControl with these opcodes:
// SQLITE_FCNTL_SIZE_HINT
// SQLITE_FCNTL_PRAGMA
// SQLITE_FCNTL_BUSYHANDLER
// SQLITE_FCNTL_HAS_MOVED
// SQLITE_FCNTL_SYNC
// SQLITE_FCNTL_COMMIT_PHASETWO
// SQLITE_FCNTL_PDB
return uint32(NOTFOUND)
}

View File

@@ -1,69 +0,0 @@
package sqlite3
import (
"os"
"sync"
"github.com/tetratelabs/wazero/api"
)
var (
vfsOpenFiles []*os.File
vfsOpenFilesMtx sync.Mutex
)
func vfsGetFileID(file *os.File) uint32 {
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
// Find an empty slot.
for id, ptr := range vfsOpenFiles {
if ptr == nil {
vfsOpenFiles[id] = file
return uint32(id)
}
}
// Add a new slot.
vfsOpenFiles = append(vfsOpenFiles, file)
return uint32(len(vfsOpenFiles) - 1)
}
func vfsCloseFile(id uint32) error {
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
file := vfsOpenFiles[id]
vfsOpenFiles[id] = nil
return file.Close()
}
type vfsFilePtr struct {
api.Module
ptr uint32
}
func (p vfsFilePtr) OSFile() *os.File {
id := p.ID()
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
return vfsOpenFiles[id]
}
func (p vfsFilePtr) ID() uint32 {
return memory{p}.readUint32(p.ptr + ptrlen)
}
func (p vfsFilePtr) Lock() vfsLockState {
return vfsLockState(memory{p}.readUint32(p.ptr + 2*ptrlen))
}
func (p vfsFilePtr) SetID(id uint32) vfsFilePtr {
memory{p}.writeUint32(p.ptr+ptrlen, id)
return p
}
func (p vfsFilePtr) SetLock(lock vfsLockState) vfsFilePtr {
memory{p}.writeUint32(p.ptr+2*ptrlen, uint32(lock))
return p
}

View File

@@ -1,215 +0,0 @@
package sqlite3
import (
"context"
"os"
"github.com/tetratelabs/wazero/api"
)
const (
// No locks are held on the database.
// The database may be neither read nor written.
// Any internally cached data is considered suspect and subject to
// verification against the database file before being used.
// Other processes can read or write the database as their own locking
// states permit.
// This is the default state.
_NO_LOCK = 0
// The database may be read but not written.
// Any number of processes can hold SHARED locks at the same time,
// hence there can be many simultaneous readers.
// But no other thread or process is allowed to write to the database file
// while one or more SHARED locks are active.
_SHARED_LOCK = 1
// A RESERVED lock means that the process is planning on writing to the
// database file at some point in the future but that it is currently just
// reading from the file.
// Only a single RESERVED lock may be active at one time,
// though multiple SHARED locks can coexist with a single RESERVED lock.
// RESERVED differs from PENDING in that new SHARED locks can be acquired
// while there is a RESERVED lock.
_RESERVED_LOCK = 2
// A PENDING lock means that the process holding the lock wants to write to
// the database as soon as possible and is just waiting on all current
// SHARED locks to clear so that it can get an EXCLUSIVE lock.
// No new SHARED locks are permitted against the database if a PENDING lock
// is active, though existing SHARED locks are allowed to continue.
_PENDING_LOCK = 3
// An EXCLUSIVE lock is needed in order to write to the database file.
// Only one EXCLUSIVE lock is allowed on the file and no other locks of any
// kind are allowed to coexist with an EXCLUSIVE lock.
// In order to maximize concurrency, SQLite works to minimize the amount of
// time that EXCLUSIVE locks are held.
_EXCLUSIVE_LOCK = 4
_PENDING_BYTE = 0x40000000
_RESERVED_BYTE = (_PENDING_BYTE + 1)
_SHARED_FIRST = (_PENDING_BYTE + 2)
_SHARED_SIZE = 510
)
type vfsLockState uint32
func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) uint32 {
// Argument check. SQLite never explicitly requests a pendig lock.
if eLock != _SHARED_LOCK && eLock != _RESERVED_LOCK && eLock != _EXCLUSIVE_LOCK {
panic(assertErr())
}
ptr := vfsFilePtr{mod, pFile}
file := ptr.OSFile()
cLock := ptr.Lock()
switch {
case cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK:
// Connection state check.
panic(assertErr())
case cLock == _NO_LOCK && eLock > _SHARED_LOCK:
// We never move from unlocked to anything higher than a shared lock.
panic(assertErr())
case cLock != _SHARED_LOCK && eLock == _RESERVED_LOCK:
// A shared lock is always held when a reserved lock is requested.
panic(assertErr())
}
// If we already have an equal or more restrictive lock, do nothing.
if cLock >= eLock {
return _OK
}
switch eLock {
case _SHARED_LOCK:
// Must be unlocked to get SHARED.
if cLock != _NO_LOCK {
panic(assertErr())
}
// Test the PENDING lock before acquiring a new SHARED lock.
if locked, _ := vfsOS.CheckPendingLock(file); locked {
return uint32(BUSY)
}
if rc := vfsOS.GetSharedLock(file); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_SHARED_LOCK)
return _OK
case _RESERVED_LOCK:
// Must be SHARED to get RESERVED.
if cLock != _SHARED_LOCK {
panic(assertErr())
}
if rc := vfsOS.GetReservedLock(file); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_RESERVED_LOCK)
return _OK
case _EXCLUSIVE_LOCK:
// Must be SHARED, RESERVED or PENDING to get EXCLUSIVE.
if cLock <= _NO_LOCK || cLock >= _EXCLUSIVE_LOCK {
panic(assertErr())
}
// A PENDING lock is needed before acquiring an EXCLUSIVE lock.
if cLock == _RESERVED_LOCK {
if rc := vfsOS.GetPendingLock(file); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_PENDING_LOCK)
}
if rc := vfsOS.GetExclusiveLock(file); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_EXCLUSIVE_LOCK)
return _OK
default:
panic(assertErr())
}
}
func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) uint32 {
// Argument check.
if eLock != _NO_LOCK && eLock != _SHARED_LOCK {
panic(assertErr())
}
ptr := vfsFilePtr{mod, pFile}
file := ptr.OSFile()
cLock := ptr.Lock()
// Connection state check.
if cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK {
panic(assertErr())
}
// If we don't have a more restrictive lock, do nothing.
if cLock <= eLock {
return _OK
}
switch eLock {
case _SHARED_LOCK:
if rc := vfsOS.DowngradeLock(file, cLock); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_SHARED_LOCK)
return _OK
case _NO_LOCK:
rc := vfsOS.ReleaseLock(file, cLock)
ptr.SetLock(_NO_LOCK)
return uint32(rc)
default:
panic(assertErr())
}
}
func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) uint32 {
ptr := vfsFilePtr{mod, pFile}
cLock := ptr.Lock()
if cLock > _SHARED_LOCK {
panic(assertErr())
}
file := ptr.OSFile()
locked, rc := vfsOS.CheckReservedLock(file)
var res uint32
if locked {
res = 1
}
memory{mod}.writeUint32(pResOut, res)
return uint32(rc)
}
func (vfsOSMethods) GetSharedLock(file *os.File) xErrorCode {
// Acquire the SHARED lock.
return vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE)
}
func (vfsOSMethods) GetReservedLock(file *os.File) xErrorCode {
// Acquire the RESERVED lock.
return vfsOS.writeLock(file, _RESERVED_BYTE, 1)
}
func (vfsOSMethods) GetPendingLock(file *os.File) xErrorCode {
// Acquire the PENDING lock.
return vfsOS.writeLock(file, _PENDING_BYTE, 1)
}
func (vfsOSMethods) CheckReservedLock(file *os.File) (bool, xErrorCode) {
// Test the RESERVED lock.
return vfsOS.checkLock(file, _RESERVED_BYTE, 1)
}
func (vfsOSMethods) CheckPendingLock(file *os.File) (bool, xErrorCode) {
// Test the PENDING lock.
return vfsOS.checkLock(file, _PENDING_BYTE, 1)
}

View File

@@ -1,125 +0,0 @@
package sqlite3
import (
"context"
"os"
"path/filepath"
"runtime"
"testing"
)
func Test_vfsLock(t *testing.T) {
switch runtime.GOOS {
case "linux", "darwin", "illumos", "windows":
break
default:
t.Skip("OS lacks OFD locks")
}
name := filepath.Join(t.TempDir(), "test.db")
// Create a temporary file.
file1, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
if err != nil {
t.Fatal(err)
}
defer file1.Close()
// Open the temporary file again.
file2, err := os.OpenFile(name, os.O_RDWR, 0)
if err != nil {
t.Fatal(err)
}
defer file2.Close()
const (
pFile1 = 4
pFile2 = 16
pOutput = 32
)
mem := newMemory(128)
vfsFilePtr{mem.mod, pFile1}.SetID(vfsGetFileID(file1)).SetLock(_NO_LOCK)
vfsFilePtr{mem.mod, pFile2}.SetID(vfsGetFileID(file2)).SetLock(_NO_LOCK)
rc := vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile2, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile2, _RESERVED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsLock(context.TODO(), mem.mod, pFile2, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile2, _EXCLUSIVE_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile1, _SHARED_LOCK)
if rc == _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsUnlock(context.TODO(), mem.mod, pFile2, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile1, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
}

View File

@@ -1,271 +0,0 @@
package sqlite3
import (
"bytes"
"context"
"errors"
"io/fs"
"os"
"path/filepath"
"syscall"
"testing"
"time"
"github.com/ncruces/julianday"
)
func Test_vfsExit(t *testing.T) {
mem := newMemory(128)
defer func() { _ = recover() }()
vfsExit(context.TODO(), mem.mod, 1)
t.Error("want panic")
}
func Test_vfsLocaltime(t *testing.T) {
mem := newMemory(128)
rc := vfsLocaltime(context.TODO(), mem.mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
epoch := time.Unix(0, 0)
if s := mem.readUint32(4 + 0*4); int(s) != epoch.Second() {
t.Error("wrong second")
}
if m := mem.readUint32(4 + 1*4); int(m) != epoch.Minute() {
t.Error("wrong minute")
}
if h := mem.readUint32(4 + 2*4); int(h) != epoch.Hour() {
t.Error("wrong hour")
}
if d := mem.readUint32(4 + 3*4); int(d) != epoch.Day() {
t.Error("wrong day")
}
if m := mem.readUint32(4 + 4*4); time.Month(1+m) != epoch.Month() {
t.Error("wrong month")
}
if y := mem.readUint32(4 + 5*4); 1900+int(y) != epoch.Year() {
t.Error("wrong year")
}
if w := mem.readUint32(4 + 6*4); time.Weekday(w) != epoch.Weekday() {
t.Error("wrong weekday")
}
if d := mem.readUint32(4 + 7*4); int(d) != epoch.YearDay()-1 {
t.Error("wrong yearday")
}
}
func Test_vfsRandomness(t *testing.T) {
mem := newMemory(128)
rc := vfsRandomness(context.TODO(), mem.mod, 0, 16, 4)
if rc != 16 {
t.Fatal("returned", rc)
}
var zero [16]byte
if got := mem.view(4, 16); bytes.Equal(got, zero[:]) {
t.Fatal("all zero")
}
}
func Test_vfsSleep(t *testing.T) {
start := time.Now()
rc := vfsSleep(context.TODO(), 0, 123456)
if rc != 0 {
t.Fatal("returned", rc)
}
want := 123456 * time.Microsecond
if got := time.Since(start); got < want {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsCurrentTime(t *testing.T) {
mem := newMemory(128)
now := time.Now()
rc := vfsCurrentTime(context.TODO(), mem.mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
want := julianday.Float(now)
if got := mem.readFloat64(4); float32(got) != float32(want) {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsCurrentTime64(t *testing.T) {
mem := newMemory(128)
now := time.Now()
time.Sleep(time.Millisecond)
rc := vfsCurrentTime64(context.TODO(), mem.mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
day, nsec := julianday.Date(now)
want := day*86_400_000 + nsec/1_000_000
if got := mem.readUint64(4); float32(got) != float32(want) {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsFullPathname(t *testing.T) {
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(4, ".")
rc := vfsFullPathname(context.TODO(), mem.mod, 0, 4, 0, 8)
if rc != uint32(CANTOPEN_FULLPATH) {
t.Errorf("returned %d, want %d", rc, CANTOPEN_FULLPATH)
}
rc = vfsFullPathname(context.TODO(), mem.mod, 0, 4, _MAX_PATHNAME, 8)
if rc != _OK {
t.Fatal("returned", rc)
}
want, _ := filepath.Abs(".")
if got := mem.readString(8, _MAX_PATHNAME); got != want {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsDelete(t *testing.T) {
name := filepath.Join(t.TempDir(), "test.db")
file, err := os.Create(name)
if err != nil {
t.Fatal(err)
}
file.Close()
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(4, name)
rc := vfsDelete(context.TODO(), mem.mod, 0, 4, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
if _, err := os.Stat(name); !errors.Is(err, fs.ErrNotExist) {
t.Fatal("did not delete the file")
}
rc = vfsDelete(context.TODO(), mem.mod, 0, 4, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
}
func Test_vfsAccess(t *testing.T) {
dir := t.TempDir()
file := filepath.Join(t.TempDir(), "test.db")
if f, err := os.Create(file); err != nil {
t.Fatal(err)
} else {
f.Close()
}
if err := os.Chmod(file, syscall.S_IRUSR); err != nil {
t.Fatal(err)
}
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(8, dir)
rc := vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_EXISTS, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(4); got != 1 {
t.Error("directory did not exist")
}
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(4); got != 1 {
t.Error("can't access directory")
}
mem.writeString(8, file)
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(4); got != 0 {
t.Error("can access file")
}
}
func Test_vfsFile(t *testing.T) {
mem := newMemory(128)
// Open a temporary file.
rc := vfsOpen(context.TODO(), mem.mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
// Write stuff.
text := "Hello world!"
mem.writeString(16, text)
rc = vfsWrite(context.TODO(), mem.mod, 4, 16, uint32(len(text)), 0)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check file size.
rc = vfsFileSize(context.TODO(), mem.mod, 4, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(16); got != uint32(len(text)) {
t.Errorf("got %d", got)
}
// Partial read at offset.
rc = vfsRead(context.TODO(), mem.mod, 4, 16, uint32(len(text)), 4)
if rc != uint32(IOERR_SHORT_READ) {
t.Fatal("returned", rc)
}
if got := mem.readString(16, 64); got != text[4:] {
t.Errorf("got %q", got)
}
// Truncate the file.
rc = vfsTruncate(context.TODO(), mem.mod, 4, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check file size.
rc = vfsFileSize(context.TODO(), mem.mod, 4, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(16); got != 4 {
t.Errorf("got %d", got)
}
// Read at offset.
rc = vfsRead(context.TODO(), mem.mod, 4, 32, 4, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readString(32, 64); got != text[:4] {
t.Errorf("got %q", got)
}
// Close the file.
rc = vfsClose(context.TODO(), mem.mod, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
}

View File

@@ -1,137 +0,0 @@
//go:build unix
package sqlite3
import (
"os"
"runtime"
"syscall"
)
func (vfsOSMethods) DeleteOnClose(file *os.File) {
_ = os.Remove(file.Name())
}
func (vfsOSMethods) GetExclusiveLock(file *os.File) xErrorCode {
// Acquire the EXCLUSIVE lock.
return vfsOS.writeLock(file, _SHARED_FIRST, _SHARED_SIZE)
}
func (vfsOSMethods) DowngradeLock(file *os.File, state vfsLockState) xErrorCode {
if state >= _EXCLUSIVE_LOCK {
// Downgrade to a SHARED lock.
if rc := vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return vfsOS.unlock(file, _PENDING_BYTE, 2)
}
func (vfsOSMethods) ReleaseLock(file *os.File, _ vfsLockState) xErrorCode {
// Release all locks.
return vfsOS.unlock(file, 0, 0)
}
func (vfsOSMethods) unlock(file *os.File, start, len int64) xErrorCode {
err := vfsOS.fcntlSetLock(file, &syscall.Flock_t{
Type: syscall.F_UNLCK,
Start: start,
Len: len,
})
if err != nil {
return IOERR_UNLOCK
}
return _OK
}
func (vfsOSMethods) readLock(file *os.File, start, len int64) xErrorCode {
return vfsOS.lockErrorCode(vfsOS.fcntlSetLock(file, &syscall.Flock_t{
Type: syscall.F_RDLCK,
Start: start,
Len: len,
}), IOERR_RDLOCK)
}
func (vfsOSMethods) writeLock(file *os.File, start, len int64) xErrorCode {
return vfsOS.lockErrorCode(vfsOS.fcntlSetLock(file, &syscall.Flock_t{
Type: syscall.F_WRLCK,
Start: start,
Len: len,
}), IOERR_LOCK)
}
func (vfsOSMethods) checkLock(file *os.File, start, len int64) (bool, xErrorCode) {
lock := syscall.Flock_t{
Type: syscall.F_RDLCK,
Start: start,
Len: len,
}
if vfsOS.fcntlGetLock(file, &lock) != nil {
return false, IOERR_CHECKRESERVEDLOCK
}
return lock.Type != syscall.F_UNLCK, _OK
}
func (vfsOSMethods) fcntlGetLock(file *os.File, lock *syscall.Flock_t) error {
var F_OFD_GETLK int
switch runtime.GOOS {
case "linux":
// https://github.com/torvalds/linux/blob/master/include/uapi/asm-generic/fcntl.h
F_OFD_GETLK = 36 // F_OFD_GETLK
case "darwin":
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
F_OFD_GETLK = 92 // F_OFD_GETLK
case "illumos":
// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/fcntl.h
F_OFD_GETLK = 47 // F_OFD_GETLK
default:
return notImplErr
}
return syscall.FcntlFlock(file.Fd(), F_OFD_GETLK, lock)
}
func (vfsOSMethods) fcntlSetLock(file *os.File, lock *syscall.Flock_t) error {
var F_OFD_SETLK int
switch runtime.GOOS {
case "linux":
// https://github.com/torvalds/linux/blob/master/include/uapi/asm-generic/fcntl.h
F_OFD_SETLK = 37 // F_OFD_SETLK
case "darwin":
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
F_OFD_SETLK = 90 // F_OFD_SETLK
case "illumos":
// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/fcntl.h
F_OFD_SETLK = 48 // F_OFD_SETLK
default:
return notImplErr
}
return syscall.FcntlFlock(file.Fd(), F_OFD_SETLK, lock)
}
func (vfsOSMethods) lockErrorCode(err error, def xErrorCode) xErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(syscall.Errno); ok {
switch errno {
case
syscall.EACCES,
syscall.EAGAIN,
syscall.EBUSY,
syscall.EINTR,
syscall.ENOLCK,
syscall.EDEADLK,
syscall.ETIMEDOUT:
return xErrorCode(BUSY)
case syscall.EPERM:
return xErrorCode(PERM)
}
}
return def
}

View File

@@ -1,102 +0,0 @@
package sqlite3
import (
"os"
"syscall"
"golang.org/x/sys/windows"
)
func (vfsOSMethods) DeleteOnClose(file *os.File) {}
func (vfsOSMethods) GetExclusiveLock(file *os.File) xErrorCode {
// Release the SHARED lock.
vfsOS.unlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Acquire the EXCLUSIVE lock.
rc := vfsOS.writeLock(file, _SHARED_FIRST, _SHARED_SIZE)
// Reacquire the SHARED lock.
if rc != _OK {
vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE)
}
return rc
}
func (vfsOSMethods) DowngradeLock(file *os.File, state vfsLockState) xErrorCode {
if state >= _EXCLUSIVE_LOCK {
// Release the SHARED lock.
vfsOS.unlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Reacquire the SHARED lock.
if rc := vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE); rc != _OK {
// This should never happen.
// We should always be able to reacquire the read lock.
return IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
if state >= _RESERVED_LOCK {
vfsOS.unlock(file, _RESERVED_BYTE, 1)
}
if state >= _PENDING_LOCK {
vfsOS.unlock(file, _PENDING_BYTE, 1)
}
return _OK
}
func (vfsOSMethods) ReleaseLock(file *os.File, state vfsLockState) xErrorCode {
// Release all locks.
if state >= _RESERVED_LOCK {
vfsOS.unlock(file, _RESERVED_BYTE, 1)
}
if state >= _SHARED_LOCK {
vfsOS.unlock(file, _SHARED_FIRST, _SHARED_SIZE)
}
if state >= _PENDING_LOCK {
vfsOS.unlock(file, _PENDING_BYTE, 1)
}
return _OK
}
func (vfsOSMethods) unlock(file *os.File, start, len uint32) xErrorCode {
err := windows.UnlockFileEx(windows.Handle(file.Fd()),
0, len, 0, &windows.Overlapped{Offset: start})
if err != nil {
return IOERR_UNLOCK
}
return _OK
}
func (vfsOSMethods) readLock(file *os.File, start, len uint32) xErrorCode {
return vfsOS.lockErrorCode(windows.LockFileEx(windows.Handle(file.Fd()),
windows.LOCKFILE_FAIL_IMMEDIATELY,
0, len, 0, &windows.Overlapped{Offset: start}),
IOERR_RDLOCK)
}
func (vfsOSMethods) writeLock(file *os.File, start, len uint32) xErrorCode {
return vfsOS.lockErrorCode(windows.LockFileEx(windows.Handle(file.Fd()),
windows.LOCKFILE_FAIL_IMMEDIATELY|windows.LOCKFILE_EXCLUSIVE_LOCK,
0, len, 0, &windows.Overlapped{Offset: start}),
IOERR_LOCK)
}
func (vfsOSMethods) checkLock(file *os.File, start, len uint32) (bool, xErrorCode) {
rc := vfsOS.readLock(file, start, len)
if rc == _OK {
vfsOS.unlock(file, start, len)
}
return rc != _OK, _OK
}
func (vfsOSMethods) lockErrorCode(err error, def xErrorCode) xErrorCode {
if err == nil {
return _OK
}
if errno, _ := err.(syscall.Errno); errno == windows.ERROR_INVALID_HANDLE {
return def
}
return xErrorCode(BUSY)
}