Compare commits

..

120 Commits

Author SHA1 Message Date
Nuno Cruces
76171da12b go-sqlite3 v0.7.3. 2023-06-15 03:56:02 +01:00
Nuno Cruces
dcc845d684 wazero v1.2.1. 2023-06-15 03:43:25 +01:00
dependabot[bot]
f1b42c26d5 Bump golang.org/x/sync from 0.2.0 to 0.3.0
Bumps [golang.org/x/sync](https://github.com/golang/sync) from 0.2.0 to 0.3.0.
- [Commits](https://github.com/golang/sync/compare/v0.2.0...v0.3.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sync
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-06-15 00:13:43 +01:00
dependabot[bot]
1e94407ae7 Bump golang.org/x/sys from 0.8.0 to 0.9.0
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.8.0 to 0.9.0.
- [Commits](https://github.com/golang/sys/compare/v0.8.0...v0.9.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-06-13 00:31:02 +01:00
Nuno Cruces
eb8d9b95fd Consistent lock timeouts. 2023-06-12 13:04:37 +01:00
Nuno Cruces
04037a75ed GORM driver sync. 2023-06-12 10:56:03 +01:00
Nuno Cruces
2472ceb0a0 Fix GORM module name. 2023-06-07 12:40:18 +01:00
Nuno Cruces
bfe9bfde2e Make GORM driver its own module. 2023-06-07 12:00:46 +01:00
Nuno Cruces
f07e82e361 GORM driver. 2023-06-06 12:37:54 +01:00
Nuno Cruces
fbbbe5a631 Fix plain files. 2023-06-06 03:47:02 +01:00
Nuno Cruces
5ea603ed78 Readers should not close. 2023-06-02 15:00:12 +01:00
Nuno Cruces
401cb77e38 binaryen-version_113. 2023-06-02 14:23:12 +01:00
Nuno Cruces
6511175011 wazero v1.2.0. 2023-06-02 14:11:20 +01:00
Nuno Cruces
f7d987fdf1 Commit phase-two API. 2023-06-02 13:40:08 +01:00
Nuno Cruces
00ba681bb5 Batch atomic writes API. 2023-06-02 11:14:34 +01:00
Nuno Cruces
d4d4533a41 Docs. 2023-06-02 03:38:26 +01:00
Nuno Cruces
ec9533b13f Implement modeof. 2023-06-02 03:38:26 +01:00
Nuno Cruces
8fe77a065c Remove wzprof. 2023-06-02 03:38:26 +01:00
Nuno Cruces
7bf5312bd4 Rename. 2023-06-02 03:38:26 +01:00
Nuno Cruces
ae7b74d858 Upgrade wzprof. 2023-06-01 16:09:18 +01:00
Nuno Cruces
9a8de3ad13 Enable memdb on speedtest1. 2023-06-01 15:41:20 +01:00
Nuno Cruces
05737e6025 Refactor reader VFS API. 2023-05-31 19:24:41 +01:00
Nuno Cruces
ac2836bb82 Refactor memdb API. 2023-05-31 16:27:31 +01:00
Nuno Cruces
d0d4b0e1a2 MemoryVFS mutexes. 2023-05-31 12:57:18 +01:00
Nuno Cruces
dc3dc6853d MemoryVFS journal. 2023-05-31 11:56:48 +01:00
dependabot[bot]
830240c368 Bump github.com/stealthrocket/wzprof from 0.1.3 to 0.1.4
Bumps [github.com/stealthrocket/wzprof](https://github.com/stealthrocket/wzprof) from 0.1.3 to 0.1.4.
- [Release notes](https://github.com/stealthrocket/wzprof/releases)
- [Changelog](https://github.com/stealthrocket/wzprof/blob/main/.goreleaser.yml)
- [Commits](https://github.com/stealthrocket/wzprof/compare/v0.1.3...v0.1.4)

---
updated-dependencies:
- dependency-name: github.com/stealthrocket/wzprof
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-05-30 13:40:52 +01:00
Nuno Cruces
dedec8682b Driver improvements. 2023-05-30 13:39:34 +01:00
Nuno Cruces
a33b828e13 Examples, tests, max size. 2023-05-30 11:21:14 +01:00
Nuno Cruces
8b2e96dedc Tests, fixes, docs. 2023-05-29 16:52:43 +01:00
Nuno Cruces
f1c46db512 VFS locking. 2023-05-27 23:36:39 +01:00
Nuno Cruces
7ca9d79424 MemoryVFS. 2023-05-27 23:36:39 +01:00
Nuno Cruces
254d473546 VFS URI parameters. 2023-05-27 23:36:39 +01:00
Nuno Cruces
5639fc1ff8 Update wzprof. 2023-05-27 23:36:39 +01:00
Nuno Cruces
ae4954d09b Profile with wzprof. 2023-05-27 23:36:39 +01:00
Nuno Cruces
45937d9749 Use wazerotest. 2023-05-27 23:36:39 +01:00
Nuno Cruces
eee71e06aa Tweak calling convention. 2023-05-25 17:03:40 +01:00
Nuno Cruces
9e7b6bb8ea Improve connection setup. 2023-05-25 11:14:18 +01:00
Nuno Cruces
597178f80d Backup fix, tests. 2023-05-24 02:47:18 +01:00
Nuno Cruces
cc2d16ac83 ReaderVFS. 2023-05-23 16:34:09 +01:00
Nuno Cruces
cfb69e4ce7 Reorg. 2023-05-23 14:47:39 +01:00
Nuno Cruces
e6969432e3 Rename. 2023-05-23 14:47:38 +01:00
Nuno Cruces
2b3da350cc Improved error handling. 2023-05-23 14:47:38 +01:00
Nuno Cruces
336ba87d56 Documentation. 2023-05-19 19:47:43 +01:00
Nuno Cruces
dd4823ebf0 Documentation, tests. 2023-05-19 14:45:40 +01:00
Nuno Cruces
663b23ff3b Documentation. 2023-05-19 13:47:37 +01:00
Nuno Cruces
4e2ce6c635 Refactor VFS. 2023-05-19 03:04:07 +01:00
Nuno Cruces
66effb4249 Rename. 2023-05-19 02:28:30 +01:00
Nuno Cruces
e1cce83f71 More VFS API. 2023-05-19 02:00:16 +01:00
Nuno Cruces
df953b31c2 Refactor VFS. 2023-05-18 16:00:34 +01:00
Nuno Cruces
67cc3d35d5 More VFS API. 2023-05-18 01:34:54 +01:00
Nuno Cruces
6846b72b31 Add SetInterrupt to DriverConn. 2023-05-17 14:38:47 +01:00
Nuno Cruces
c94cdaf720 More VFS API. 2023-05-17 14:38:47 +01:00
Nuno Cruces
f6a887dd1c Allow manual runs. 2023-05-17 14:38:35 +01:00
Nuno Cruces
2a010a2022 Towards VFS API. 2023-05-17 01:00:08 +01:00
Nuno Cruces
c86b06b048 Refactor. 2023-05-16 17:52:37 +01:00
Nuno Cruces
a44a13a506 Rename. 2023-05-16 15:40:08 +01:00
Nuno Cruces
4604719966 SQLite 3.42.0. 2023-05-16 14:56:47 +01:00
Nuno Cruces
03168d5d34 Build scripts. 2023-05-16 12:14:34 +01:00
Nuno Cruces
be4b6304f9 Documentation. 2023-05-16 12:14:23 +01:00
Nuno Cruces
b5e678a40a Inline host calls. 2023-05-09 14:41:24 +01:00
Nuno Cruces
2fc4698ddc Avoid some allocs. 2023-05-08 10:47:19 +01:00
dependabot[bot]
bd86539577 Bump golang.org/x/sync from 0.1.0 to 0.2.0
Bumps [golang.org/x/sync](https://github.com/golang/sync) from 0.1.0 to 0.2.0.
- [Commits](https://github.com/golang/sync/compare/v0.1.0...v0.2.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sync
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-05-05 07:52:12 +01:00
dependabot[bot]
7a785d9aec Bump golang.org/x/sys from 0.7.0 to 0.8.0
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.7.0 to 0.8.0.
- [Commits](https://github.com/golang/sys/compare/v0.7.0...v0.8.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-05-05 07:51:01 +01:00
Nuno Cruces
59f79e8e74 Optimize calls. 2023-05-02 01:08:04 +01:00
dependabot[bot]
40457721d7 Bump github.com/tetratelabs/wazero from 1.0.3 to 1.1.0 (#11)
Bumps [github.com/tetratelabs/wazero](https://github.com/tetratelabs/wazero) from 1.0.3 to 1.1.0.
- [Release notes](https://github.com/tetratelabs/wazero/releases)
- [Commits](https://github.com/tetratelabs/wazero/compare/v1.0.3...v1.1.0)

---
updated-dependencies:
- dependency-name: github.com/tetratelabs/wazero
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-02 01:03:42 +01:00
Nuno Cruces
18eeb85783 Improve mock. 2023-04-28 13:50:50 +01:00
Nuno Cruces
b36536979b Fix reopen. 2023-04-28 13:50:32 +01:00
Nuno Cruces
a6226c3b31 wazero v1.0.3. 2023-04-22 10:06:48 +01:00
Nuno Cruces
bed2ee7674 Refactor. 2023-04-22 00:15:44 +01:00
Nuno Cruces
7e6d178122 Fix. 2023-04-21 13:33:24 +01:00
Nuno Cruces
f360c77a78 Optimize blobs. (#10) 2023-04-21 13:31:45 +01:00
Nuno Cruces
759b11a05d wazero 1.0.2. 2023-04-18 23:33:56 +01:00
Nuno Cruces
93ce586139 Optimize time. 2023-04-18 01:00:59 +01:00
Nuno Cruces
2e5082c616 Query pragmas at startup. 2023-04-17 00:29:20 +01:00
Nuno Cruces
34acc28af8 Fix CI. 2023-04-14 15:48:20 +01:00
Nuno Cruces
c1a640f7d8 Build using wasi-sdk. 2023-04-14 15:31:17 +01:00
Nuno Cruces
005b15610a Memory optimizations. 2023-04-11 15:33:38 +01:00
Nuno Cruces
23ee4ccb0b Refactor. 2023-04-10 19:55:44 +01:00
Nuno Cruces
3a8cfd036d Dependencies. 2023-04-10 14:24:06 +01:00
Nuno Cruces
c38382fd8e Refactor. 2023-03-31 14:33:24 +01:00
Nuno Cruces
8509e0b6c8 Test coverage. 2023-03-31 13:42:31 +01:00
Nuno Cruces
9c07e57252 Refactor. 2023-03-29 15:06:22 +01:00
Nuno Cruces
80039385d3 Read only files. 2023-03-25 11:46:13 +00:00
Nuno Cruces
89f4327b2b Sync journal directories. 2023-03-25 11:16:51 +00:00
Nuno Cruces
37a3ff37e8 wazero 1.0. 2023-03-24 21:17:30 +00:00
Nuno Cruces
d880d6842c Refactor VFS. 2023-03-23 13:29:26 +00:00
Nuno Cruces
bef46e7954 Locking improvements (windows). 2023-03-23 12:40:55 +00:00
Nuno Cruces
4e72b4d117 Locking fix. 2023-03-23 11:26:19 +00:00
Nuno Cruces
3b08d02a83 Lock refactoring. 2023-03-23 01:55:54 +00:00
Nuno Cruces
b19c12c4c7 SQLite 3.41.2, prefer speed over size. 2023-03-23 00:44:43 +00:00
Nuno Cruces
859a21ef4e CI improvements. 2023-03-22 12:08:33 +00:00
Nuno Cruces
8ff0ee752f Use flock. 2023-03-22 03:15:54 +00:00
Nuno Cruces
589ad86f76 Extensions. 2023-03-21 00:13:12 +00:00
Nuno Cruces
1a3a1be1f6 Fix test. 2023-03-20 14:26:25 +00:00
Nuno Cruces
222c217bc8 Scripts. 2023-03-20 13:06:31 +00:00
Nuno Cruces
c1dc716391 VFS performance. 2023-03-20 11:02:34 +00:00
Nuno Cruces
71e1e5a8ee Avoid some copies. 2023-03-20 02:16:42 +00:00
Nuno Cruces
e4efb20c71 Generate coverage chart. 2023-03-18 03:51:05 +00:00
Nuno Cruces
2c9459d907 Add SQLite speedtest1. 2023-03-18 03:03:11 +00:00
Nuno Cruces
d0875e5fab Lock timeouts. 2023-03-18 01:13:31 +00:00
Nuno Cruces
15dec13f15 FCNTL_SIZE_HINT, refactor. 2023-03-17 17:13:03 +00:00
Nuno Cruces
f38e36109a FCNTL_HAS_MOVED. 2023-03-17 14:11:09 +00:00
Nuno Cruces
4cb65ccbd9 xFileControl, xDeviceCharacteristics, PSOW. 2023-03-17 13:39:19 +00:00
Nuno Cruces
f789c2fb8b OPEN_NOFOLLOW. 2023-03-16 12:27:44 +00:00
Nuno Cruces
c6a2617dfc Locking fixes. 2023-03-16 02:52:22 +00:00
Nuno Cruces
6fc0afcd12 Towards lock timeouts. 2023-03-15 13:58:16 +00:00
Nuno Cruces
77088962f5 SQLite 3.41.1. 2023-03-15 13:29:09 +00:00
Nuno Cruces
71da34861b Fix time collation. 2023-03-13 04:19:58 +00:00
Nuno Cruces
56e8281bdb Time collation tests. 2023-03-10 16:42:20 +00:00
Nuno Cruces
f61d430e65 Documentation. 2023-03-10 16:26:19 +00:00
Nuno Cruces
dbaed53b9a Sync and delete improvements. 2023-03-10 14:17:02 +00:00
Nuno Cruces
8b1bfd04e3 Simplify windows hacks. 2023-03-10 10:43:02 +00:00
Nuno Cruces
11c1687146 Time collation. 2023-03-09 14:42:29 +00:00
Nuno Cruces
94c43a8685 Use access syscall. 2023-03-09 01:59:46 +00:00
Nuno Cruces
a25159a070 Fix sharing violation. 2023-03-09 01:23:52 +00:00
Nuno Cruces
e007e9b060 Windows fixes. 2023-03-08 20:10:46 +00:00
Nuno Cruces
66a730893f Fix readonly transaction rollback. 2023-03-08 18:07:21 +00:00
Nuno Cruces
926adeb3f5 Remove MustPrepare. 2023-03-08 17:39:41 +00:00
Nuno Cruces
677f51bec1 Savepoint API. 2023-03-08 17:39:23 +00:00
Nuno Cruces
5d6f92b733 Documentation, tests, tweaks. 2023-03-08 13:29:33 +00:00
131 changed files with 7161 additions and 2636 deletions

2
.github/FUNDING.yml vendored
View File

@@ -1 +1 @@
custom: https://www.paypal.com/donate/buttons/manage/33P59ELZWGMK6
custom: https://www.paypal.com/donate?hosted_button_id=33P59ELZWGMK6

View File

@@ -5,6 +5,7 @@ on:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
workflow_dispatch:
jobs:
test:
@@ -18,11 +19,28 @@ jobs:
with:
lfs: 'true'
- name: Set up Go
uses: actions/setup-go@v3
- name: Set up
uses: actions/setup-go@v4
with:
go-version: stable
cache: true
- name: Format
run: gofmt -s -w . && git diff --exit-code
if: matrix.os != 'windows-latest'
- name: Tidy
run: go mod tidy && git diff --exit-code
- name: Download
run: go mod download
# Fixed in go 1.21: https://go.dev/issue/54372
# - name: Verify
# run: go mod verify
- name: Vet
run: go vet ./...
continue-on-error: true
- name: Build
run: go build -v ./...
@@ -30,12 +48,15 @@ jobs:
- name: Test
run: go test -v ./...
- name: Test data races
run: go test -v -race ./...
if: matrix.os == 'ubuntu-latest'
- name: Test BSD locks
run: go test -v -tags sqlite3_bsd ./...
if: matrix.os == 'macos-latest'
- name: Update coverage report
uses: ncruces/go-coverage-report@main
- name: Coverage report
uses: ncruces/go-coverage-report@v0
with:
chart: 'true'
amend: 'true'
if: |
matrix.os == 'ubuntu-latest' &&
github.event_name == 'push'

View File

@@ -15,9 +15,17 @@ provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
- Package [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
embeds a build of SQLite into your application.
- Package [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
- Package [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
provides a [GORM](https://gorm.io) driver.
### Caveats
This module replaces the SQLite [OS Interface](https://www.sqlite.org/vfs.html)
(aka VFS) with a [pure Go](vfs/) implementation.
This has benefits, but also comes with some drawbacks.
#### Write-Ahead Logging
Because WASM does not support shared memory,
@@ -30,44 +38,47 @@ For non-WAL databases, `NORMAL` locking mode can be activated with
[`PRAGMA locking_mode=NORMAL`](https://www.sqlite.org/pragma.html#pragma_locking_mode).
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
the `database/sql` driver defaults to `NORMAL` locking mode,
and WAL databases are not supported.
the `database/sql` driver defaults to `NORMAL` locking mode.
To open WAL databases, or use `EXCLUSIVE` locking mode,
disable connection pooling by calling
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
#### Open File Description Locks
#### POSIX Advisory Locks
On Unix, this module uses [OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
POSIX advisory locks, which SQLite uses, are
[broken by design](https://www.sqlite.org/src/artifact/90c4fa?ln=1073-1161).
On Linux, macOS and illumos, this module uses
[OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
to synchronize access to database files.
OFD locks are fully compatible with process-associated POSIX advisory locks.
POSIX advisory locks, which SQLite uses, are [broken by design](https://www.sqlite.org/src/artifact/90c4fa?ln=1073-1161).
OFD locks are fully compatible with process-associated POSIX advisory locks,
and are supported on Linux, macOS and illumos.
As a work around for other Unixes, you can use [`nolock=1`](https://www.sqlite.org/uri.html).
On BSD Unixes, this module uses
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2).
BSD locks may _not_ be compatible with process-associated POSIX advisory locks.
#### Testing
The pure Go VFS is tested by running an unmodified build of SQLite's
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
on Linux, macOS and Windows.
Performance is tested by running
[speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c).
### Roadmap
- [x] build SQLite using `zig cc --target=wasm32-wasi`
- [x] `:memory:` databases
- [x] port [`test_demovfs.c`](https://www.sqlite.org/src/doc/trunk/src/test_demovfs.c) to Go
- branch [`wasi`](https://github.com/ncruces/go-sqlite3/tree/wasi) uses `test_demovfs.c` directly
- [x] design a nice API, enough for simple use cases
- [x] provide a simple `database/sql` driver
- [x] file locking, compatible with SQLite on macOS/Linux/Windows
- [ ] advanced SQLite features
- [x] nested transactions
- [x] incremental BLOB I/O
- [x] online backup
- [ ] snapshots
- [ ] session extension
- [ ] resumable bulk update
- [ ] shared-cache mode
- [ ] unlock-notify
- [ ] custom SQL functions
- [ ] custom VFSes
- [ ] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
- [ ] in-memory VFS, wrapping a [`bytes.Buffer`](https://pkg.go.dev/bytes#Buffer)
- [x] custom VFS API
- [x] in-memory VFS
- [x] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
- [ ] custom VFS API
- [ ] custom SQL functions
### Alternatives
- [`modernc.org/sqlite`](https://pkg.go.dev/modernc.org/sqlite)

View File

@@ -1,6 +1,6 @@
package sqlite3
// Backup is a handle to an open BLOB.
// Backup is an handle to an ongoing online backup operation.
//
// https://www.sqlite.org/c3ref/backup.html
type Backup struct {
@@ -11,7 +11,7 @@ type Backup struct {
// Backup backs up srcDB on the src connection to the "main" database in dstURI.
//
// Backup calls [Conn.Open] to open the SQLite database file dstURI,
// Backup opens the SQLite database file dstURI,
// and blocks until the entire backup is complete.
// Use [Conn.BackupInit] for incremental backup.
//
@@ -28,12 +28,12 @@ func (src *Conn) Backup(srcDB, dstURI string) error {
// Restore restores dstDB on the dst connection from the "main" database in srcURI.
//
// Restore calls [Conn.Open] to open the SQLite database file srcURI,
// Restore opens the SQLite database file srcURI,
// and blocks until the entire restore is complete.
//
// https://www.sqlite.org/backup.html
func (dst *Conn) Restore(dstDB, srcURI string) error {
src, err := dst.openDB(srcURI, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
src, err := dst.openDB(srcURI, OPEN_READONLY|OPEN_URI)
if err != nil {
return err
}
@@ -48,7 +48,7 @@ func (dst *Conn) Restore(dstDB, srcURI string) error {
// BackupInit initializes a backup operation to copy the content of one database into another.
//
// BackupInit calls [Conn.Open] to open the SQLite database file dstURI,
// BackupInit opens the SQLite database file dstURI,
// then initializes a backup that copies the contents of srcDB on the src connection
// to the "main" database in dstURI.
//
@@ -74,16 +74,16 @@ func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string
r := c.call(c.api.backupInit,
uint64(dst), uint64(dstPtr),
uint64(src), uint64(srcPtr))
if r[0] == 0 {
if r == 0 {
defer c.closeDB(other)
r = c.call(c.api.errcode, uint64(dst))
return nil, c.module.error(r[0], dst)
return nil, c.module.error(r, dst)
}
return &Backup{
c: c,
otherc: other,
handle: uint32(r[0]),
handle: uint32(r),
}, nil
}
@@ -100,7 +100,7 @@ func (b *Backup) Close() error {
r := b.c.call(b.c.api.backupFinish, uint64(b.handle))
b.c.closeDB(b.otherc)
b.handle = 0
return b.c.error(r[0])
return b.c.error(r)
}
// Step copies up to nPage pages between the source and destination databases.
@@ -109,10 +109,10 @@ func (b *Backup) Close() error {
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupstep
func (b *Backup) Step(nPage int) (done bool, err error) {
r := b.c.call(b.c.api.backupStep, uint64(b.handle), uint64(nPage))
if r[0] == _DONE {
if r == _DONE {
return true, nil
}
return false, b.c.error(r[0])
return false, b.c.error(r)
}
// Remaining returns the number of pages still to be backed up
@@ -121,7 +121,7 @@ func (b *Backup) Step(nPage int) (done bool, err error) {
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupremaining
func (b *Backup) Remaining() int {
r := b.c.call(b.c.api.backupRemaining, uint64(b.handle))
return int(r[0])
return int(r)
}
// PageCount returns the total number of pages in the source database
@@ -129,6 +129,6 @@ func (b *Backup) Remaining() int {
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backuppagecount
func (b *Backup) PageCount() int {
r := b.c.call(b.c.api.backupFinish, uint64(b.handle))
return int(r[0])
r := b.c.call(b.c.api.backupPageCount, uint64(b.handle))
return int(r)
}

157
blob.go
View File

@@ -1,22 +1,26 @@
package sqlite3
import "io"
import (
"io"
"github.com/ncruces/go-sqlite3/internal/util"
)
// ZeroBlob represents a zero-filled, length n BLOB
// that can be used as an argument to
// [database/sql.DB.Exec] and similar methods.
type ZeroBlob int64
// Blob is a handle to an open BLOB.
// Blob is an handle to an open BLOB.
//
// It implements [io.ReadWriteSeeker] for incremental BLOB I/O.
//
// https://www.sqlite.org/c3ref/blob.html
type Blob struct {
c *Conn
handle uint32
bytes int64
offset int64
handle uint32
}
var _ io.ReadWriteSeeker = &Blob{}
@@ -25,6 +29,7 @@ var _ io.ReadWriteSeeker = &Blob{}
//
// https://www.sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
c.checkInterrupt()
defer c.arena.reset()
blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db)
@@ -40,13 +45,13 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob,
uint64(dbPtr), uint64(tablePtr), uint64(columnPtr),
uint64(row), flags, uint64(blobPtr))
if err := c.error(r[0]); err != nil {
if err := c.error(r); err != nil {
return nil, err
}
blob := Blob{c: c}
blob.handle = c.mem.readUint32(blobPtr)
blob.bytes = int64(c.call(c.api.blobBytes, uint64(blob.handle))[0])
blob.handle = util.ReadUint32(c.mod, blobPtr)
blob.bytes = int64(c.call(c.api.blobBytes, uint64(blob.handle)))
return &blob, nil
}
@@ -63,7 +68,7 @@ func (b *Blob) Close() error {
r := b.c.call(b.c.api.blobClose, uint64(b.handle))
b.handle = 0
return b.c.error(r[0])
return b.c.error(r)
}
// Size returns the size of the BLOB in bytes.
@@ -81,8 +86,40 @@ func (b *Blob) Read(p []byte) (n int, err error) {
return 0, io.EOF
}
want := int64(len(p))
avail := b.bytes - b.offset
want := int64(len(p))
if want > avail {
want = avail
}
defer b.c.arena.reset()
ptr := b.c.arena.new(uint64(want))
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
uint64(ptr), uint64(want), uint64(b.offset))
err = b.c.error(r)
if err != nil {
return 0, err
}
b.offset += want
if b.offset >= b.bytes {
err = io.EOF
}
copy(p, util.View(b.c.mod, ptr, uint64(want)))
return int(want), err
}
// WriteTo implements the [io.WriterTo] interface.
//
// https://www.sqlite.org/c3ref/blob_read.html
func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
if b.offset >= b.bytes {
return 0, nil
}
avail := b.bytes - b.offset
want := int64(65536)
if want > avail {
want = avail
}
@@ -90,37 +127,43 @@ func (b *Blob) Read(p []byte) (n int, err error) {
ptr := b.c.new(uint64(want))
defer b.c.free(ptr)
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
uint64(ptr), uint64(want), uint64(b.offset))
err = b.c.error(r[0])
if err != nil {
return 0, err
}
for want > 0 {
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
uint64(ptr), uint64(want), uint64(b.offset))
err = b.c.error(r)
if err != nil {
return n, err
}
mem := b.c.mem.view(ptr, uint64(want))
copy(p, mem)
b.offset += want
if b.offset >= b.bytes {
err = io.EOF
mem := util.View(b.c.mod, ptr, uint64(want))
m, err := w.Write(mem[:want])
b.offset += int64(m)
n += int64(m)
if err != nil {
return n, err
}
if int64(m) != want {
return n, io.ErrShortWrite
}
avail = b.bytes - b.offset
if want > avail {
want = avail
}
}
return int(want), err
return n, nil
}
// Write implements the [io.Writer] interface.
//
// https://www.sqlite.org/c3ref/blob_write.html
func (b *Blob) Write(p []byte) (n int, err error) {
offset := b.offset
if offset > b.bytes {
offset = b.bytes
}
ptr := b.c.newBytes(p)
defer b.c.free(ptr)
defer b.c.arena.reset()
ptr := b.c.arena.bytes(p)
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
uint64(ptr), uint64(len(p)), uint64(offset))
err = b.c.error(r[0])
uint64(ptr), uint64(len(p)), uint64(b.offset))
err = b.c.error(r)
if err != nil {
return 0, err
}
@@ -128,11 +171,57 @@ func (b *Blob) Write(p []byte) (n int, err error) {
return len(p), nil
}
// ReadFrom implements the [io.ReaderFrom] interface.
//
// https://www.sqlite.org/c3ref/blob_write.html
func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) {
avail := b.bytes - b.offset
want := int64(65536)
if want > avail {
want = avail
}
if want < 1 {
want = 1
}
ptr := b.c.new(uint64(want))
defer b.c.free(ptr)
for {
mem := util.View(b.c.mod, ptr, uint64(want))
m, err := r.Read(mem[:want])
if m > 0 {
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
uint64(ptr), uint64(m), uint64(b.offset))
err := b.c.error(r)
if err != nil {
return n, err
}
b.offset += int64(m)
n += int64(m)
}
if err == io.EOF {
return n, nil
}
if err != nil {
return n, err
}
avail = b.bytes - b.offset
if want > avail {
want = avail
}
if want < 1 {
want = 1
}
}
}
// Seek implements the [io.Seeker] interface.
func (b *Blob) Seek(offset int64, whence int) (int64, error) {
switch whence {
default:
return 0, whenceErr
return 0, util.WhenceErr
case io.SeekStart:
break
case io.SeekCurrent:
@@ -141,7 +230,7 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
offset += b.bytes
}
if offset < 0 {
return 0, offsetErr
return 0, util.OffsetErr
}
b.offset = offset
return offset, nil
@@ -151,8 +240,8 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
//
// https://www.sqlite.org/c3ref/blob_reopen.html
func (b *Blob) Reopen(row int64) error {
r := b.c.call(b.c.api.blobReopen, uint64(b.handle), uint64(row))
b.bytes = int64(b.c.call(b.c.api.blobBytes, uint64(b.handle))[0])
err := b.c.error(b.c.call(b.c.api.blobReopen, uint64(b.handle), uint64(row)))
b.bytes = int64(b.c.call(b.c.api.blobBytes, uint64(b.handle)))
b.offset = 0
return b.c.error(r[0])
return err
}

109
conn.go
View File

@@ -3,12 +3,15 @@ package sqlite3
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"net/url"
"runtime"
"strings"
"sync/atomic"
"unsafe"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Conn is a database connection handle.
@@ -18,26 +21,31 @@ import (
type Conn struct {
*module
handle uint32
arena arena
interrupt context.Context
waiter chan struct{}
pending *Stmt
arena arena
handle uint32
}
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI].
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE], [OPEN_URI] and [OPEN_NOFOLLOW].
func Open(filename string) (*Conn, error) {
return newConn(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
return newConn(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI|OPEN_NOFOLLOW)
}
// OpenFlags opens an SQLite database file as specified by the filename argument.
//
// If none of the required flags is used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used.
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
//
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
//
// https://www.sqlite.org/c3ref/open.html
func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
if flags&(OPEN_READONLY|OPEN_READWRITE|OPEN_CREATE) == 0 {
flags |= OPEN_READWRITE | OPEN_CREATE
}
return newConn(filename, flags)
}
@@ -50,7 +58,7 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
if conn == nil {
mod.close()
} else {
runtime.SetFinalizer(conn, finalizer[Conn](3))
runtime.SetFinalizer(conn, util.Finalizer[Conn](3))
}
}()
@@ -68,10 +76,11 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
connPtr := c.arena.new(ptrlen)
namePtr := c.arena.string(filename)
flags |= OPEN_EXRESCODE
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
handle := c.mem.readUint32(connPtr)
if err := c.module.error(r[0], handle); err != nil {
handle := util.ReadUint32(c.mod, connPtr)
if err := c.module.error(r, handle); err != nil {
c.closeDB(handle)
return 0, err
}
@@ -87,19 +96,24 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
}
}
c.arena.reset()
pragmaPtr := c.arena.string(pragmas.String())
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.module.error(r[0], handle, pragmas.String()); err != nil {
if err := c.module.error(r, handle, pragmas.String()); err != nil {
if errors.Is(err, ERROR) {
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
c.closeDB(handle)
return 0, fmt.Errorf("sqlite3: invalid _pragma: %w", err)
return 0, err
}
}
return handle, nil
}
func (c *Conn) closeDB(handle uint32) {
r := c.call(c.api.closeZombie, uint64(handle))
if err := c.module.error(r[0], handle); err != nil {
if err := c.module.error(r, handle); err != nil {
panic(err)
}
}
@@ -119,9 +133,11 @@ func (c *Conn) Close() error {
}
c.SetInterrupt(context.Background())
c.pending.Close()
c.pending = nil
r := c.call(c.api.close, uint64(c.handle))
if err := c.error(r[0]); err != nil {
if err := c.error(r); err != nil {
return err
}
@@ -140,24 +156,7 @@ func (c *Conn) Exec(sql string) error {
sqlPtr := c.arena.string(sql)
r := c.call(c.api.exec, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
return c.error(r[0])
}
// MustPrepare calls [Conn.Prepare] and panics on error,
// a nil Stmt, or a non-empty tail.
func (c *Conn) MustPrepare(sql string) *Stmt {
s, tail, err := c.PrepareFlags(sql, 0)
if err != nil {
panic(err)
}
if s == nil {
panic(emptyErr)
}
if !emptyStatement(tail) {
s.Close()
panic(tailErr)
}
return s
return c.error(r)
}
// Prepare calls [Conn.PrepareFlags] with no flags.
@@ -186,11 +185,11 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
uint64(stmtPtr), uint64(tailPtr))
stmt = &Stmt{c: c}
stmt.handle = c.mem.readUint32(stmtPtr)
i := c.mem.readUint32(tailPtr)
stmt.handle = util.ReadUint32(c.mod, stmtPtr)
i := util.ReadUint32(c.mod, tailPtr)
tail = sql[i-sqlPtr:]
if err := c.error(r[0], sql); err != nil {
if err := c.error(r, sql); err != nil {
return nil, "", err
}
if stmt.handle == 0 {
@@ -204,7 +203,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
// https://www.sqlite.org/c3ref/get_autocommit.html
func (c *Conn) GetAutocommit() bool {
r := c.call(c.api.autocommit, uint64(c.handle))
return r[0] != 0
return r != 0
}
// LastInsertRowID returns the rowid of the most recent successful INSERT
@@ -213,7 +212,7 @@ func (c *Conn) GetAutocommit() bool {
// https://www.sqlite.org/c3ref/last_insert_rowid.html
func (c *Conn) LastInsertRowID() int64 {
r := c.call(c.api.lastRowid, uint64(c.handle))
return int64(r[0])
return int64(r)
}
// Changes returns the number of rows modified, inserted or deleted
@@ -223,7 +222,7 @@ func (c *Conn) LastInsertRowID() int64 {
// https://www.sqlite.org/c3ref/changes.html
func (c *Conn) Changes() int64 {
r := c.call(c.api.changes, uint64(c.handle))
return int64(r[0])
return int64(r)
}
// SetInterrupt interrupts a long-running query when a context is done.
@@ -247,26 +246,23 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
<-c.waiter // Wait for it to finish.
c.waiter = nil
}
// Reset the pending statement.
if c.pending != nil {
c.pending.Reset()
}
old = c.interrupt
c.interrupt = ctx
if ctx == nil || ctx.Done() == nil {
// Finalize the uncompleted SQL statement.
if c.pending != nil {
c.pending.Close()
c.pending = nil
}
return old
}
// Creating an uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
if c.pending == nil {
c.pending = c.MustPrepare(`SELECT 1 UNION ALL SELECT 2`)
c.pending.Step()
} else {
c.pending.Reset()
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
}
c.pending.Step()
// Don't create the goroutine if we're already interrupted.
// This happens frequently while restoring to a previously interrupted state.
@@ -282,7 +278,8 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
break
case <-ctx.Done(): // Done was closed.
buf := c.mem.view(c.handle+c.api.interrupt, 4)
const isInterruptedOffset = 280
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
// Wait for the next call to SetInterrupt.
<-waiter
@@ -298,7 +295,8 @@ func (c *Conn) checkInterrupt() bool {
if c.interrupt == nil || c.interrupt.Err() == nil {
return false
}
buf := c.mem.view(c.handle+c.api.interrupt, 4)
const isInterruptedOffset = 280
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
return true
}
@@ -306,15 +304,18 @@ func (c *Conn) checkInterrupt() bool {
// Pragma executes a PRAGMA statement and returns any results.
//
// https://www.sqlite.org/pragma.html
func (c *Conn) Pragma(str string) []string {
stmt := c.MustPrepare(`PRAGMA ` + str)
func (c *Conn) Pragma(str string) ([]string, error) {
stmt, _, err := c.Prepare(`PRAGMA ` + str)
if err != nil {
return nil, err
}
defer stmt.Close()
var pragmas []string
for stmt.Step() {
pragmas = append(pragmas, stmt.ColumnText(0))
}
return pragmas
return pragmas, stmt.Close()
}
func (c *Conn) error(rc uint64, sql ...string) error {
@@ -324,15 +325,21 @@ func (c *Conn) error(rc uint64, sql ...string) error {
// DriverConn is implemented by the SQLite [database/sql] driver connection.
//
// It can be used to access advanced SQLite features like
// [savepoints] and [incremental BLOB I/O].
// [savepoints], [online backup] and [incremental BLOB I/O].
//
// [savepoints]: https://www.sqlite.org/lang_savepoint.html
// [online backup]: https://www.sqlite.org/backup.html
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
type DriverConn interface {
driver.Conn
driver.ConnBeginTx
driver.ExecerContext
driver.ConnPrepareContext
Savepoint() (release func(*error))
SetInterrupt(ctx context.Context) (old context.Context)
Savepoint() Savepoint
Backup(srcDB, dstURI string) error
Restore(dstDB, srcURI string) error
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
}

View File

@@ -9,8 +9,7 @@ const (
_UTF8 = 1
_MAX_STRING = 512 // Used for short strings: names, error messages…
_MAX_PATHNAME = 512
_MAX_STRING = 512 // Used for short strings: names, error messages…
_MAX_ALLOCATION_SIZE = 0x7ffffeff
@@ -133,46 +132,28 @@ const (
CONSTRAINT_DATATYPE ExtendedErrorCode = xErrorCode(CONSTRAINT) | (12 << 8)
NOTICE_RECOVER_WAL ExtendedErrorCode = xErrorCode(NOTICE) | (1 << 8)
NOTICE_RECOVER_ROLLBACK ExtendedErrorCode = xErrorCode(NOTICE) | (2 << 8)
NOTICE_RBU ExtendedErrorCode = xErrorCode(NOTICE) | (3 << 8)
WARNING_AUTOINDEX ExtendedErrorCode = xErrorCode(WARNING) | (1 << 8)
AUTH_USER ExtendedErrorCode = xErrorCode(AUTH) | (1 << 8)
)
// OpenFlag is a flag for a file open operation.
// OpenFlag is a flag for the [OpenFlags] function.
//
// https://www.sqlite.org/c3ref/c_open_autoproxy.html
type OpenFlag uint32
const (
OPEN_READONLY OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */
OPEN_READWRITE OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */
OPEN_CREATE OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */
OPEN_DELETEONCLOSE OpenFlag = 0x00000008 /* VFS only */
OPEN_EXCLUSIVE OpenFlag = 0x00000010 /* VFS only */
OPEN_AUTOPROXY OpenFlag = 0x00000020 /* VFS only */
OPEN_URI OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */
OPEN_MEMORY OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */
OPEN_MAIN_DB OpenFlag = 0x00000100 /* VFS only */
OPEN_TEMP_DB OpenFlag = 0x00000200 /* VFS only */
OPEN_TRANSIENT_DB OpenFlag = 0x00000400 /* VFS only */
OPEN_MAIN_JOURNAL OpenFlag = 0x00000800 /* VFS only */
OPEN_TEMP_JOURNAL OpenFlag = 0x00001000 /* VFS only */
OPEN_SUBJOURNAL OpenFlag = 0x00002000 /* VFS only */
OPEN_SUPER_JOURNAL OpenFlag = 0x00004000 /* VFS only */
OPEN_NOMUTEX OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */
OPEN_FULLMUTEX OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */
OPEN_SHAREDCACHE OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */
OPEN_PRIVATECACHE OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
OPEN_WAL OpenFlag = 0x00080000 /* VFS only */
OPEN_NOFOLLOW OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */
)
type _AccessFlag uint32
const (
_ACCESS_EXISTS _AccessFlag = 0
_ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
_ACCESS_READ _AccessFlag = 2 /* Unused */
OPEN_READONLY OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */
OPEN_READWRITE OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */
OPEN_CREATE OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */
OPEN_URI OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */
OPEN_MEMORY OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */
OPEN_NOMUTEX OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */
OPEN_FULLMUTEX OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */
OPEN_SHAREDCACHE OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */
OPEN_PRIVATECACHE OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
OPEN_NOFOLLOW OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */
)
// PrepareFlag is a flag that can be passed to [Conn.PrepareFlags].

View File

@@ -16,9 +16,12 @@
//
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
//
// If no PRAGMAs are specifed, a busy timeout of 1 minute
// If no PRAGMAs are specified, a busy timeout of 1 minute
// and normal locking mode are used.
//
// Order matters:
// busy timeout and locking mode should be the first PRAGMAs set, in that order.
//
// [URI]: https://www.sqlite.org/uri.html
// [PRAGMA]: https://www.sqlite.org/pragma.html
// [TRANSACTION]: https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
@@ -35,6 +38,7 @@ import (
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
func init() {
@@ -44,12 +48,13 @@ func init() {
type sqlite struct{}
func (sqlite) Open(name string) (_ driver.Conn, err error) {
c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE)
var c conn
c.Conn, err = sqlite3.Open(name)
if err != nil {
return nil, err
}
var txBegin string
c.txBegin = "BEGIN"
var pragmas []string
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
@@ -57,9 +62,9 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
switch s := query.Get("_txlock"); s {
case "":
txBegin = "BEGIN"
c.txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
txBegin = "BEGIN " + s
c.txBegin = "BEGIN " + s
default:
c.Close()
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
@@ -69,7 +74,7 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
}
}
if len(pragmas) == 0 {
err := c.Exec(`
err := c.Conn.Exec(`
PRAGMA busy_timeout=60000;
PRAGMA locking_mode=normal;
`)
@@ -77,81 +82,116 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
c.Close()
return nil, err
}
c.reusable = true
} else {
s, _, err := c.Conn.Prepare(`
SELECT * FROM
PRAGMA_locking_mode,
PRAGMA_query_only;
`)
if err != nil {
c.Close()
return nil, err
}
if s.Step() {
c.reusable = s.ColumnText(0) == "normal"
c.readOnly = s.ColumnRawText(1)[0] // 0 or 1
}
err = s.Close()
if err != nil {
c.Close()
return nil, err
}
}
return conn{
conn: c,
txBegin: txBegin,
}, nil
return &c, nil
}
type conn struct {
conn *sqlite3.Conn
txBegin string
txCommit string
*sqlite3.Conn
txBegin string
txCommit string
txRollback string
reusable bool
readOnly byte
}
var (
// Ensure these interfaces are implemented:
_ driver.ExecerContext = conn{}
_ driver.ConnBeginTx = conn{}
_ sqlite3.DriverConn = conn{}
_ driver.ExecerContext = &conn{}
_ driver.ConnBeginTx = &conn{}
_ driver.Validator = &conn{}
_ sqlite3.DriverConn = &conn{}
)
func (c conn) Close() error {
return c.conn.Close()
func (c *conn) IsValid() bool {
return c.reusable
}
func (c conn) Begin() (driver.Tx, error) {
func (c *conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
switch opts.Isolation {
default:
return nil, isolationErr
case driver.IsolationLevel(sql.LevelDefault):
case driver.IsolationLevel(sql.LevelSerializable):
}
func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
txBegin := c.txBegin
c.txCommit = `COMMIT`
c.txRollback = `ROLLBACK`
if opts.ReadOnly {
c.txCommit = `
ROLLBACK;
PRAGMA query_only=` + c.conn.Pragma("query_only")[0]
txBegin = `
BEGIN deferred;
PRAGMA query_only=on`
c.txCommit = `
ROLLBACK;
PRAGMA query_only=` + string(c.readOnly)
c.txRollback = c.txCommit
}
err := c.conn.Exec(txBegin)
switch opts.Isolation {
default:
return nil, util.IsolationErr
case
driver.IsolationLevel(sql.LevelDefault),
driver.IsolationLevel(sql.LevelSerializable):
break
}
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
err := c.Conn.Exec(txBegin)
if err != nil {
return nil, err
}
return c, nil
}
func (c conn) Commit() error {
err := c.conn.Exec(c.txCommit)
if err != nil {
func (c *conn) Commit() error {
err := c.Conn.Exec(c.txCommit)
if err != nil && !c.GetAutocommit() {
c.Rollback()
}
return err
}
func (c conn) Rollback() error {
return c.conn.Exec(`ROLLBACK`)
func (c *conn) Rollback() error {
return c.Conn.Exec(c.txRollback)
}
func (c conn) Prepare(query string) (driver.Stmt, error) {
s, tail, err := c.conn.Prepare(query)
func (c *conn) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}
func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
s, tail, err := c.Conn.Prepare(query)
if err != nil {
return nil, err
}
if tail != "" {
// Check if the tail contains any SQL.
st, _, err := c.conn.Prepare(tail)
st, _, err := c.Conn.Prepare(tail)
if err != nil {
s.Close()
return nil, err
@@ -159,64 +199,49 @@ func (c conn) Prepare(query string) (driver.Stmt, error) {
if st != nil {
s.Close()
st.Close()
return nil, tailErr
return nil, util.TailErr
}
}
return stmt{s, c.conn}, nil
return &stmt{s, c.Conn}, nil
}
func (c conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
return c.Prepare(query)
}
func (c conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if len(args) != 0 {
// Slow path.
return nil, driver.ErrSkip
}
old := c.conn.SetInterrupt(ctx)
defer c.conn.SetInterrupt(old)
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
err := c.conn.Exec(query)
err := c.Conn.Exec(query)
if err != nil {
return nil, err
}
return result{
c.conn.LastInsertRowID(),
c.conn.Changes(),
}, nil
}
func (c conn) Savepoint() (release func(*error)) {
return c.conn.Savepoint()
}
func (c conn) OpenBlob(db, table, column string, row int64, write bool) (*sqlite3.Blob, error) {
return c.conn.OpenBlob(db, table, column, row, write)
return newResult(c.Conn), nil
}
type stmt struct {
stmt *sqlite3.Stmt
conn *sqlite3.Conn
Stmt *sqlite3.Stmt
Conn *sqlite3.Conn
}
var (
// Ensure these interfaces are implemented:
_ driver.StmtExecContext = stmt{}
_ driver.StmtQueryContext = stmt{}
_ driver.NamedValueChecker = stmt{}
_ driver.StmtExecContext = &stmt{}
_ driver.StmtQueryContext = &stmt{}
_ driver.NamedValueChecker = &stmt{}
)
func (s stmt) Close() error {
return s.stmt.Close()
func (s *stmt) Close() error {
return s.Stmt.Close()
}
func (s stmt) NumInput() int {
n := s.stmt.BindCount()
func (s *stmt) NumInput() int {
n := s.Stmt.BindCount()
for i := 1; i <= n; i++ {
if s.stmt.BindName(i) != "" {
if s.Stmt.BindName(i) != "" {
return -1
}
}
@@ -224,16 +249,16 @@ func (s stmt) NumInput() int {
}
// Deprecated: use ExecContext instead.
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
return s.ExecContext(context.Background(), namedValues(args))
}
// Deprecated: use QueryContext instead.
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
return s.QueryContext(context.Background(), namedValues(args))
}
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
// Use QueryContext to setup bindings.
// No need to close rows: that simply resets the statement, exec does the same.
_, err := s.QueryContext(ctx, args)
@@ -241,19 +266,16 @@ func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver
return nil, err
}
err = s.stmt.Exec()
err = s.Stmt.Exec()
if err != nil {
return nil, err
}
return result{
int64(s.conn.LastInsertRowID()),
int64(s.conn.Changes()),
}, nil
return newResult(s.Conn), nil
}
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
err := s.stmt.ClearBindings()
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
err := s.Stmt.ClearBindings()
if err != nil {
return nil, err
}
@@ -265,7 +287,7 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
ids = append(ids, arg.Ordinal)
} else {
for _, prefix := range []string{":", "@", "$"} {
if id := s.stmt.BindIndex(prefix + arg.Name); id != 0 {
if id := s.Stmt.BindIndex(prefix + arg.Name); id != 0 {
ids = append(ids, id)
}
}
@@ -274,25 +296,25 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
for _, id := range ids {
switch a := arg.Value.(type) {
case bool:
err = s.stmt.BindBool(id, a)
err = s.Stmt.BindBool(id, a)
case int:
err = s.stmt.BindInt(id, a)
err = s.Stmt.BindInt(id, a)
case int64:
err = s.stmt.BindInt64(id, a)
err = s.Stmt.BindInt64(id, a)
case float64:
err = s.stmt.BindFloat(id, a)
err = s.Stmt.BindFloat(id, a)
case string:
err = s.stmt.BindText(id, a)
err = s.Stmt.BindText(id, a)
case []byte:
err = s.stmt.BindBlob(id, a)
err = s.Stmt.BindBlob(id, a)
case sqlite3.ZeroBlob:
err = s.stmt.BindZeroBlob(id, int64(a))
err = s.Stmt.BindZeroBlob(id, int64(a))
case time.Time:
err = s.stmt.BindText(id, a.Format(time.RFC3339Nano))
err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
case nil:
err = s.stmt.BindNull(id)
err = s.Stmt.BindNull(id)
default:
panic(assertErr)
panic(util.AssertErr())
}
}
if err != nil {
@@ -300,10 +322,10 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
}
}
return rows{ctx, s.stmt, s.conn}, nil
return &rows{ctx, s.Stmt, s.Conn}, nil
}
func (s stmt) CheckNamedValue(arg *driver.NamedValue) error {
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) {
case bool, int, int64, float64, string, []byte,
sqlite3.ZeroBlob, time.Time, nil:
@@ -313,6 +335,17 @@ func (s stmt) CheckNamedValue(arg *driver.NamedValue) error {
}
}
func newResult(c *sqlite3.Conn) driver.Result {
rows := c.Changes()
if rows != 0 {
id := c.LastInsertRowID()
if id != 0 {
return result{id, rows}
}
}
return resultRowsAffected(rows)
}
type result struct{ lastInsertId, rowsAffected int64 }
func (r result) LastInsertId() (int64, error) {
@@ -323,47 +356,56 @@ func (r result) RowsAffected() (int64, error) {
return r.rowsAffected, nil
}
type resultRowsAffected int64
func (r resultRowsAffected) LastInsertId() (int64, error) {
return 0, nil
}
func (r resultRowsAffected) RowsAffected() (int64, error) {
return int64(r), nil
}
type rows struct {
ctx context.Context
stmt *sqlite3.Stmt
conn *sqlite3.Conn
Stmt *sqlite3.Stmt
Conn *sqlite3.Conn
}
func (r rows) Close() error {
return r.stmt.Reset()
func (r *rows) Close() error {
return r.Stmt.Reset()
}
func (r rows) Columns() []string {
count := r.stmt.ColumnCount()
func (r *rows) Columns() []string {
count := r.Stmt.ColumnCount()
columns := make([]string, count)
for i := range columns {
columns[i] = r.stmt.ColumnName(i)
columns[i] = r.Stmt.ColumnName(i)
}
return columns
}
func (r rows) Next(dest []driver.Value) error {
old := r.conn.SetInterrupt(r.ctx)
defer r.conn.SetInterrupt(old)
func (r *rows) Next(dest []driver.Value) error {
old := r.Conn.SetInterrupt(r.ctx)
defer r.Conn.SetInterrupt(old)
if !r.stmt.Step() {
if err := r.stmt.Err(); err != nil {
if !r.Stmt.Step() {
if err := r.Stmt.Err(); err != nil {
return err
}
return io.EOF
}
for i := range dest {
switch r.stmt.ColumnType(i) {
switch r.Stmt.ColumnType(i) {
case sqlite3.INTEGER:
dest[i] = r.stmt.ColumnInt64(i)
dest[i] = r.Stmt.ColumnInt64(i)
case sqlite3.FLOAT:
dest[i] = r.stmt.ColumnFloat(i)
case sqlite3.TEXT:
dest[i] = maybeTime(r.stmt.ColumnText(i))
dest[i] = r.Stmt.ColumnFloat(i)
case sqlite3.BLOB:
buf, _ := dest[i].([]byte)
dest[i] = r.stmt.ColumnBlob(i, buf)
dest[i] = r.Stmt.ColumnRawBlob(i)
case sqlite3.TEXT:
dest[i] = stringOrTime(r.Stmt.ColumnRawText(i))
case sqlite3.NULL:
if buf, ok := dest[i].([]byte); ok {
dest[i] = buf[0:0]
@@ -371,9 +413,9 @@ func (r rows) Next(dest []driver.Value) error {
dest[i] = nil
}
default:
panic(assertErr)
panic(util.AssertErr())
}
}
return r.stmt.Err()
return r.Stmt.Err()
}

View File

@@ -1,4 +1,3 @@
// Package driver provides a database/sql driver for SQLite.
package driver
import (
@@ -12,6 +11,7 @@ import (
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
func Test_Open_dir(t *testing.T) {
@@ -134,14 +134,16 @@ func Test_BeginTx(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.db"))
db, err := sql.Open("sqlite3", "file:"+
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
"?_txlock=exclusive&_pragma=busy_timeout(0)")
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err.Error() != string(isolationErr) {
if err.Error() != string(util.IsolationErr) {
t.Error("want isolationErr")
}
@@ -219,7 +221,7 @@ func Test_Prepare(t *testing.T) {
}
_, err = db.Prepare(`SELECT 1; SELECT 2`)
if err.Error() != string(tailErr) {
if err.Error() != string(util.TailErr) {
t.Error("want tailErr")
}
}
@@ -297,7 +299,7 @@ func Test_QueryRow_blob_null(t *testing.T) {
want := [][]byte{nil, {0xca, 0xfe}, {0xba, 0xbe}, nil}
for i := 0; rows.Next(); i++ {
var buf []byte
var buf sql.RawBytes
err = rows.Scan(&buf)
if err != nil {
t.Fatal(err)

View File

@@ -1,11 +0,0 @@
package driver
type errorString string
func (e errorString) Error() string { return string(e) }
const (
assertErr = errorString("sqlite3: assertion failed")
tailErr = errorString("sqlite3: multiple statements")
isolationErr = errorString("sqlite3: unsupported isolation level")
)

View File

@@ -9,23 +9,23 @@ import (
// if it roundtrips back to the same string.
// This way times can be persisted to, and recovered from, the database,
// but if a string is needed, [database/sql] will recover the same string.
func maybeTime(text string) driver.Value {
func stringOrTime(text []byte) driver.Value {
// Weed out (some) values that can't possibly be
// [time.RFC3339Nano] timestamps.
if len(text) < len("2006-01-02T15:04:05Z") {
return text
return string(text)
}
if len(text) > len(time.RFC3339Nano) {
return text
return string(text)
}
if text[4] != '-' || text[10] != 'T' || text[16] != ':' {
return text
return string(text)
}
// Slow path.
date, err := time.Parse(time.RFC3339Nano, text)
if err == nil && date.Format(time.RFC3339Nano) == text {
date, err := time.Parse(time.RFC3339Nano, string(text))
if err == nil && date.Format(time.RFC3339Nano) == string(text) {
return date
}
return text
return string(text)
}

View File

@@ -6,7 +6,7 @@ import (
)
// This checks that any string can be recovered as the same string.
func Fuzz_maybeTime_1(f *testing.F) {
func Fuzz_stringOrTime_1(f *testing.F) {
f.Add("")
f.Add(" ")
f.Add("SQLite")
@@ -22,7 +22,7 @@ func Fuzz_maybeTime_1(f *testing.F) {
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
f.Fuzz(func(t *testing.T, str string) {
value := maybeTime(str)
value := stringOrTime([]byte(str))
switch v := value.(type) {
case time.Time:
@@ -48,7 +48,7 @@ func Fuzz_maybeTime_1(f *testing.F) {
// This checks that any [time.Time] can be recovered as a [time.Time],
// with nanosecond accuracy, and preserving any timezone offset.
func Fuzz_maybeTime_2(f *testing.F) {
func Fuzz_stringOrTime_2(f *testing.F) {
f.Add(0, 0)
f.Add(0, 1)
f.Add(0, -1)
@@ -59,7 +59,7 @@ func Fuzz_maybeTime_2(f *testing.F) {
f.Add(-763421161058, 222_222_222) // twosday, year 22222BC
checkTime := func(t *testing.T, date time.Time) {
value := maybeTime(date.Format(time.RFC3339Nano))
value := stringOrTime([]byte(date.Format(time.RFC3339Nano)))
switch v := value.(type) {
case time.Time:

View File

@@ -48,7 +48,8 @@ func ExampleDriverConn() {
err = conn.Raw(func(driverConn any) error {
conn := driverConn.(sqlite3.DriverConn)
defer conn.Savepoint()(&err)
savept := conn.Savepoint()
defer savept.Release(&err)
blob, err := conn.OpenBlob("main", "test", "col", id, true)
if err != nil {

24
embed/README.md Normal file
View File

@@ -0,0 +1,24 @@
# Embeddable WASM build of SQLite
This folder includes an embeddable WASM build of SQLite 3.42.0 for use with
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
The following optional features are compiled in:
- [math functions](https://www.sqlite.org/lang_mathfunc.html)
- [FTS3/4](https://www.sqlite.org/fts3.html)/[5](https://www.sqlite.org/fts5.html)
- [JSON](https://www.sqlite.org/json1.html)
- [R*Tree](https://www.sqlite.org/rtree.html)
- [GeoPoly](https://www.sqlite.org/geopoly.html)
- [base64](https://github.com/sqlite/sqlite/blob/master/ext/misc/base64.c)
- [decimal](https://github.com/sqlite/sqlite/blob/master/ext/misc/decimal.c)
- [regexp](https://github.com/sqlite/sqlite/blob/master/ext/misc/regexp.c)
- [series](https://github.com/sqlite/sqlite/blob/master/ext/misc/series.c)
- [uint](https://github.com/sqlite/sqlite/blob/master/ext/misc/uint.c)
- [uuid](https://github.com/sqlite/sqlite/blob/master/ext/misc/uuid.c)
- [time](../sqlite3/time.c)
See the [configuration options](../sqlite3/sqlite_cfg.h),
and [patches](../sqlite3) applied.
Built using [`wasi-sdk`](https://github.com/WebAssembly/wasi-sdk),
and [`binaryen`](https://github.com/WebAssembly/binaryen).

View File

@@ -1,16 +1,27 @@
#!/usr/bin/env bash
set -eo pipefail
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
# download SQLite
../sqlite3/download.sh
ROOT=../
BINARYEN="$ROOT/tools/binaryen-version_113/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
# build SQLite
zig cc --target=wasm32-wasi -flto -g0 -Os \
-o sqlite3.wasm ../sqlite3/amalg.c \
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
-o sqlite3.wasm "$ROOT/sqlite3/main.c" \
-I"$ROOT/sqlite3" \
-mexec-model=reactor \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-D_HAVE_SQLITE_CONFIG_H \
$(awk '{print "-Wl,--export="$0}' ../sqlite3/exports.txt)
-Wl,--initial-memory=327680 \
-Wl,--stack-first \
-Wl,--import-undefined \
$(awk '{print "-Wl,--export="$0}' exports.txt)
trap 'rm -f sqlite3.tmp' EXIT
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
"$BINARYEN/wasm-opt" -g -O2 sqlite3.tmp -o sqlite3.wasm \
--enable-multivalue --enable-mutable-globals \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext

View File

@@ -37,13 +37,13 @@ sqlite3_blob_bytes
sqlite3_blob_read
sqlite3_blob_write
sqlite3_blob_reopen
sqlite3_get_autocommit
sqlite3_last_insert_rowid
sqlite3_changes64
sqlite3_unlock_notify
sqlite3_backup_init
sqlite3_backup_step
sqlite3_backup_finish
sqlite3_backup_remaining
sqlite3_backup_pagecount
sqlite3_interrupt_offset
sqlite3_uri_parameter
sqlite3_uri_key
sqlite3_changes64
sqlite3_last_insert_rowid
sqlite3_get_autocommit

View File

@@ -4,9 +4,6 @@
// with an appropriate build of SQLite:
//
// import _ "github.com/ncruces/go-sqlite3/embed"
//
// You can obtain this build of SQLite from:
// https://github.com/ncruces/go-sqlite3/tree/main/embed
package embed
import (

Binary file not shown.

121
error.go
View File

@@ -1,20 +1,20 @@
package sqlite3
import (
"fmt"
"runtime"
"strconv"
"strings"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Error wraps an SQLite Error Code.
//
// https://www.sqlite.org/c3ref/errcode.html
type Error struct {
code uint64
str string
msg string
sql string
code uint64
}
// Code returns the primary error code for this error.
@@ -85,72 +85,7 @@ func (e *Error) SQL() string {
// Error implements the error interface.
func (e ErrorCode) Error() string {
switch e {
case _OK:
return "sqlite3: not an error"
case _ROW:
return "sqlite3: another row available"
case _DONE:
return "sqlite3: no more rows available"
case ERROR:
return "sqlite3: SQL logic error"
case INTERNAL:
break
case PERM:
return "sqlite3: access permission denied"
case ABORT:
return "sqlite3: query aborted"
case BUSY:
return "sqlite3: database is locked"
case LOCKED:
return "sqlite3: database table is locked"
case NOMEM:
return "sqlite3: out of memory"
case READONLY:
return "sqlite3: attempt to write a readonly database"
case INTERRUPT:
return "sqlite3: interrupted"
case IOERR:
return "sqlite3: disk I/O error"
case CORRUPT:
return "sqlite3: database disk image is malformed"
case NOTFOUND:
return "sqlite3: unknown operation"
case FULL:
return "sqlite3: database or disk is full"
case CANTOPEN:
return "sqlite3: unable to open database file"
case PROTOCOL:
return "sqlite3: locking protocol"
case FORMAT:
break
case SCHEMA:
return "sqlite3: database schema has changed"
case TOOBIG:
return "sqlite3: string or blob too big"
case CONSTRAINT:
return "sqlite3: constraint failed"
case MISMATCH:
return "sqlite3: datatype mismatch"
case MISUSE:
return "sqlite3: bad parameter or other API misuse"
case NOLFS:
break
case AUTH:
return "sqlite3: authorization denied"
case EMPTY:
break
case RANGE:
return "sqlite3: column index out of range"
case NOTADB:
return "sqlite3: file is not a database"
case NOTICE:
return "sqlite3: notification message"
case WARNING:
return "sqlite3: warning message"
}
return "sqlite3: unknown error"
return util.ErrorCodeString(uint32(e))
}
// Temporary returns true for [BUSY] errors.
@@ -160,17 +95,7 @@ func (e ErrorCode) Temporary() bool {
// Error implements the error interface.
func (e ExtendedErrorCode) Error() string {
switch x := ErrorCode(e); {
case e == ABORT_ROLLBACK:
return "sqlite3: abort due to ROLLBACK"
case x < _ROW:
return x.Error()
case e == _ROW:
return "sqlite3: another row available"
case e == _DONE:
return "sqlite3: no more rows available"
}
return "sqlite3: unknown error"
return util.ErrorCodeString(uint32(e))
}
// Is tests whether this error matches a given [ErrorCode].
@@ -188,39 +113,3 @@ func (e ExtendedErrorCode) Temporary() bool {
func (e ExtendedErrorCode) Timeout() bool {
return e == BUSY_TIMEOUT
}
type errorString string
func (e errorString) Error() string { return string(e) }
const (
nilErr = errorString("sqlite3: invalid memory address or null pointer dereference")
oomErr = errorString("sqlite3: out of memory")
rangeErr = errorString("sqlite3: index out of range")
noNulErr = errorString("sqlite3: missing NUL terminator")
noGlobalErr = errorString("sqlite3: could not find global: ")
noFuncErr = errorString("sqlite3: could not find function: ")
binaryErr = errorString("sqlite3: no SQLite binary embed/set/loaded")
timeErr = errorString("sqlite3: invalid time value")
emptyErr = errorString("sqlite3: empty statement")
tailErr = errorString("sqlite3: non-empty tail")
notImplErr = errorString("sqlite3: not implemented")
whenceErr = errorString("sqlite3: invalid whence")
offsetErr = errorString("sqlite3: invalid offset")
)
func assertErr() errorString {
msg := "sqlite3: assertion failed"
if _, file, line, ok := runtime.Caller(1); ok {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return errorString(msg)
}
func finalizer[T any](skip int) func(*T) {
msg := fmt.Sprintf("sqlite3: %T not closed", new(T))
if _, file, line, ok := runtime.Caller(skip + 1); ok && skip >= 0 {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return func(*T) { panic(errorString(msg)) }
}

View File

@@ -1,15 +1,16 @@
package sqlite3
import (
"context"
"errors"
"strings"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
)
func Test_assertErr(t *testing.T) {
err := assertErr()
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:11)") {
err := util.AssertErr()
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:12)") {
t.Errorf("got %q", s)
}
}
@@ -120,10 +121,8 @@ func Test_ErrorCode_Error(t *testing.T) {
// Test all error codes.
for i := 0; i == int(ErrorCode(i)); i++ {
want := "sqlite3: "
r, _ := db.api.errstr.Call(context.TODO(), uint64(i))
if r != nil {
want += db.mem.readString(uint32(r[0]), _MAX_STRING)
}
r := db.call(db.api.errstr, uint64(i))
want += util.ReadString(db.mod, uint32(r), _MAX_STRING)
got := ErrorCode(i).Error()
if got != want {
@@ -144,10 +143,8 @@ func Test_ExtendedErrorCode_Error(t *testing.T) {
// Test all extended error codes.
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
want := "sqlite3: "
r, _ := db.api.errstr.Call(context.TODO(), uint64(i))
if r != nil {
want += db.mem.readString(uint32(r[0]), _MAX_STRING)
}
r := db.call(db.api.errstr, uint64(i))
want += util.ReadString(db.mod, uint32(r), _MAX_STRING)
got := ExtendedErrorCode(i).Error()
if got != want {

View File

@@ -26,7 +26,11 @@ func Example() {
log.Fatal(err)
}
stmt := db.MustPrepare(`SELECT id, name FROM users`)
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnInt(0), stmt.ColumnText(1))

7
go.mod
View File

@@ -4,9 +4,10 @@ go 1.19
require (
github.com/ncruces/julianday v0.1.5
github.com/tetratelabs/wazero v1.0.0-rc.1
golang.org/x/sync v0.1.0
golang.org/x/sys v0.6.0
github.com/psanford/httpreadat v0.1.0
github.com/tetratelabs/wazero v1.2.1
golang.org/x/sync v0.3.0
golang.org/x/sys v0.9.0
)
retract v0.4.0 // tagged from the wrong branch

14
go.sum
View File

@@ -1,8 +1,10 @@
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.0.0-rc.1 h1:ytecMV5Ue0BwezjKh/cM5yv1Mo49ep2R2snSsQUyToc=
github.com/tetratelabs/wazero v1.0.0-rc.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
github.com/tetratelabs/wazero v1.2.1 h1:J4X2hrGzJvt+wqltuvcSjHQ7ujQxA9gb6PeMs4qlUWs=
github.com/tetratelabs/wazero v1.2.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

6
go.work Normal file
View File

@@ -0,0 +1,6 @@
go 1.19
use (
.
./gormlite
)

22
gormlite/LICENSE Normal file
View File

@@ -0,0 +1,22 @@
MIT License
Copyright (c) 2023 Nuno Cruces
Copyright (c) 2023 Jinzhu <wosmvp@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

26
gormlite/README.md Normal file
View File

@@ -0,0 +1,26 @@
# GORM SQLite Driver
[![Go Reference](https://pkg.go.dev/badge/image)](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
## Usage
```go
import (
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/gormlite"
"gorm.io/gorm"
)
db, err := gorm.Open(gormlite.Open("gorm.db"), &gorm.Config{})
```
Checkout [https://gorm.io](https://gorm.io) for details.
### Foreign-key constraint activation
Foreign-key constraint is disabled by default in SQLite. To activate it, use connection URL parameter:
```go
db, err := gorm.Open(gormlite.Open(
"file:gorm.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)&_pragma=foreign_keys(1)"),
&gorm.Config{})
```

231
gormlite/ddlmod.go Normal file
View File

@@ -0,0 +1,231 @@
package gormlite
import (
"database/sql"
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"gorm.io/gorm/migrator"
)
var (
sqliteSeparator = "`|\"|'|\t"
indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]? ON (.*)$`, sqliteSeparator, sqliteSeparator))
tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator))
separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator))
columnsRegexp = regexp.MustCompile(fmt.Sprintf(`[(,][%v]?(\w+)[%v]?`, sqliteSeparator, sqliteSeparator))
columnRegexp = regexp.MustCompile(fmt.Sprintf(`^[%v]?([\w\d]+)[%v]?\s+([\w\(\)\d]+)(.*)$`, sqliteSeparator, sqliteSeparator))
defaultValueRegexp = regexp.MustCompile(`(?i) DEFAULT \(?(.+)?\)?( |COLLATE|GENERATED|$)`)
regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
)
func getAllColumns(s string) []string {
allMatches := columnsRegexp.FindAllStringSubmatch(s, -1)
columns := make([]string, 0, len(allMatches))
for _, matches := range allMatches {
if len(matches) > 1 {
columns = append(columns, matches[1])
}
}
return columns
}
type ddl struct {
head string
fields []string
columns []migrator.ColumnType
}
func parseDDL(strs ...string) (*ddl, error) {
var result ddl
for _, str := range strs {
if sections := tableRegexp.FindStringSubmatch(str); len(sections) > 0 {
var (
ddlBody = sections[2]
ddlBodyRunes = []rune(ddlBody)
bracketLevel int
quote rune
buf string
)
ddlBodyRunesLen := len(ddlBodyRunes)
result.head = sections[1]
for idx := 0; idx < ddlBodyRunesLen; idx++ {
var (
next rune = 0
c = ddlBodyRunes[idx]
)
if idx+1 < ddlBodyRunesLen {
next = ddlBodyRunes[idx+1]
}
if sc := string(c); separatorRegexp.MatchString(sc) {
if c == next {
buf += sc // Skip escaped quote
idx++
} else if quote > 0 {
quote = 0
} else {
quote = c
}
} else if quote == 0 {
if c == '(' {
bracketLevel++
} else if c == ')' {
bracketLevel--
} else if bracketLevel == 0 {
if c == ',' {
result.fields = append(result.fields, strings.TrimSpace(buf))
buf = ""
continue
}
}
}
if bracketLevel < 0 {
return nil, errors.New("invalid DDL, unbalanced brackets")
}
buf += string(c)
}
if bracketLevel != 0 {
return nil, errors.New("invalid DDL, unbalanced brackets")
}
if buf != "" {
result.fields = append(result.fields, strings.TrimSpace(buf))
}
for _, f := range result.fields {
fUpper := strings.ToUpper(f)
if strings.HasPrefix(fUpper, "CHECK") ||
strings.HasPrefix(fUpper, "CONSTRAINT") {
continue
}
if strings.HasPrefix(fUpper, "PRIMARY KEY") {
for _, name := range getAllColumns(f) {
for idx, column := range result.columns {
if column.NameValue.String == name {
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
result.columns[idx] = column
break
}
}
}
} else if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 0 {
columnType := migrator.ColumnType{
NameValue: sql.NullString{String: matches[1], Valid: true},
DataTypeValue: sql.NullString{String: matches[2], Valid: true},
ColumnTypeValue: sql.NullString{String: matches[2], Valid: true},
PrimaryKeyValue: sql.NullBool{Valid: true},
UniqueValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
}
matchUpper := strings.ToUpper(matches[3])
if strings.Contains(matchUpper, " NOT NULL") {
columnType.NullableValue = sql.NullBool{Bool: false, Valid: true}
} else if strings.Contains(matchUpper, " NULL") {
columnType.NullableValue = sql.NullBool{Bool: true, Valid: true}
}
if strings.Contains(matchUpper, " UNIQUE") {
columnType.UniqueValue = sql.NullBool{Bool: true, Valid: true}
}
if strings.Contains(matchUpper, " PRIMARY") {
columnType.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
}
if defaultMatches := defaultValueRegexp.FindStringSubmatch(matches[3]); len(defaultMatches) > 1 {
if strings.ToLower(defaultMatches[1]) != "null" {
columnType.DefaultValueValue = sql.NullString{String: strings.Trim(defaultMatches[1], `"`), Valid: true}
}
}
// data type length
matches := regRealDataType.FindAllStringSubmatch(columnType.DataTypeValue.String, -1)
if len(matches) == 1 && len(matches[0]) == 2 {
size, _ := strconv.Atoi(matches[0][1])
columnType.LengthValue = sql.NullInt64{Valid: true, Int64: int64(size)}
columnType.DataTypeValue.String = strings.TrimSuffix(columnType.DataTypeValue.String, matches[0][0])
}
result.columns = append(result.columns, columnType)
}
}
} else if matches := indexRegexp.FindStringSubmatch(str); len(matches) > 0 {
for _, column := range getAllColumns(matches[1]) {
for idx, c := range result.columns {
if c.NameValue.String == column {
c.UniqueValue = sql.NullBool{Bool: strings.ToUpper(strings.Fields(str)[1]) == "UNIQUE", Valid: true}
result.columns[idx] = c
}
}
}
} else {
return nil, errors.New("invalid DDL")
}
}
return &result, nil
}
func (d *ddl) compile() string {
if len(d.fields) == 0 {
return d.head
}
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
}
func (d *ddl) addConstraint(name string, sql string) {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields[i] = sql
return
}
}
d.fields = append(d.fields, sql)
}
func (d *ddl) removeConstraint(name string) bool {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields = append(d.fields[:i], d.fields[i+1:]...)
return true
}
}
return false
}
func (d *ddl) getColumns() []string {
res := []string{}
for _, f := range d.fields {
fUpper := strings.ToUpper(f)
if strings.HasPrefix(fUpper, "PRIMARY KEY") ||
strings.HasPrefix(fUpper, "CHECK") ||
strings.HasPrefix(fUpper, "CONSTRAINT") ||
strings.Contains(fUpper, "GENERATED ALWAYS AS") {
continue
}
reg := regexp.MustCompile("^[\"`']?([\\w\\d]+)[\"`']?")
match := reg.FindStringSubmatch(f)
if match != nil {
res = append(res, "`"+match[1]+"`")
}
}
return res
}

352
gormlite/ddlmod_test.go Normal file
View File

@@ -0,0 +1,352 @@
package gormlite
import (
"database/sql"
"testing"
"gorm.io/gorm/migrator"
"gorm.io/gorm/utils/tests"
)
func TestParseDDL(t *testing.T) {
params := []struct {
name string
sql []string
nFields int
columns []migrator.ColumnType
}{
{"with_fk", []string{
"CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500) DEFAULT \"hello\",`age` integer DEFAULT 18,`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
"CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)",
}, 6, []migrator.ColumnType{
{NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}},
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
},
{"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
}},
{"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
},
{"no brackets", []string{"create table test"}, 0, nil},
{"with_special_characters", []string{
"CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")",
}, 1, []migrator.ColumnType{
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
},
{
"table_name_with_dash",
[]string{
"CREATE TABLE `test-a` (`id` int NOT NULL)",
"CREATE UNIQUE INDEX `idx_test-a_id` ON `test-a`(`id`)",
},
1,
[]migrator.ColumnType{
{
NameValue: sql.NullString{String: "id", Valid: true},
DataTypeValue: sql.NullString{String: "int", Valid: true},
ColumnTypeValue: sql.NullString{String: "int", Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
PrimaryKeyValue: sql.NullBool{Valid: true},
},
},
},
{
"unique index",
[]string{
"CREATE TABLE `test-b` (`field` integer NOT NULL)",
"CREATE UNIQUE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
},
1,
[]migrator.ColumnType{
{
NameValue: sql.NullString{String: "field", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
},
},
},
{
"non-unique index",
[]string{
"CREATE TABLE `test-c` (`field` integer NOT NULL)",
"CREATE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
},
1,
[]migrator.ColumnType{
{
NameValue: sql.NullString{String: "field", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
UniqueValue: sql.NullBool{Bool: false, Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
},
},
},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
ddl, err := parseDDL(p.sql...)
if err != nil {
panic(err.Error())
}
tests.AssertEqual(t, p.sql[0], ddl.compile())
if len(ddl.fields) != p.nFields {
t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields))
}
tests.AssertEqual(t, ddl.columns, p.columns)
})
}
}
func TestParseDDL_Whitespaces(t *testing.T) {
testColumns := []migrator.ColumnType{
{
NameValue: sql.NullString{String: "id", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true},
},
{
NameValue: sql.NullString{String: "dark_mode", Valid: true},
DataTypeValue: sql.NullString{String: "numeric", Valid: true},
ColumnTypeValue: sql.NullString{String: "numeric", Valid: true},
NullableValue: sql.NullBool{Valid: true},
DefaultValueValue: sql.NullString{String: "true", Valid: true},
UniqueValue: sql.NullBool{Bool: false, Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
},
}
params := []struct {
name string
sql []string
nFields int
columns []migrator.ColumnType
}{
{
"with_newline",
[]string{"CREATE TABLE `users`\n(\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
2,
testColumns,
},
{
"with_newline_2",
[]string{"CREATE TABLE `users` (\n\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
2,
testColumns,
},
{
"with_missing_space",
[]string{"CREATE TABLE `users`(id integer primary key unique, dark_mode numeric DEFAULT true)"},
2,
testColumns,
},
{
"with_many_spaces",
[]string{"CREATE TABLE `users` (id integer primary key unique, dark_mode numeric DEFAULT true)"},
2,
testColumns,
},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
ddl, err := parseDDL(p.sql...)
if err != nil {
panic(err.Error())
}
if len(ddl.fields) != p.nFields {
t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields))
}
tests.AssertEqual(t, ddl.columns, p.columns)
})
}
}
func TestParseDDL_error(t *testing.T) {
params := []struct {
name string
sql string
}{
{"invalid_cmd", "CREATE TABLE"},
{"unbalanced_brackets", "CREATE TABLE test (ID int NOT NULL,Name varchar(255)"},
{"unbalanced_brackets2", "CREATE TABLE test (ID int NOT NULL,Name varchar(255)))"},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
_, err := parseDDL(p.sql)
if err == nil {
t.Fail()
}
})
}
}
func TestAddConstraint(t *testing.T) {
params := []struct {
name string
fields []string
cName string
sql string
expect []string
}{
{
name: "add_new",
fields: []string{"`id` integer NOT NULL"},
cName: "fk_users_notes",
sql: "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
},
{
name: "update",
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
cName: "fk_users_notes",
sql: "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)) ON UPDATE CASCADE ON DELETE CASCADE",
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)) ON UPDATE CASCADE ON DELETE CASCADE"},
},
{
name: "add_check",
fields: []string{"`id` integer NOT NULL"},
cName: "name_checker",
sql: "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')",
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')"},
},
{
name: "update_check",
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')"},
cName: "name_checker",
sql: "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')",
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')"},
},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
testDDL := ddl{fields: p.fields}
testDDL.addConstraint(p.cName, p.sql)
tests.AssertEqual(t, p.expect, testDDL.fields)
})
}
}
func TestRemoveConstraint(t *testing.T) {
params := []struct {
name string
fields []string
cName string
success bool
expect []string
}{
{
name: "fk",
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
cName: "fk_users_notes",
success: true,
expect: []string{"`id` integer NOT NULL"},
},
{
name: "check",
fields: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
cName: "name_checker",
success: true,
expect: []string{"`id` integer NOT NULL"},
},
{
name: "none",
fields: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
cName: "nothing",
success: false,
expect: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
testDDL := ddl{fields: p.fields}
success := testDDL.removeConstraint(p.cName)
tests.AssertEqual(t, p.success, success)
tests.AssertEqual(t, p.expect, testDDL.fields)
})
}
}
func TestGetColumns(t *testing.T) {
params := []struct {
name string
ddl string
columns []string
}{
{
name: "with_fk",
ddl: "CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500),`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
columns: []string{"`id`", "`text`", "`user_id`"},
},
{
name: "with_check",
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName!='John'))",
columns: []string{"`ID`", "`LastName`", "`FirstName`", "`Age`"},
},
{
name: "with_escaped_quote",
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL DEFAULT \"\",FirstName varchar(255))",
columns: []string{"`ID`", "`LastName`", "`FirstName`"},
},
{
name: "with_generated_column",
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),FullName varchar(255) GENERATED ALWAYS AS (FirstName || ' ' || LastName))",
columns: []string{"`ID`", "`LastName`", "`FirstName`"},
},
{
name: "with_new_line",
ddl: `CREATE TABLE "tb_sys_role_menu__temp" (
"id" integer PRIMARY KEY AUTOINCREMENT,
"created_at" datetime NOT NULL,
"updated_at" datetime NOT NULL,
"created_by" integer NOT NULL DEFAULT 0,
"updated_by" integer NOT NULL DEFAULT 0,
"role_id" integer NOT NULL,
"menu_id" bigint NOT NULL
)`,
columns: []string{"`id`", "`created_at`", "`updated_at`", "`created_by`", "`updated_by`", "`role_id`", "`menu_id`"},
},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
testDDL, err := parseDDL(p.ddl)
if err != nil {
panic(err.Error())
}
cols := testDDL.getColumns()
tests.AssertEqual(t, p.columns, cols)
})
}
}

11
gormlite/download.sh Executable file
View File

@@ -0,0 +1,11 @@
#!/usr/bin/env bash
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/ddlmod.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/ddlmod_test.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/error_translator.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/migrator.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/sqlite.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/sqlite_test.go"

View File

@@ -0,0 +1,21 @@
package gormlite
import (
"errors"
"github.com/ncruces/go-sqlite3"
"gorm.io/gorm"
)
func (dialector Dialector) Translate(err error) error {
switch {
case
errors.Is(err, sqlite3.CONSTRAINT_UNIQUE),
errors.Is(err, sqlite3.CONSTRAINT_PRIMARYKEY):
return gorm.ErrDuplicatedKey
case
errors.Is(err, sqlite3.CONSTRAINT_FOREIGNKEY):
return err // gorm.ErrForeignKeyViolated (gorm v1.25.2)
}
return err
}

16
gormlite/go.mod Normal file
View File

@@ -0,0 +1,16 @@
module github.com/ncruces/go-sqlite3/gormlite
go 1.19
require (
github.com/ncruces/go-sqlite3 v0.7.3
gorm.io/gorm v1.25.1
)
require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/ncruces/julianday v0.1.5 // indirect
github.com/tetratelabs/wazero v1.2.1 // indirect
golang.org/x/sys v0.9.0 // indirect
)

14
gormlite/go.sum Normal file
View File

@@ -0,0 +1,14 @@
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/ncruces/go-sqlite3 v0.7.3 h1:yX9GebkDvaf1Z+VnxY77Od7GbWTWFCq9yvNzYJYMsaY=
github.com/ncruces/go-sqlite3 v0.7.3/go.mod h1:EhHe1qvG6Zc/8ffYMzre8n//rTRs1YNN5dUD1f1mEGc=
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.2.1 h1:J4X2hrGzJvt+wqltuvcSjHQ7ujQxA9gb6PeMs4qlUWs=
github.com/tetratelabs/wazero v1.2.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gorm.io/gorm v1.25.1 h1:nsSALe5Pr+cM3V1qwwQ7rOkw+6UeLrX5O4v3llhHa64=
gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=

431
gormlite/migrator.go Normal file
View File

@@ -0,0 +1,431 @@
package gormlite
import (
"database/sql"
"fmt"
"regexp"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
type Migrator struct {
migrator.Migrator
}
func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
var enabled int
m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
if enabled == 1 {
m.DB.Exec("PRAGMA foreign_keys = OFF")
defer m.DB.Exec("PRAGMA foreign_keys = ON")
}
return fc()
}
func (m Migrator) HasTable(value interface{}) bool {
var count int
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
})
return count > 0
}
func (m Migrator) DropTable(values ...interface{}) error {
return m.RunWithoutForeignKey(func() error {
values = m.ReorderModels(values, false)
tx := m.DB.Session(&gorm.Session{})
for i := len(values) - 1; i >= 0; i-- {
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error
}); err != nil {
return err
}
}
return nil
})
}
func (m Migrator) GetTables() (tableList []string, err error) {
return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error
}
func (m Migrator) HasColumn(value interface{}, name string) bool {
var count int
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
}
if name != "" {
m.DB.Raw(
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
"table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", "%["+name+"]%", "%\t"+name+"\t%",
).Row().Scan(&count)
}
return nil
})
return count > 0
}
func (m Migrator) AlterColumn(value interface{}, name string) error {
return m.RunWithoutForeignKey(func() error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
if field := stmt.Schema.LookUpField(name); field != nil {
// lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)`
reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)")
if err != nil {
return "", nil, err
}
createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName))
if createSQL == rawDDL {
return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL)
}
return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil
}
return "", nil, fmt.Errorf("failed to alter field with name %v", name)
})
})
}
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
var (
sqls []string
sqlDDL *ddl
)
if err := m.DB.Raw("SELECT sql FROM sqlite_master WHERE type IN ? AND tbl_name = ? AND sql IS NOT NULL order by type = ? desc", []string{"table", "index"}, stmt.Table, "table").Scan(&sqls).Error; err != nil {
return err
}
if sqlDDL, err = parseDDL(sqls...); err != nil {
return err
}
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
if err != nil {
return err
}
defer func() {
err = rows.Close()
}()
var rawColumnTypes []*sql.ColumnType
rawColumnTypes, err = rows.ColumnTypes()
if err != nil {
return err
}
for _, c := range rawColumnTypes {
columnType := migrator.ColumnType{SQLColumnType: c}
for _, column := range sqlDDL.columns {
if column.NameValue.String == c.Name() {
column.SQLColumnType = c
columnType = column
break
}
}
columnTypes = append(columnTypes, columnType)
}
return err
})
return columnTypes, execErr
}
func (m Migrator) DropColumn(value interface{}, name string) error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,")
if err != nil {
return "", nil, err
}
createSQL := reg.ReplaceAllString(rawDDL, "")
return createSQL, nil, nil
})
}
func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
return m.recreateTable(value, &table,
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
var (
constraintName string
constraintSql string
constraintValues []interface{}
)
if constraint != nil {
constraintName = constraint.Name
constraintSql, constraintValues = buildConstraint(constraint)
} else if chk != nil {
constraintName = chk.Name
constraintSql = "CONSTRAINT ? CHECK (?)"
constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
} else {
return "", nil, nil
}
createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
}
createDDL.addConstraint(constraintName, constraintSql)
createSQL := createDDL.compile()
return createSQL, constraintValues, nil
})
})
}
func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil {
name = constraint.Name
} else if chk != nil {
name = chk.Name
}
return m.recreateTable(value, &table,
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
}
createDDL.removeConstraint(name)
createSQL := createDDL.compile()
return createSQL, nil, nil
})
})
}
func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil {
name = constraint.Name
} else if chk != nil {
name = chk.Name
}
m.DB.Raw(
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
"table", table, `%CONSTRAINT "`+name+`" %`, `%CONSTRAINT `+name+` %`, "%CONSTRAINT `"+name+"`%", "%CONSTRAINT ["+name+"]%", "%CONSTRAINT \t"+name+"\t%",
).Row().Scan(&count)
return nil
})
return count > 0
}
func (m Migrator) CurrentDatabase() (name string) {
var null interface{}
m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
return
}
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
if opt.Expression != "" {
str = opt.Expression
}
if opt.Collate != "" {
str += " COLLATE " + opt.Collate
}
if opt.Sort != "" {
str += " " + opt.Sort
}
results = append(results, clause.Expr{SQL: str})
}
return
}
func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ?"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
createIndexSQL += " ON ??"
if idx.Where != "" {
createIndexSQL += " WHERE " + idx.Where
}
return m.DB.Exec(createIndexSQL, values...).Error
}
}
return fmt.Errorf("failed to create index with name %v", name)
})
}
func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int
m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}
if name != "" {
m.DB.Raw(
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name,
).Row().Scan(&count)
}
return nil
})
return count > 0
}
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
var sql string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
if sql != "" {
if err := m.DropIndex(value, oldName); err != nil {
return err
}
return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
}
return fmt.Errorf("failed to find index with name %v", oldName)
})
}
func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
})
}
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
var foreignKeys, references []interface{}
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}
func (m Migrator) getRawDDL(table string) (string, error) {
var createSQL string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
if m.DB.Error != nil {
return "", m.DB.Error
}
return createSQL, nil
}
func (m Migrator) recreateTable(value interface{}, tablePtr *string,
getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
table := stmt.Table
if tablePtr != nil {
table = *tablePtr
}
rawDDL, err := m.getRawDDL(table)
if err != nil {
return err
}
newTableName := table + "__temp"
createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt)
if err != nil {
return err
}
if createSQL == "" {
return nil
}
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
if err != nil {
return err
}
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
createDDL, err := parseDDL(createSQL)
if err != nil {
return err
}
columns := createDDL.getColumns()
return m.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {
return err
}
queries := []string{
fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), table),
fmt.Sprintf("DROP TABLE `%v`", table),
fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, table),
}
for _, query := range queries {
if err := tx.Exec(query).Error; err != nil {
return err
}
}
return nil
})
})
}

219
gormlite/sqlite.go Normal file
View File

@@ -0,0 +1,219 @@
// Package gormlite provides a GORM driver for SQLite.
package gormlite
import (
"context"
"database/sql"
"strconv"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
_ "github.com/ncruces/go-sqlite3/driver"
)
type Dialector struct {
DSN string
Conn gorm.ConnPool
}
func Open(dsn string) gorm.Dialector {
return &Dialector{DSN: dsn}
}
func (dialector Dialector) Name() string {
return "sqlite"
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
if dialector.Conn != nil {
db.ConnPool = dialector.Conn
} else {
conn, err := sql.Open("sqlite3", dialector.DSN)
if err != nil {
return err
}
db.ConnPool = conn
}
var version string
if err := db.ConnPool.QueryRowContext(context.Background(), "select sqlite_version()").Scan(&version); err != nil {
return err
}
// https://www.sqlite.org/releaselog/3_35_0.html
if compareVersion(version, "3.35.0") >= 0 {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
LastInsertIDReversed: true,
})
} else {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
LastInsertIDReversed: true,
})
}
for k, v := range dialector.ClauseBuilders() {
db.ClauseBuilders[k] = v
}
return
}
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
return map[string]clause.ClauseBuilder{
"INSERT": func(c clause.Clause, builder clause.Builder) {
if insert, ok := c.Expression.(clause.Insert); ok {
if stmt, ok := builder.(*gorm.Statement); ok {
stmt.WriteString("INSERT ")
if insert.Modifier != "" {
stmt.WriteString(insert.Modifier)
stmt.WriteByte(' ')
}
stmt.WriteString("INTO ")
if insert.Table.Name == "" {
stmt.WriteQuoted(stmt.Table)
} else {
stmt.WriteQuoted(insert.Table)
}
return
}
}
c.Build(builder)
},
"LIMIT": func(c clause.Clause, builder clause.Builder) {
if limit, ok := c.Expression.(clause.Limit); ok {
var lmt = -1
if limit.Limit != nil && *limit.Limit >= 0 {
lmt = *limit.Limit
}
if lmt >= 0 || limit.Offset > 0 {
builder.WriteString("LIMIT ")
builder.WriteString(strconv.Itoa(lmt))
}
if limit.Offset > 0 {
builder.WriteString(" OFFSET ")
builder.WriteString(strconv.Itoa(limit.Offset))
}
}
},
"FOR": func(c clause.Clause, builder clause.Builder) {
if _, ok := c.Expression.(clause.Locking); ok {
// SQLite3 does not support row-level locking.
return
}
c.Build(builder)
},
}
}
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
if field.AutoIncrement {
return clause.Expr{SQL: "NULL"}
}
// doesn't work, will raise error
return clause.Expr{SQL: "DEFAULT"}
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
CreateIndexAfterCreateTable: true,
}}}
}
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteByte('?')
}
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteByte('`')
if strings.Contains(str, ".") {
for idx, str := range strings.Split(str, ".") {
if idx > 0 {
writer.WriteString(".`")
}
writer.WriteString(str)
writer.WriteByte('`')
}
} else {
writer.WriteString(str)
writer.WriteByte('`')
}
}
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, nil, `"`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType {
case schema.Bool:
return "numeric"
case schema.Int, schema.Uint:
if field.AutoIncrement && !field.PrimaryKey {
// https://www.sqlite.org/autoinc.html
return "integer PRIMARY KEY AUTOINCREMENT"
} else {
return "integer"
}
case schema.Float:
return "real"
case schema.String:
return "text"
case schema.Time:
// Distinguish between schema.Time and tag time
if val, ok := field.TagSettings["TYPE"]; ok {
return val
} else {
return "datetime"
}
case schema.Bytes:
return "blob"
}
return string(field.DataType)
}
func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error {
tx.Exec("SAVEPOINT " + name)
return nil
}
func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
return nil
}
func compareVersion(version1, version2 string) int {
n, m := len(version1), len(version2)
i, j := 0, 0
for i < n || j < m {
x := 0
for ; i < n && version1[i] != '.'; i++ {
x = x*10 + int(version1[i]-'0')
}
i++
y := 0
for ; j < m && version2[j] != '.'; j++ {
y = y*10 + int(version2[j]-'0')
}
j++
if x > y {
return 1
}
if x < y {
return -1
}
}
return 0
}

64
gormlite/sqlite_test.go Normal file
View File

@@ -0,0 +1,64 @@
package gormlite
import (
"fmt"
"testing"
"gorm.io/gorm"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestDialector(t *testing.T) {
// This is the DSN of the in-memory SQLite database for these tests.
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
rows := []struct {
description string
dialector *Dialector
openSuccess bool
query string
querySuccess bool
}{
{
description: "Default driver",
dialector: &Dialector{
DSN: InMemoryDSN,
},
openSuccess: true,
query: "SELECT 1",
querySuccess: true,
},
}
for rowIndex, row := range rows {
t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) {
db, err := gorm.Open(row.dialector, &gorm.Config{})
if !row.openSuccess {
if err == nil {
t.Errorf("Expected Open to fail.")
}
return
}
if err != nil {
t.Errorf("Expected Open to succeed; got error: %v", err)
}
if db == nil {
t.Errorf("Expected db to be non-nil.")
}
if row.query != "" {
err = db.Exec(row.query).Error
if !row.querySuccess {
if err == nil {
t.Errorf("Expected query to fail.")
}
return
}
if err != nil {
t.Errorf("Expected query to succeed; got error: %v", err)
}
}
})
}
}

18
gormlite/test.sh Executable file
View File

@@ -0,0 +1,18 @@
#!/usr/bin/env bash
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
rm -rf gorm/ tests/
git clone --filter=blob:none --branch=v1.25.1 https://github.com/go-gorm/gorm.git
mv gorm/tests tests
rm -rf gorm/
patch -p1 -N < tests.patch
cd tests
go mod tidy && go work use . && go test
cd ..
rm -rf tests/
go work use -r .

63
gormlite/tests.patch Normal file
View File

@@ -0,0 +1,63 @@
diff --git a/tests/.gitignore b/tests/.gitignore
index 08cb523..72e8ffc 100644
--- a/tests/.gitignore
+++ b/tests/.gitignore
@@ -1 +1 @@
-go.sum
+*
diff --git a/tests/go.mod b/tests/go.mod
index f47d175..84b80c2 100644
--- a/tests/go.mod
+++ b/tests/go.mod
@@ -7,13 +7,13 @@ require (
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.8
- github.com/mattn/go-sqlite3 v1.14.16 // indirect
+ github.com/ncruces/go-sqlite3 v0.7.2
+ github.com/ncruces/go-sqlite3/gormlite v0.0.0
golang.org/x/crypto v0.8.0 // indirect
gorm.io/driver/mysql v1.5.0
gorm.io/driver/postgres v1.5.0
- gorm.io/driver/sqlite v1.5.0
gorm.io/driver/sqlserver v1.4.3
- gorm.io/gorm v1.25.0
+ gorm.io/gorm v1.25.1
)
-replace gorm.io/gorm => ../
+replace github.com/ncruces/go-sqlite3/gormlite => ../
diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go
index 1412169..472434b 100644
--- a/tests/scanner_valuer_test.go
+++ b/tests/scanner_valuer_test.go
@@ -170,10 +170,10 @@ func (data *EncryptedData) Scan(value interface{}) error {
return errors.New("Too short")
}
- *data = b[3:]
+ *data = append((*data)[0:], b[3:]...)
return nil
} else if s, ok := value.(string); ok {
- *data = []byte(s)[3:]
+ *data = []byte(s[3:])
return nil
}
diff --git a/tests/tests_test.go b/tests/tests_test.go
index 90eb847..cd9af43 100644
--- a/tests/tests_test.go
+++ b/tests/tests_test.go
@@ -7,9 +7,11 @@ import (
"path/filepath"
"time"
+ _ "github.com/ncruces/go-sqlite3/embed"
+ sqlite "github.com/ncruces/go-sqlite3/gormlite"
+
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
- "gorm.io/driver/sqlite"
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
"gorm.io/gorm/logger"

22
internal/util/bool.go Normal file
View File

@@ -0,0 +1,22 @@
package util
import "strings"
func ParseBool(s string) (b, ok bool) {
if len(s) == 0 {
return false, false
}
if s[0] == '0' {
return false, true
}
if '1' <= s[0] && s[0] <= '9' {
return true, true
}
switch strings.ToLower(s) {
case "true", "yes", "on":
return true, true
case "false", "no", "off":
return false, true
}
return false, false
}

View File

@@ -0,0 +1,28 @@
package util
import "testing"
func TestParseBool(t *testing.T) {
tests := []struct {
str string
val bool
ok bool
}{
{"", false, false},
{"0", false, true},
{"1", true, true},
{"9", true, true},
{"T", false, false},
{"true", true, true},
{"FALSE", false, true},
{"false?", false, false},
}
for _, tt := range tests {
t.Run(tt.str, func(t *testing.T) {
gotVal, gotOK := ParseBool(tt.str)
if gotVal != tt.val || gotOK != tt.ok {
t.Errorf("ParseBool(%q) = (%v, %v) want (%v, %v)", tt.str, gotVal, gotOK, tt.val, tt.ok)
}
})
}
}

116
internal/util/const.go Normal file
View File

@@ -0,0 +1,116 @@
package util
// https://sqlite.com/matrix/rescode.html
const (
OK = 0 /* Successful result */
ERROR = 1 /* Generic error */
INTERNAL = 2 /* Internal logic error in SQLite */
PERM = 3 /* Access permission denied */
ABORT = 4 /* Callback routine requested an abort */
BUSY = 5 /* The database file is locked */
LOCKED = 6 /* A table in the database is locked */
NOMEM = 7 /* A malloc() failed */
READONLY = 8 /* Attempt to write a readonly database */
INTERRUPT = 9 /* Operation terminated by sqlite3_interrupt() */
IOERR = 10 /* Some kind of disk I/O error occurred */
CORRUPT = 11 /* The database disk image is malformed */
NOTFOUND = 12 /* Unknown opcode in sqlite3_file_control() */
FULL = 13 /* Insertion failed because database is full */
CANTOPEN = 14 /* Unable to open the database file */
PROTOCOL = 15 /* Database lock protocol error */
EMPTY = 16 /* Internal use only */
SCHEMA = 17 /* The database schema changed */
TOOBIG = 18 /* String or BLOB exceeds size limit */
CONSTRAINT = 19 /* Abort due to constraint violation */
MISMATCH = 20 /* Data type mismatch */
MISUSE = 21 /* Library used incorrectly */
NOLFS = 22 /* Uses OS features not supported on host */
AUTH = 23 /* Authorization denied */
FORMAT = 24 /* Not used */
RANGE = 25 /* 2nd parameter to sqlite3_bind out of range */
NOTADB = 26 /* File opened that is not a database file */
NOTICE = 27 /* Notifications from sqlite3_log() */
WARNING = 28 /* Warnings from sqlite3_log() */
ROW = 100 /* sqlite3_step() has another row ready */
DONE = 101 /* sqlite3_step() has finished executing */
ERROR_MISSING_COLLSEQ = ERROR | (1 << 8)
ERROR_RETRY = ERROR | (2 << 8)
ERROR_SNAPSHOT = ERROR | (3 << 8)
IOERR_READ = IOERR | (1 << 8)
IOERR_SHORT_READ = IOERR | (2 << 8)
IOERR_WRITE = IOERR | (3 << 8)
IOERR_FSYNC = IOERR | (4 << 8)
IOERR_DIR_FSYNC = IOERR | (5 << 8)
IOERR_TRUNCATE = IOERR | (6 << 8)
IOERR_FSTAT = IOERR | (7 << 8)
IOERR_UNLOCK = IOERR | (8 << 8)
IOERR_RDLOCK = IOERR | (9 << 8)
IOERR_DELETE = IOERR | (10 << 8)
IOERR_BLOCKED = IOERR | (11 << 8)
IOERR_NOMEM = IOERR | (12 << 8)
IOERR_ACCESS = IOERR | (13 << 8)
IOERR_CHECKRESERVEDLOCK = IOERR | (14 << 8)
IOERR_LOCK = IOERR | (15 << 8)
IOERR_CLOSE = IOERR | (16 << 8)
IOERR_DIR_CLOSE = IOERR | (17 << 8)
IOERR_SHMOPEN = IOERR | (18 << 8)
IOERR_SHMSIZE = IOERR | (19 << 8)
IOERR_SHMLOCK = IOERR | (20 << 8)
IOERR_SHMMAP = IOERR | (21 << 8)
IOERR_SEEK = IOERR | (22 << 8)
IOERR_DELETE_NOENT = IOERR | (23 << 8)
IOERR_MMAP = IOERR | (24 << 8)
IOERR_GETTEMPPATH = IOERR | (25 << 8)
IOERR_CONVPATH = IOERR | (26 << 8)
IOERR_VNODE = IOERR | (27 << 8)
IOERR_AUTH = IOERR | (28 << 8)
IOERR_BEGIN_ATOMIC = IOERR | (29 << 8)
IOERR_COMMIT_ATOMIC = IOERR | (30 << 8)
IOERR_ROLLBACK_ATOMIC = IOERR | (31 << 8)
IOERR_DATA = IOERR | (32 << 8)
IOERR_CORRUPTFS = IOERR | (33 << 8)
LOCKED_SHAREDCACHE = LOCKED | (1 << 8)
LOCKED_VTAB = LOCKED | (2 << 8)
BUSY_RECOVERY = BUSY | (1 << 8)
BUSY_SNAPSHOT = BUSY | (2 << 8)
BUSY_TIMEOUT = BUSY | (3 << 8)
CANTOPEN_NOTEMPDIR = CANTOPEN | (1 << 8)
CANTOPEN_ISDIR = CANTOPEN | (2 << 8)
CANTOPEN_FULLPATH = CANTOPEN | (3 << 8)
CANTOPEN_CONVPATH = CANTOPEN | (4 << 8)
CANTOPEN_DIRTYWAL = CANTOPEN | (5 << 8) /* Not Used */
CANTOPEN_SYMLINK = CANTOPEN | (6 << 8)
CORRUPT_VTAB = CORRUPT | (1 << 8)
CORRUPT_SEQUENCE = CORRUPT | (2 << 8)
CORRUPT_INDEX = CORRUPT | (3 << 8)
READONLY_RECOVERY = READONLY | (1 << 8)
READONLY_CANTLOCK = READONLY | (2 << 8)
READONLY_ROLLBACK = READONLY | (3 << 8)
READONLY_DBMOVED = READONLY | (4 << 8)
READONLY_CANTINIT = READONLY | (5 << 8)
READONLY_DIRECTORY = READONLY | (6 << 8)
ABORT_ROLLBACK = ABORT | (2 << 8)
CONSTRAINT_CHECK = CONSTRAINT | (1 << 8)
CONSTRAINT_COMMITHOOK = CONSTRAINT | (2 << 8)
CONSTRAINT_FOREIGNKEY = CONSTRAINT | (3 << 8)
CONSTRAINT_FUNCTION = CONSTRAINT | (4 << 8)
CONSTRAINT_NOTNULL = CONSTRAINT | (5 << 8)
CONSTRAINT_PRIMARYKEY = CONSTRAINT | (6 << 8)
CONSTRAINT_TRIGGER = CONSTRAINT | (7 << 8)
CONSTRAINT_UNIQUE = CONSTRAINT | (8 << 8)
CONSTRAINT_VTAB = CONSTRAINT | (9 << 8)
CONSTRAINT_ROWID = CONSTRAINT | (10 << 8)
CONSTRAINT_PINNED = CONSTRAINT | (11 << 8)
CONSTRAINT_DATATYPE = CONSTRAINT | (12 << 8)
NOTICE_RECOVER_WAL = NOTICE | (1 << 8)
NOTICE_RECOVER_ROLLBACK = NOTICE | (2 << 8)
NOTICE_RBU = NOTICE | (3 << 8)
WARNING_AUTOINDEX = WARNING | (1 << 8)
AUTH_USER = AUTH | (1 << 8)
OK_LOAD_PERMANENTLY = OK | (1 << 8)
OK_SYMLINK = OK | (2 << 8) /* internal use only */
)

115
internal/util/error.go Normal file
View File

@@ -0,0 +1,115 @@
package util
import (
"fmt"
"runtime"
"strconv"
)
type ErrorString string
func (e ErrorString) Error() string { return string(e) }
const (
NilErr = ErrorString("sqlite3: invalid memory address or null pointer dereference")
OOMErr = ErrorString("sqlite3: out of memory")
RangeErr = ErrorString("sqlite3: index out of range")
NoNulErr = ErrorString("sqlite3: missing NUL terminator")
NoGlobalErr = ErrorString("sqlite3: could not find global: ")
NoFuncErr = ErrorString("sqlite3: could not find function: ")
BinaryErr = ErrorString("sqlite3: no SQLite binary embed/set/loaded")
TimeErr = ErrorString("sqlite3: invalid time value")
WhenceErr = ErrorString("sqlite3: invalid whence")
OffsetErr = ErrorString("sqlite3: invalid offset")
TailErr = ErrorString("sqlite3: multiple statements")
IsolationErr = ErrorString("sqlite3: unsupported isolation level")
NoVFSErr = ErrorString("sqlite3: no such vfs: ")
)
func AssertErr() ErrorString {
msg := "sqlite3: assertion failed"
if _, file, line, ok := runtime.Caller(1); ok {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return ErrorString(msg)
}
func Finalizer[T any](skip int) func(*T) {
msg := fmt.Sprintf("sqlite3: %T not closed", new(T))
if _, file, line, ok := runtime.Caller(skip + 1); ok && skip >= 0 {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return func(*T) { panic(ErrorString(msg)) }
}
func ErrorCodeString(rc uint32) string {
switch rc {
case ABORT_ROLLBACK:
return "sqlite3: abort due to ROLLBACK"
case ROW:
return "sqlite3: another row available"
case DONE:
return "sqlite3: no more rows available"
}
switch rc & 0xff {
case OK:
return "sqlite3: not an error"
case ERROR:
return "sqlite3: SQL logic error"
case INTERNAL:
break
case PERM:
return "sqlite3: access permission denied"
case ABORT:
return "sqlite3: query aborted"
case BUSY:
return "sqlite3: database is locked"
case LOCKED:
return "sqlite3: database table is locked"
case NOMEM:
return "sqlite3: out of memory"
case READONLY:
return "sqlite3: attempt to write a readonly database"
case INTERRUPT:
return "sqlite3: interrupted"
case IOERR:
return "sqlite3: disk I/O error"
case CORRUPT:
return "sqlite3: database disk image is malformed"
case NOTFOUND:
return "sqlite3: unknown operation"
case FULL:
return "sqlite3: database or disk is full"
case CANTOPEN:
return "sqlite3: unable to open database file"
case PROTOCOL:
return "sqlite3: locking protocol"
case FORMAT:
break
case SCHEMA:
return "sqlite3: database schema has changed"
case TOOBIG:
return "sqlite3: string or blob too big"
case CONSTRAINT:
return "sqlite3: constraint failed"
case MISMATCH:
return "sqlite3: datatype mismatch"
case MISUSE:
return "sqlite3: bad parameter or other API misuse"
case NOLFS:
break
case AUTH:
return "sqlite3: authorization denied"
case EMPTY:
break
case RANGE:
return "sqlite3: column index out of range"
case NOTADB:
return "sqlite3: file is not a database"
case NOTICE:
return "sqlite3: notification message"
case WARNING:
return "sqlite3: warning message"
}
return "sqlite3: unknown error"
}

102
internal/util/func.go Normal file
View File

@@ -0,0 +1,102 @@
package util
import (
"context"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
type i32 interface{ ~int32 | ~uint32 }
type i64 interface{ ~int64 | ~uint64 }
type funcII[TR, T0 i32] func(context.Context, api.Module, T0) TR
func (fn funcII[TR, T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0])))
}
func ExportFuncII[TR, T0 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcII[TR, T0](fn),
[]api.ValueType{api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIII[TR, T0, T1 i32] func(context.Context, api.Module, T0, T1) TR
func (fn funcIII[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
}
func ExportFuncIII[TR, T0, T1 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIII[TR, T0, T1](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIII[TR, T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2) TR
func (fn funcIIII[TR, T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2])))
}
func ExportFuncIIII[TR, T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIII[TR, T0, T1, T2](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIIII[TR, T0, T1, T2, T3 i32] func(context.Context, api.Module, T0, T1, T2, T3) TR
func (fn funcIIIII[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
}
func ExportFuncIIIII[TR, T0, T1, T2, T3 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIIII[TR, T0, T1, T2, T3](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIIIII[TR, T0, T1, T2, T3, T4 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4) TR
func (fn funcIIIIII[TR, T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4])))
}
func ExportFuncIIIIII[TR, T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3, T4) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIIIII[TR, T0, T1, T2, T3, T4](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIIIJ[TR, T0, T1, T2 i32, T3 i64] func(context.Context, api.Module, T0, T1, T2, T3) TR
func (fn funcIIIIJ[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
}
func ExportFuncIIIIJ[TR, T0, T1, T2 i32, T3 i64](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIIIJ[TR, T0, T1, T2, T3](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIJ[TR, T0 i32, T1 i64] func(context.Context, api.Module, T0, T1) TR
func (fn funcIIJ[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
}
func ExportFuncIIJ[TR, T0 i32, T1 i64](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIJ[TR, T0, T1](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}

110
internal/util/mem.go Normal file
View File

@@ -0,0 +1,110 @@
package util
import (
"bytes"
"math"
"github.com/tetratelabs/wazero/api"
)
func View(mod api.Module, ptr uint32, size uint64) []byte {
if ptr == 0 {
panic(NilErr)
}
if size > math.MaxUint32 {
panic(RangeErr)
}
buf, ok := mod.Memory().Read(ptr, uint32(size))
if !ok {
panic(RangeErr)
}
return buf
}
func ReadUint32(mod api.Module, ptr uint32) uint32 {
if ptr == 0 {
panic(NilErr)
}
v, ok := mod.Memory().ReadUint32Le(ptr)
if !ok {
panic(RangeErr)
}
return v
}
func WriteUint32(mod api.Module, ptr uint32, v uint32) {
if ptr == 0 {
panic(NilErr)
}
ok := mod.Memory().WriteUint32Le(ptr, v)
if !ok {
panic(RangeErr)
}
}
func ReadUint64(mod api.Module, ptr uint32) uint64 {
if ptr == 0 {
panic(NilErr)
}
v, ok := mod.Memory().ReadUint64Le(ptr)
if !ok {
panic(RangeErr)
}
return v
}
func WriteUint64(mod api.Module, ptr uint32, v uint64) {
if ptr == 0 {
panic(NilErr)
}
ok := mod.Memory().WriteUint64Le(ptr, v)
if !ok {
panic(RangeErr)
}
}
func ReadFloat64(mod api.Module, ptr uint32) float64 {
return math.Float64frombits(ReadUint64(mod, ptr))
}
func WriteFloat64(mod api.Module, ptr uint32, v float64) {
WriteUint64(mod, ptr, math.Float64bits(v))
}
func ReadString(mod api.Module, ptr, maxlen uint32) string {
if ptr == 0 {
panic(NilErr)
}
switch maxlen {
case 0:
return ""
case math.MaxUint32:
// avoid overflow
default:
maxlen = maxlen + 1
}
mem := mod.Memory()
buf, ok := mem.Read(ptr, maxlen)
if !ok {
buf, ok = mem.Read(ptr, mem.Size()-ptr)
if !ok {
panic(RangeErr)
}
}
if i := bytes.IndexByte(buf, 0); i < 0 {
panic(NoNulErr)
} else {
return string(buf[:i])
}
}
func WriteBytes(mod api.Module, ptr uint32, b []byte) {
buf := View(mod, ptr, uint64(len(b)))
copy(buf, b)
}
func WriteString(mod api.Module, ptr uint32, s string) {
buf := View(mod, ptr, uint64(len(s)+1))
buf[len(s)] = 0
copy(buf, s)
}

92
internal/util/mem_test.go Normal file
View File

@@ -0,0 +1,92 @@
package util
import (
"math"
"testing"
"github.com/tetratelabs/wazero/experimental/wazerotest"
)
func TestView_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
View(mock, 0, 8)
t.Error("want panic")
}
func TestView_range(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
View(mock, wazerotest.PageSize-2, 8)
t.Error("want panic")
}
func TestView_overflow(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
View(mock, 1, math.MaxInt64)
t.Error("want panic")
}
func TestReadUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadUint32(mock, 0)
t.Error("want panic")
}
func TestReadUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadUint32(mock, wazerotest.PageSize-2)
t.Error("want panic")
}
func TestReadUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadUint64(mock, 0)
t.Error("want panic")
}
func TestReadUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadUint64(mock, wazerotest.PageSize-2)
t.Error("want panic")
}
func TestWriteUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
WriteUint32(mock, 0, 1)
t.Error("want panic")
}
func TestWriteUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
WriteUint32(mock, wazerotest.PageSize-2, 1)
t.Error("want panic")
}
func TestWriteUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
WriteUint64(mock, 0, 1)
t.Error("want panic")
}
func TestWriteUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
WriteUint64(mock, wazerotest.PageSize-2, 1)
t.Error("want panic")
}
func TestReadString_range(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadString(mock, wazerotest.PageSize+2, math.MaxUint32)
t.Error("want panic")
}

114
mem.go
View File

@@ -1,114 +0,0 @@
package sqlite3
import (
"bytes"
"math"
"github.com/tetratelabs/wazero/api"
)
type memory struct {
mod api.Module
}
func (m memory) view(ptr uint32, size uint64) []byte {
if ptr == 0 {
panic(nilErr)
}
if size > math.MaxUint32 {
panic(rangeErr)
}
buf, ok := m.mod.Memory().Read(ptr, uint32(size))
if !ok {
panic(rangeErr)
}
return buf
}
func (m memory) readUint32(ptr uint32) uint32 {
if ptr == 0 {
panic(nilErr)
}
v, ok := m.mod.Memory().ReadUint32Le(ptr)
if !ok {
panic(rangeErr)
}
return v
}
func (m memory) writeUint32(ptr, v uint32) {
if ptr == 0 {
panic(nilErr)
}
ok := m.mod.Memory().WriteUint32Le(ptr, v)
if !ok {
panic(rangeErr)
}
}
func (m memory) readUint64(ptr uint32) uint64 {
if ptr == 0 {
panic(nilErr)
}
v, ok := m.mod.Memory().ReadUint64Le(ptr)
if !ok {
panic(rangeErr)
}
return v
}
func (m memory) writeUint64(ptr uint32, v uint64) {
if ptr == 0 {
panic(nilErr)
}
ok := m.mod.Memory().WriteUint64Le(ptr, v)
if !ok {
panic(rangeErr)
}
}
func (m memory) readFloat64(ptr uint32) float64 {
return math.Float64frombits(m.readUint64(ptr))
}
func (m memory) writeFloat64(ptr uint32, v float64) {
m.writeUint64(ptr, math.Float64bits(v))
}
func (m memory) readString(ptr, maxlen uint32) string {
if ptr == 0 {
panic(nilErr)
}
switch maxlen {
case 0:
return ""
case math.MaxUint32:
// avoid overflow
default:
maxlen = maxlen + 1
}
mem := m.mod.Memory()
buf, ok := mem.Read(ptr, maxlen)
if !ok {
buf, ok = mem.Read(ptr, mem.Size()-ptr)
if !ok {
panic(rangeErr)
}
}
if i := bytes.IndexByte(buf, 0); i < 0 {
panic(noNulErr)
} else {
return string(buf[:i])
}
}
func (m memory) writeBytes(ptr uint32, b []byte) {
buf := m.view(ptr, uint64(len(b)))
copy(buf, b)
}
func (m memory) writeString(ptr uint32, s string) {
buf := m.view(ptr, uint64(len(s)+1))
buf[len(s)] = 0
copy(buf, s)
}

View File

@@ -1,90 +0,0 @@
package sqlite3
import (
"math"
"testing"
)
func Test_memory_view_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.view(0, 8)
t.Error("want panic")
}
func Test_memory_view_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.view(126, 8)
t.Error("want panic")
}
func Test_memory_view_overflow(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.view(1, math.MaxInt64)
t.Error("want panic")
}
func Test_memory_readUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint32(0)
t.Error("want panic")
}
func Test_memory_readUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint32(126)
t.Error("want panic")
}
func Test_memory_readUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint64(0)
t.Error("want panic")
}
func Test_memory_readUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint64(126)
t.Error("want panic")
}
func Test_memory_writeUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint32(0, 1)
t.Error("want panic")
}
func Test_memory_writeUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint32(126, 1)
t.Error("want panic")
}
func Test_memory_writeUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint64(0, 1)
t.Error("want panic")
}
func Test_memory_writeUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint64(126, 1)
t.Error("want panic")
}
func Test_memory_readString_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readString(130, math.MaxUint32)
t.Error("want panic")
}

View File

@@ -1,161 +0,0 @@
package sqlite3
import (
"context"
"encoding/binary"
"math"
"github.com/tetratelabs/wazero/api"
)
func init() {
Path = "./embed/sqlite3.wasm"
}
func newMemory(size uint32) memory {
mem := make(mockMemory, size)
return memory{mockModule{&mem}}
}
type mockModule struct {
memory api.Memory
}
func (m mockModule) Memory() api.Memory { return m.memory }
func (m mockModule) String() string { return "mockModule" }
func (m mockModule) Name() string { return "mockModule" }
func (m mockModule) ExportedGlobal(name string) api.Global { return nil }
func (m mockModule) ExportedMemory(name string) api.Memory { return nil }
func (m mockModule) ExportedFunction(name string) api.Function { return nil }
func (m mockModule) ExportedMemoryDefinitions() map[string]api.MemoryDefinition { return nil }
func (m mockModule) ExportedFunctionDefinitions() map[string]api.FunctionDefinition { return nil }
func (m mockModule) CloseWithExitCode(ctx context.Context, exitCode uint32) error { return nil }
func (m mockModule) Close(context.Context) error { return nil }
type mockMemory []byte
func (m mockMemory) Definition() api.MemoryDefinition { return nil }
func (m mockMemory) Size() uint32 { return uint32(len(m)) }
func (m mockMemory) ReadByte(offset uint32) (byte, bool) {
if offset >= m.Size() {
return 0, false
}
return m[offset], true
}
func (m mockMemory) ReadUint16Le(offset uint32) (uint16, bool) {
if !m.hasSize(offset, 2) {
return 0, false
}
return binary.LittleEndian.Uint16(m[offset : offset+2]), true
}
func (m mockMemory) ReadUint32Le(offset uint32) (uint32, bool) {
if !m.hasSize(offset, 4) {
return 0, false
}
return binary.LittleEndian.Uint32(m[offset : offset+4]), true
}
func (m mockMemory) ReadFloat32Le(offset uint32) (float32, bool) {
v, ok := m.ReadUint32Le(offset)
if !ok {
return 0, false
}
return math.Float32frombits(v), true
}
func (m mockMemory) ReadUint64Le(offset uint32) (uint64, bool) {
if !m.hasSize(offset, 8) {
return 0, false
}
return binary.LittleEndian.Uint64(m[offset : offset+8]), true
}
func (m mockMemory) ReadFloat64Le(offset uint32) (float64, bool) {
v, ok := m.ReadUint64Le(offset)
if !ok {
return 0, false
}
return math.Float64frombits(v), true
}
func (m mockMemory) Read(offset, byteCount uint32) ([]byte, bool) {
if !m.hasSize(offset, byteCount) {
return nil, false
}
return m[offset : offset+byteCount : offset+byteCount], true
}
func (m mockMemory) WriteByte(offset uint32, v byte) bool {
if offset >= m.Size() {
return false
}
m[offset] = v
return true
}
func (m mockMemory) WriteUint16Le(offset uint32, v uint16) bool {
if !m.hasSize(offset, 2) {
return false
}
binary.LittleEndian.PutUint16(m[offset:], v)
return true
}
func (m mockMemory) WriteUint32Le(offset, v uint32) bool {
if !m.hasSize(offset, 4) {
return false
}
binary.LittleEndian.PutUint32(m[offset:], v)
return true
}
func (m mockMemory) WriteFloat32Le(offset uint32, v float32) bool {
return m.WriteUint32Le(offset, math.Float32bits(v))
}
func (m mockMemory) WriteUint64Le(offset uint32, v uint64) bool {
if !m.hasSize(offset, 8) {
return false
}
binary.LittleEndian.PutUint64(m[offset:], v)
return true
}
func (m mockMemory) WriteFloat64Le(offset uint32, v float64) bool {
return m.WriteUint64Le(offset, math.Float64bits(v))
}
func (m mockMemory) Write(offset uint32, val []byte) bool {
if !m.hasSize(offset, uint32(len(val))) {
return false
}
copy(m[offset:], val)
return true
}
func (m mockMemory) WriteString(offset uint32, val string) bool {
if !m.hasSize(offset, uint32(len(val))) {
return false
}
copy(m[offset:], val)
return true
}
func (m *mockMemory) Grow(delta uint32) (result uint32, ok bool) {
prev := (len(*m) + 65535) / 65536
*m = append(*m, make([]byte, 65536*delta)...)
return uint32(prev), true
}
func (m mockMemory) PageSize() (result uint32) {
return uint32(len(m) / 65536)
}
func (m mockMemory) hasSize(offset uint32, byteCount uint32) bool {
return uint64(offset)+uint64(byteCount) <= uint64(len(m))
}

125
module.go
View File

@@ -3,15 +3,13 @@ package sqlite3
import (
"context"
"crypto/rand"
"io"
"math"
"os"
"runtime"
"strconv"
"sync"
"sync/atomic"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/vfs"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
@@ -25,14 +23,14 @@ import (
var (
Binary []byte // WASM binary to load.
Path string // Path to load the binary from.
Debug bool // Whether to enable SQLite debug stack traces.
)
var sqlite3 struct {
once sync.Once
runtime wazero.Runtime
compiled wazero.CompiledModule
instances atomic.Uint64
err error
runtime wazero.Runtime
compiled wazero.CompiledModule
err error
once sync.Once
}
func instantiateModule() (*module, error) {
@@ -43,12 +41,7 @@ func instantiateModule() (*module, error) {
return nil, sqlite3.err
}
name := "sqlite3-" + strconv.FormatUint(sqlite3.instances.Add(1), 10)
cfg := wazero.NewModuleConfig().WithName(name).
WithSysWalltime().WithSysNanotime().WithSysNanosleep().
WithOsyield(runtime.Gosched).
WithRandSource(rand.Reader)
cfg := wazero.NewModuleConfig()
mod, err := sqlite3.runtime.InstantiateModule(ctx, sqlite3.compiled, cfg)
if err != nil {
@@ -59,8 +52,13 @@ func instantiateModule() (*module, error) {
func compileModule() {
ctx := context.Background()
sqlite3.runtime = wazero.NewRuntime(ctx)
vfsInstantiate(ctx, sqlite3.runtime)
sqlite3.runtime = wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfig().WithDebugInfoEnabled(Debug))
env := vfs.ExportHostFunctions(sqlite3.runtime.NewHostModuleBuilder("env"))
_, sqlite3.err = env.Instantiate(ctx)
if sqlite3.err != nil {
return
}
bin := Binary
if bin == nil && Path != "" {
@@ -70,7 +68,7 @@ func compileModule() {
}
}
if bin == nil {
sqlite3.err = binaryErr
sqlite3.err = util.BinaryErr
return
}
@@ -79,38 +77,39 @@ func compileModule() {
type module struct {
ctx context.Context
mem memory
api sqliteAPI
mod api.Module
vfs io.Closer
api sqliteAPI
arg [8]uint64
}
func newModule(mod api.Module) (m *module, err error) {
m = &module{}
m.mem = memory{mod}
m.ctx, m.vfs = vfsContext(context.Background())
m = new(module)
m.mod = mod
m.ctx, m.vfs = vfs.NewContext(context.Background())
getFun := func(name string) api.Function {
f := mod.ExportedFunction(name)
if f == nil {
err = noFuncErr + errorString(name)
err = util.NoFuncErr + util.ErrorString(name)
return nil
}
return f
}
getVal := func(name string) uint32 {
global := mod.ExportedGlobal(name)
if global == nil {
err = noGlobalErr + errorString(name)
g := mod.ExportedGlobal(name)
if g == nil {
err = util.NoGlobalErr + util.ErrorString(name)
return 0
}
return m.mem.readUint32(uint32(global.Get()))
return util.ReadUint32(mod, uint32(g.Get()))
}
m.api = sqliteAPI{
free: getFun("free"),
malloc: getFun("malloc"),
destructor: uint64(getVal("malloc_destructor")),
destructor: getVal("malloc_destructor"),
errcode: getFun("sqlite3_errcode"),
errstr: getFun("sqlite3_errstr"),
errmsg: getFun("sqlite3_errmsg"),
@@ -141,9 +140,6 @@ func newModule(mod api.Module) (m *module, err error) {
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
autocommit: getFun("sqlite3_get_autocommit"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
changes: getFun("sqlite3_changes64"),
blobOpen: getFun("sqlite3_blob_open"),
blobClose: getFun("sqlite3_blob_close"),
blobReopen: getFun("sqlite3_blob_reopen"),
@@ -155,7 +151,9 @@ func newModule(mod api.Module) (m *module, err error) {
backupFinish: getFun("sqlite3_backup_finish"),
backupRemaining: getFun("sqlite3_backup_remaining"),
backupPageCount: getFun("sqlite3_backup_pagecount"),
interrupt: getVal("sqlite3_interrupt_offset"),
changes: getFun("sqlite3_changes64"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
autocommit: getFun("sqlite3_get_autocommit"),
}
if err != nil {
return nil, err
@@ -164,7 +162,7 @@ func newModule(mod api.Module) (m *module, err error) {
}
func (m *module) close() error {
err := m.mem.mod.Close(m.ctx)
err := m.mod.Close(m.ctx)
m.vfs.Close()
return err
}
@@ -177,25 +175,20 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
err := Error{code: rc}
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
panic(oomErr)
panic(util.OOMErr)
}
var r []uint64
r = m.call(m.api.errstr, rc)
if r != nil {
err.str = m.mem.readString(uint32(r[0]), _MAX_STRING)
if r := m.call(m.api.errstr, rc); r != 0 {
err.str = util.ReadString(m.mod, uint32(r), _MAX_STRING)
}
r = m.call(m.api.errmsg, uint64(handle))
if r != nil {
err.msg = m.mem.readString(uint32(r[0]), _MAX_STRING)
if r := m.call(m.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(m.mod, uint32(r), _MAX_STRING)
}
if sql != nil {
r = m.call(m.api.erroff, uint64(handle))
if r != nil && r[0] != math.MaxUint32 {
err.sql = sql[0][r[0]:]
if r := m.call(m.api.erroff, uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:]
}
}
@@ -206,14 +199,15 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
return &err
}
func (m *module) call(fn api.Function, params ...uint64) []uint64 {
r, err := fn.Call(m.ctx, params...)
func (m *module) call(fn api.Function, params ...uint64) uint64 {
copy(m.arg[:], params)
err := fn.CallWithStack(m.ctx, m.arg[:])
if err != nil {
// The module closed or panicked; release resources.
m.vfs.Close()
panic(err)
}
return r
return m.arg[0]
}
func (m *module) free(ptr uint32) {
@@ -225,12 +219,11 @@ func (m *module) free(ptr uint32) {
func (m *module) new(size uint64) uint32 {
if size > _MAX_ALLOCATION_SIZE {
panic(oomErr)
panic(util.OOMErr)
}
r := m.call(m.api.malloc, size)
ptr := uint32(r[0])
ptr := uint32(m.call(m.api.malloc, size))
if ptr == 0 && size != 0 {
panic(oomErr)
panic(util.OOMErr)
}
return ptr
}
@@ -240,13 +233,13 @@ func (m *module) newBytes(b []byte) uint32 {
return 0
}
ptr := m.new(uint64(len(b)))
m.mem.writeBytes(ptr, b)
util.WriteBytes(m.mod, ptr, b)
return ptr
}
func (m *module) newString(s string) uint32 {
ptr := m.new(uint64(len(s) + 1))
m.mem.writeString(ptr, s)
util.WriteString(m.mod, ptr, s)
return ptr
}
@@ -260,10 +253,10 @@ func (m *module) newArena(size uint64) arena {
type arena struct {
m *module
ptrs []uint32
base uint32
next uint32
size uint32
ptrs []uint32
}
func (a *arena) free() {
@@ -294,16 +287,24 @@ func (a *arena) new(size uint64) uint32 {
return ptr
}
func (a *arena) bytes(b []byte) uint32 {
if b == nil {
return 0
}
ptr := a.new(uint64(len(b)))
util.WriteBytes(a.m.mod, ptr, b)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint64(len(s) + 1))
a.m.mem.writeString(ptr, s)
util.WriteString(a.m.mod, ptr, s)
return ptr
}
type sqliteAPI struct {
free api.Function
malloc api.Function
destructor uint64
errcode api.Function
errstr api.Function
errmsg api.Function
@@ -334,9 +335,6 @@ type sqliteAPI struct {
columnText api.Function
columnBlob api.Function
columnBytes api.Function
autocommit api.Function
lastRowid api.Function
changes api.Function
blobOpen api.Function
blobClose api.Function
blobReopen api.Function
@@ -348,5 +346,8 @@ type sqliteAPI struct {
backupFinish api.Function
backupRemaining api.Function
backupPageCount api.Function
interrupt uint32
changes api.Function
lastRowid api.Function
autocommit api.Function
destructor uint32
}

View File

@@ -4,8 +4,14 @@ import (
"bytes"
"math"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
)
func init() {
Path = "./embed/sqlite3.wasm"
}
func TestConn_error_OOM(t *testing.T) {
t.Parallel()
@@ -20,14 +26,14 @@ func TestConn_error_OOM(t *testing.T) {
t.Error("want panic")
}
func TestConn_call_nil(t *testing.T) {
func TestConn_call_closed(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer m.close()
m.close()
defer func() { _ = recover() }()
m.call(m.api.free)
@@ -43,14 +49,18 @@ func TestConn_new(t *testing.T) {
}
defer m.close()
testOOM := func(size uint64) {
t.Run("MaxUint32", func(t *testing.T) {
defer func() { _ = recover() }()
m.new(size)
m.new(math.MaxUint32)
t.Error("want panic")
}
})
testOOM(math.MaxUint32)
testOOM(_MAX_ALLOCATION_SIZE)
t.Run("_MAX_ALLOCATION_SIZE", func(t *testing.T) {
defer func() { _ = recover() }()
m.new(_MAX_ALLOCATION_SIZE)
m.new(_MAX_ALLOCATION_SIZE)
t.Error("want panic")
})
}
func TestConn_newArena(t *testing.T) {
@@ -66,12 +76,11 @@ func TestConn_newArena(t *testing.T) {
defer arena.free()
const title = "Lorem ipsum"
ptr := arena.string(title)
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := m.mem.readString(ptr, math.MaxUint32); got != title {
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != title {
t.Errorf("got %q, want %q", got, title)
}
@@ -80,9 +89,22 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := m.mem.readString(ptr, math.MaxUint32); got != body {
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != body {
t.Errorf("got %q, want %q", got, body)
}
ptr = arena.bytes(nil)
if ptr != 0 {
t.Errorf("want nullptr")
}
ptr = arena.bytes([]byte(title))
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := util.View(m.mod, ptr, uint64(len(title))); string(got) != title {
t.Errorf("got %q, want %q", got, title)
}
arena.free()
}
@@ -107,7 +129,7 @@ func TestConn_newBytes(t *testing.T) {
}
want := buf
if got := m.mem.view(ptr, uint64(len(want))); !bytes.Equal(got, want) {
if got := util.View(m.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -133,7 +155,7 @@ func TestConn_newString(t *testing.T) {
}
want := str + "\000"
if got := m.mem.view(ptr, uint64(len(want))); string(got) != want {
if got := util.View(m.mod, ptr, uint64(len(want))); string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -159,22 +181,22 @@ func TestConn_getString(t *testing.T) {
}
want := "sqlite3"
if got := m.mem.readString(ptr, math.MaxUint32); got != want {
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != want {
t.Errorf("got %q, want %q", got, want)
}
if got := m.mem.readString(ptr, 0); got != "" {
if got := util.ReadString(m.mod, ptr, 0); got != "" {
t.Errorf("got %q, want empty", got)
}
func() {
defer func() { _ = recover() }()
m.mem.readString(ptr, uint32(len(want)/2))
util.ReadString(m.mod, ptr, uint32(len(want)/2))
t.Error("want panic")
}()
func() {
defer func() { _ = recover() }()
m.mem.readString(0, math.MaxUint32)
util.ReadString(m.mod, 0, math.MaxUint32)
t.Error("want panic")
}()
}

View File

@@ -1,9 +0,0 @@
#include <stddef.h>
#include "main.c"
#include "os.c"
#include "qsort.c"
#include "sqlite3.c"
sqlite3_destructor_type malloc_destructor = &free;
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);

20
sqlite3/deserialize.patch Normal file
View File

@@ -0,0 +1,20 @@
--- sqlite3.c.orig
+++ sqlite3.c
@@ -60425,7 +60425,7 @@
int rc = SQLITE_OK; /* Return code */
int tempFile = 0; /* True for temp files (incl. in-memory files) */
int memDb = 0; /* True if this is an in-memory file */
-#ifndef SQLITE_OMIT_DESERIALIZE
+#if 1
int memJM = 0; /* Memory journal mode */
#else
# define memJM 0
@@ -60628,7 +60628,7 @@
int fout = 0; /* VFS flags returned by xOpen() */
rc = sqlite3OsOpen(pVfs, pPager->zFilename, pPager->fd, vfsFlags, &fout);
assert( !memDb );
-#ifndef SQLITE_OMIT_DESERIALIZE
+#if 1
pPager->memVfs = memJM = (fout&SQLITE_OPEN_MEMORY)!=0;
#endif
readOnly = (fout&SQLITE_OPEN_READONLY)!=0;

View File

@@ -1,13 +1,34 @@
#!/usr/bin/env bash
set -eo pipefail
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
if [ ! -f "sqlite3.c" ]; then
url="https://sqlite.org/2023/sqlite-amalgamation-3410000.zip"
curl "$url" > sqlite.zip
unzip -d . sqlite.zip
mv sqlite-amalgamation-*/sqlite3* .
rm -rf sqlite-amalgamation-*
rm sqlite.zip
fi
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3420000.zip"
unzip -d . sqlite-amalgamation-*.zip
mv sqlite-amalgamation-*/sqlite3* .
rm -rf sqlite-amalgamation-*
patch < vfs_find.patch
patch < deserialize.patch
cd ext/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uuid.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/series.c"
cd ~-
cd ../vfs/tests/mptest/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/mptest.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config02.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash02.subtest"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/multiwrite01.test"
cd ~-
cd ../vfs/tests/speedtest1/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/test/speedtest1.c"
cd ~-

1
sqlite3/ext/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
*.c

View File

@@ -1,8 +1,7 @@
#!/usr/bin/env bash
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
clang-format -i \
main.c \
os.c \
qsort.c \
amalg.c
shopt -s extglob
clang-format -i !(sqlite3*).@(c|h)

View File

@@ -1,14 +1,28 @@
#include <stdbool.h>
#include <stddef.h>
#include "sqlite3.h"
// Configuration
#include "sqlite_cfg.h"
// Amalgamation
#include "sqlite3.c"
// VFS
#include "vfs.c"
// Extensions
#include "ext/base64.c"
#include "ext/decimal.c"
#include "ext/regexp.c"
#include "ext/series.c"
#include "ext/uint.c"
#include "ext/uuid.c"
#include "time.c"
int main() {
int rc = sqlite3_initialize();
if (rc != SQLITE_OK) return 1;
}
sqlite3_vfs *os_vfs();
int sqlite3_os_init() {
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
__attribute__((constructor)) void init() {
sqlite3_initialize();
sqlite3_auto_extension((void (*)(void))sqlite3_base_init);
sqlite3_auto_extension((void (*)(void))sqlite3_decimal_init);
sqlite3_auto_extension((void (*)(void))sqlite3_regexp_init);
sqlite3_auto_extension((void (*)(void))sqlite3_series_init);
sqlite3_auto_extension((void (*)(void))sqlite3_uint_init);
sqlite3_auto_extension((void (*)(void))sqlite3_uuid_init);
sqlite3_auto_extension((void (*)(void))sqlite3_time_init);
}

View File

@@ -1,92 +0,0 @@
#include <time.h>
#include "sqlite3.h"
int os_localtime(sqlite3_int64, struct tm *);
int os_randomness(sqlite3_vfs *, int nByte, char *zOut);
int os_sleep(sqlite3_vfs *, int microseconds);
int os_current_time(sqlite3_vfs *, double *);
int os_current_time_64(sqlite3_vfs *, sqlite3_int64 *);
int os_open(sqlite3_vfs *, sqlite3_filename zName, sqlite3_file *, int flags,
int *pOutFlags);
int os_delete(sqlite3_vfs *, const char *zName, int syncDir);
int os_access(sqlite3_vfs *, const char *zName, int flags, int *pResOut);
int os_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut);
struct os_file {
sqlite3_file base;
int id;
int lock;
};
int os_close(sqlite3_file *);
int os_read(sqlite3_file *, void *, int iAmt, sqlite3_int64 iOfst);
int os_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
int os_truncate(sqlite3_file *, sqlite3_int64 size);
int os_sync(sqlite3_file *, int flags);
int os_file_size(sqlite3_file *, sqlite3_int64 *pSize);
int os_file_control(sqlite3_file *pFile, int op, void *pArg);
int os_lock(sqlite3_file *pFile, int eLock);
int os_unlock(sqlite3_file *pFile, int eLock);
int os_check_reserved_lock(sqlite3_file *pFile, int *pResOut);
static int no_lock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; }
static int no_unlock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; }
static int no_check_reserved_lock(sqlite3_file *pFile, int *pResOut) {
*pResOut = 0;
return SQLITE_OK;
}
static int no_file_control(sqlite3_file *pFile, int op, void *pArg) {
return SQLITE_NOTFOUND;
}
static int no_sector_size(sqlite3_file *pFile) { return 0; }
static int no_device_characteristics(sqlite3_file *pFile) { return 0; }
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
return os_localtime((sqlite3_int64)*pTime, pTm);
}
static int os_open_w(sqlite3_vfs *vfs, sqlite3_filename zName,
sqlite3_file *file, int flags, int *pOutFlags) {
static const sqlite3_io_methods os_io = {
.iVersion = 1,
.xClose = os_close,
.xRead = os_read,
.xWrite = os_write,
.xTruncate = os_truncate,
.xSync = os_sync,
.xFileSize = os_file_size,
.xLock = os_lock,
.xUnlock = os_unlock,
.xCheckReservedLock = os_check_reserved_lock,
.xFileControl = no_file_control,
.xDeviceCharacteristics = no_device_characteristics,
};
int rc = os_open(vfs, zName, file, flags, pOutFlags);
file->pMethods = (char)rc == SQLITE_OK ? &os_io : NULL;
return rc;
}
sqlite3_vfs *os_vfs() {
static sqlite3_vfs os_vfs = {
.iVersion = 2,
.szOsFile = sizeof(struct os_file),
.mxPathname = 512,
.zName = "os",
.xOpen = os_open_w,
.xDelete = os_delete,
.xAccess = os_access,
.xFullPathname = os_full_pathname,
.xRandomness = os_randomness,
.xSleep = os_sleep,
.xCurrentTime = os_current_time,
.xCurrentTimeInt64 = os_current_time_64,
};
return &os_vfs;
}

View File

@@ -1,14 +0,0 @@
#include <stddef.h>
void qsort_r(void *, size_t, size_t,
int (*)(const void *, const void *, void *), void *);
typedef int (*cmpfun)(const void *, const void *);
static int wrapper_cmp(const void *v1, const void *v2, void *cmp) {
return ((cmpfun)cmp)(v1, v2);
}
void qsort(void *base, size_t nel, size_t width, cmpfun cmp) {
qsort_r(base, nel, width, wrapper_cmp, cmp);
}

View File

@@ -28,8 +28,14 @@
#define SQLITE_OMIT_AUTOINIT
#define SQLITE_USE_ALLOCA
// Other Options
#define SQLITE_ALLOW_URI_AUTHORITY
#define SQLITE_ENABLE_BATCH_ATOMIC_WRITE
#define SQLITE_ENABLE_ATOMIC_WRITE
#define SQLITE_OMIT_DESERIALIZE
// Because WASM does not support shared memory,
// SQLite disables it for WASM builds.
// SQLite disables WAL for WASM builds.
// We set the default locking mode to EXCLUSIVE instead.
// https://www.sqlite.org/wal.html#noshm
#undef SQLITE_OMIT_WAL
@@ -37,7 +43,7 @@
#define SQLITE_DEFAULT_LOCKING_MODE 1
#endif
// Recommended Extensions
// Amalgamated Extensions
#define SQLITE_ENABLE_MATH_FUNCTIONS 1
#define SQLITE_ENABLE_JSON1 1
@@ -48,15 +54,9 @@
#define SQLITE_ENABLE_RTREE 1
#define SQLITE_ENABLE_GEOPOLY 1
// Snapshot
// #define SQLITE_ENABLE_SNAPSHOT 1
// Session Extension
// #define SQLITE_ENABLE_SESSION 1
// #define SQLITE_ENABLE_PREUPDATE_HOOK 1
// #define SQLITE_ENABLE_SESSION
// #define SQLITE_ENABLE_PREUPDATE_HOOK
// Resumable Bulk Update Extension
// #define SQLITE_ENABLE_RBU 1
// Implemented in Go.
// Implemented in vfs.c.
int localtime_s(struct tm *const pTm, time_t const *const pTime);

32
sqlite3/time.c Normal file
View File

@@ -0,0 +1,32 @@
#include <string.h>
#include "sqlite3.h"
static int time_collation(void *pArg, int nKey1, const void *pKey1, int nKey2,
const void *pKey2) {
// Remove a Z suffix if one key is no longer than the other.
// A Z suffix collates before any character but after the empty string.
// This avoids making different keys equal.
const int nK1 = nKey1;
const int nK2 = nKey2;
const char *pK1 = (const char *)pKey1;
const char *pK2 = (const char *)pKey2;
if (nK1 && nK1 <= nK2 && pK1[nK1 - 1] == 'Z') {
nKey1--;
}
if (nK2 && nK2 <= nK1 && pK2[nK2 - 1] == 'Z') {
nKey2--;
}
int n = nKey1 < nKey2 ? nKey1 : nKey2;
int rc = memcmp(pKey1, pKey2, n);
if (rc == 0) {
rc = nKey1 - nKey2;
}
return rc;
}
int sqlite3_time_init(sqlite3 *db, char **pzErrMsg,
const sqlite3_api_routines *pApi) {
return sqlite3_create_collation(db, "time", SQLITE_UTF8, 0, time_collation);
}

137
sqlite3/vfs.c Normal file
View File

@@ -0,0 +1,137 @@
#include <time.h>
#include "sqlite3.h"
int go_localtime(struct tm *, sqlite3_int64);
int go_vfs_find(const char *zVfsName);
int go_randomness(sqlite3_vfs *, int nByte, char *zOut);
int go_sleep(sqlite3_vfs *, int microseconds);
int go_current_time(sqlite3_vfs *, double *);
int go_current_time_64(sqlite3_vfs *, sqlite3_int64 *);
int go_open(sqlite3_vfs *, sqlite3_filename zName, sqlite3_file *, int flags,
int *pOutFlags);
int go_delete(sqlite3_vfs *, const char *zName, int syncDir);
int go_access(sqlite3_vfs *, const char *zName, int flags, int *pResOut);
int go_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut);
int go_close(sqlite3_file *);
int go_read(sqlite3_file *, void *, int iAmt, sqlite3_int64 iOfst);
int go_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
int go_truncate(sqlite3_file *, sqlite3_int64 size);
int go_sync(sqlite3_file *, int flags);
int go_file_size(sqlite3_file *, sqlite3_int64 *pSize);
int go_file_control(sqlite3_file *, int op, void *pArg);
int go_sector_size(sqlite3_file *file);
int go_device_characteristics(sqlite3_file *file);
int go_lock(sqlite3_file *, int eLock);
int go_unlock(sqlite3_file *, int eLock);
int go_check_reserved_lock(sqlite3_file *, int *pResOut);
static int go_open_wrapper(sqlite3_vfs *vfs, sqlite3_filename zName,
sqlite3_file *file, int flags, int *pOutFlags) {
static const sqlite3_io_methods os_io = {
.iVersion = 1,
.xClose = go_close,
.xRead = go_read,
.xWrite = go_write,
.xTruncate = go_truncate,
.xSync = go_sync,
.xFileSize = go_file_size,
.xLock = go_lock,
.xUnlock = go_unlock,
.xCheckReservedLock = go_check_reserved_lock,
.xFileControl = go_file_control,
.xSectorSize = go_sector_size,
.xDeviceCharacteristics = go_device_characteristics,
};
memset(file, 0, vfs->szOsFile);
int rc = go_open(vfs, zName, file, flags, pOutFlags);
if (rc) {
return rc;
}
file->pMethods = &os_io;
return SQLITE_OK;
}
struct go_file {
sqlite3_file base;
int handle;
};
int sqlite3_os_init() {
static sqlite3_vfs os_vfs = {
.iVersion = 2,
.szOsFile = sizeof(struct go_file),
.mxPathname = 512,
.zName = "os",
.xOpen = go_open_wrapper,
.xDelete = go_delete,
.xAccess = go_access,
.xFullPathname = go_full_pathname,
.xRandomness = go_randomness,
.xSleep = go_sleep,
.xCurrentTime = go_current_time,
.xCurrentTimeInt64 = go_current_time_64,
};
return sqlite3_vfs_register(&os_vfs, /*default=*/true);
}
sqlite3_destructor_type malloc_destructor = &free;
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
return go_localtime(pTm, (sqlite3_int64)*pTime);
}
sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
if (zVfsName) {
static sqlite3_vfs *go_vfs_list;
sqlite3_vfs *found = NULL;
for (sqlite3_vfs **next = &go_vfs_list; *next;) {
sqlite3_vfs *it = *next;
if (go_vfs_find(it->zName)) {
if (!strcmp(zVfsName, it->zName)) found = it;
next = &it->pNext;
} else {
*next = it->pNext;
free(it);
}
}
if (found) {
return found;
}
if (go_vfs_find(zVfsName)) {
sqlite3_vfs *prev = go_vfs_list;
go_vfs_list = malloc(sizeof(sqlite3_vfs) + strlen(zVfsName) + 1);
char *name = (char *)(go_vfs_list + 1);
strcpy(name, zVfsName);
*go_vfs_list = (sqlite3_vfs){
.iVersion = 2,
.szOsFile = sizeof(struct go_file),
.mxPathname = 512,
.zName = name,
.pNext = prev,
.xOpen = go_open_wrapper,
.xDelete = go_delete,
.xAccess = go_access,
.xFullPathname = go_full_pathname,
.xRandomness = go_randomness,
.xSleep = go_sleep,
.xCurrentTime = go_current_time,
.xCurrentTimeInt64 = go_current_time_64,
};
return go_vfs_list;
}
}
return sqlite3_vfs_find_orig(zVfsName);
}
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");
static_assert(offsetof(sqlite3_vfs, zName) == 16, "Unexpected offset");
static_assert(offsetof(sqlite3, u1.isInterrupted) == 280, "Unexpected offset");

11
sqlite3/vfs_find.patch Normal file
View File

@@ -0,0 +1,11 @@
--- sqlite3.c.orig
+++ sqlite3.c
@@ -25394,7 +25394,7 @@
** Locate a VFS by name. If no name is given, simply return the
** first VFS on the list.
*/
-SQLITE_API sqlite3_vfs *sqlite3_vfs_find(const char *zVfs){
+SQLITE_API sqlite3_vfs *sqlite3_vfs_find_orig(const char *zVfs){
sqlite3_vfs *pVfs = 0;
#if SQLITE_THREADSAFE
sqlite3_mutex *mutex;

135
stmt.go
View File

@@ -3,6 +3,8 @@ package sqlite3
import (
"math"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Stmt is a prepared statement object.
@@ -10,8 +12,8 @@ import (
// https://www.sqlite.org/c3ref/stmt.html
type Stmt struct {
c *Conn
handle uint32
err error
handle uint32
}
// Close destroys the prepared statement object.
@@ -27,7 +29,7 @@ func (s *Stmt) Close() error {
r := s.c.call(s.c.api.finalize, uint64(s.handle))
s.handle = 0
return s.c.error(r[0])
return s.c.error(r)
}
// Reset resets the prepared statement object.
@@ -36,7 +38,7 @@ func (s *Stmt) Close() error {
func (s *Stmt) Reset() error {
r := s.c.call(s.c.api.reset, uint64(s.handle))
s.err = nil
return s.c.error(r[0])
return s.c.error(r)
}
// ClearBindings resets all bindings on the prepared statement.
@@ -44,7 +46,7 @@ func (s *Stmt) Reset() error {
// https://www.sqlite.org/c3ref/clear_bindings.html
func (s *Stmt) ClearBindings() error {
r := s.c.call(s.c.api.clearBindings, uint64(s.handle))
return s.c.error(r[0])
return s.c.error(r)
}
// Step evaluates the SQL statement.
@@ -59,13 +61,13 @@ func (s *Stmt) ClearBindings() error {
func (s *Stmt) Step() bool {
s.c.checkInterrupt()
r := s.c.call(s.c.api.step, uint64(s.handle))
if r[0] == _ROW {
if r == _ROW {
return true
}
if r[0] == _DONE {
if r == _DONE {
s.err = nil
} else {
s.err = s.c.error(r[0])
s.err = s.c.error(r)
}
return false
}
@@ -92,7 +94,7 @@ func (s *Stmt) Exec() error {
func (s *Stmt) BindCount() int {
r := s.c.call(s.c.api.bindCount,
uint64(s.handle))
return int(r[0])
return int(r)
}
// BindIndex returns the index of a parameter in the prepared statement
@@ -104,7 +106,7 @@ func (s *Stmt) BindIndex(name string) int {
namePtr := s.c.arena.string(name)
r := s.c.call(s.c.api.bindIndex,
uint64(s.handle), uint64(namePtr))
return int(r[0])
return int(r)
}
// BindName returns the name of a parameter in the prepared statement.
@@ -115,11 +117,11 @@ func (s *Stmt) BindName(param int) string {
r := s.c.call(s.c.api.bindName,
uint64(s.handle), uint64(param))
ptr := uint32(r[0])
ptr := uint32(r)
if ptr == 0 {
return ""
}
return s.c.mem.readString(ptr, _MAX_STRING)
return util.ReadString(s.c.mod, ptr, _MAX_STRING)
}
// BindBool binds a bool to the prepared statement.
@@ -150,7 +152,7 @@ func (s *Stmt) BindInt(param int, value int) error {
func (s *Stmt) BindInt64(param int, value int64) error {
r := s.c.call(s.c.api.bindInteger,
uint64(s.handle), uint64(param), uint64(value))
return s.c.error(r[0])
return s.c.error(r)
}
// BindFloat binds a float64 to the prepared statement.
@@ -160,7 +162,7 @@ func (s *Stmt) BindInt64(param int, value int64) error {
func (s *Stmt) BindFloat(param int, value float64) error {
r := s.c.call(s.c.api.bindFloat,
uint64(s.handle), uint64(param), math.Float64bits(value))
return s.c.error(r[0])
return s.c.error(r)
}
// BindText binds a string to the prepared statement.
@@ -172,8 +174,8 @@ func (s *Stmt) BindText(param int, value string) error {
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
s.c.api.destructor, _UTF8)
return s.c.error(r[0])
uint64(s.c.api.destructor), _UTF8)
return s.c.error(r)
}
// BindBlob binds a []byte to the prepared statement.
@@ -186,8 +188,8 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
r := s.c.call(s.c.api.bindBlob,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
s.c.api.destructor)
return s.c.error(r[0])
uint64(s.c.api.destructor))
return s.c.error(r)
}
// BindZeroBlob binds a zero-filled, length n BLOB to the prepared statement.
@@ -197,7 +199,7 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
func (s *Stmt) BindZeroBlob(param int, n int64) error {
r := s.c.call(s.c.api.bindZeroBlob,
uint64(s.handle), uint64(param), uint64(n))
return s.c.error(r[0])
return s.c.error(r)
}
// BindNull binds a NULL to the prepared statement.
@@ -207,7 +209,7 @@ func (s *Stmt) BindZeroBlob(param int, n int64) error {
func (s *Stmt) BindNull(param int) error {
r := s.c.call(s.c.api.bindNull,
uint64(s.handle), uint64(param))
return s.c.error(r[0])
return s.c.error(r)
}
// BindTime binds a [time.Time] to the prepared statement.
@@ -215,6 +217,9 @@ func (s *Stmt) BindNull(param int) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
if format == TimeFormatDefault {
return s.bindRFC3339Nano(param, value)
}
switch v := format.Encode(value).(type) {
case string:
s.BindText(param, v)
@@ -223,18 +228,32 @@ func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
case float64:
s.BindFloat(param, v)
default:
panic(assertErr())
panic(util.AssertErr())
}
return nil
}
func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
const maxlen = uint64(len(time.RFC3339Nano))
ptr := s.c.new(maxlen)
buf := util.View(s.c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(buf)),
uint64(s.c.api.destructor), _UTF8)
return s.c.error(r)
}
// ColumnCount returns the number of columns in a result set.
//
// https://www.sqlite.org/c3ref/column_count.html
func (s *Stmt) ColumnCount() int {
r := s.c.call(s.c.api.columnCount,
uint64(s.handle))
return int(r[0])
return int(r)
}
// ColumnName returns the name of the result column.
@@ -245,11 +264,11 @@ func (s *Stmt) ColumnName(col int) string {
r := s.c.call(s.c.api.columnName,
uint64(s.handle), uint64(col))
ptr := uint32(r[0])
ptr := uint32(r)
if ptr == 0 {
panic(oomErr)
panic(util.OOMErr)
}
return s.c.mem.readString(ptr, _MAX_STRING)
return util.ReadString(s.c.mod, ptr, _MAX_STRING)
}
// ColumnType returns the initial [Datatype] of the result column.
@@ -259,7 +278,7 @@ func (s *Stmt) ColumnName(col int) string {
func (s *Stmt) ColumnType(col int) Datatype {
r := s.c.call(s.c.api.columnType,
uint64(s.handle), uint64(col))
return Datatype(r[0])
return Datatype(r)
}
// ColumnBool returns the value of the result column as a bool.
@@ -291,7 +310,7 @@ func (s *Stmt) ColumnInt(col int) int {
func (s *Stmt) ColumnInt64(col int) int64 {
r := s.c.call(s.c.api.columnInteger,
uint64(s.handle), uint64(col))
return int64(r[0])
return int64(r)
}
// ColumnFloat returns the value of the result column as a float64.
@@ -301,7 +320,7 @@ func (s *Stmt) ColumnInt64(col int) int64 {
func (s *Stmt) ColumnFloat(col int) float64 {
r := s.c.call(s.c.api.columnFloat,
uint64(s.handle), uint64(col))
return math.Float64frombits(r[0])
return math.Float64frombits(r)
}
// ColumnTime returns the value of the result column as a [time.Time].
@@ -320,7 +339,7 @@ func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time {
case NULL:
return time.Time{}
default:
panic(assertErr())
panic(util.AssertErr())
}
t, err := format.Decode(v)
if err != nil {
@@ -334,21 +353,7 @@ func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnText(col int) string {
r := s.c.call(s.c.api.columnText,
uint64(s.handle), uint64(col))
ptr := uint32(r[0])
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r[0])
return ""
}
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
mem := s.c.mem.view(ptr, r[0])
return string(mem)
return string(s.ColumnRawText(col))
}
// ColumnBlob appends to buf and returns
@@ -357,21 +362,53 @@ func (s *Stmt) ColumnText(col int) string {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
r := s.c.call(s.c.api.columnBlob,
return append(buf, s.ColumnRawBlob(col)...)
}
// ColumnRawText returns the value of the result column as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnRawText(col int) []byte {
r := s.c.call(s.c.api.columnText,
uint64(s.handle), uint64(col))
ptr := uint32(r[0])
ptr := uint32(r)
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r[0])
return buf[0:0]
s.err = s.c.error(r)
return nil
}
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
mem := s.c.mem.view(ptr, r[0])
return append(buf[0:0], mem...)
return util.View(s.c.mod, ptr, r)
}
// ColumnRawBlob returns the value of the result column as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnRawBlob(col int) []byte {
r := s.c.call(s.c.api.columnBlob,
uint64(s.handle), uint64(col))
ptr := uint32(r)
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r)
return nil
}
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
return util.View(s.c.mod, ptr, r)
}
// Return true if stmt is an empty SQL statement.

View File

@@ -124,4 +124,46 @@ func TestBackup(t *testing.T) {
t.Fatal(err)
}
}()
func() { // Incremental.
db, err := sqlite3.Open(backupName)
if err != nil {
t.Fatal(err)
}
defer db.Close()
b, err := db.BackupInit("main", ":memory:")
if err != nil {
t.Fatal(err)
}
defer b.Close()
done, err := b.Step(1)
if done {
t.Error("want false")
}
if err != nil {
t.Error(err)
}
n := b.Remaining()
if n != 1 {
t.Errorf("got %d", n)
}
n = b.PageCount()
if n != 2 {
t.Errorf("got %d", n)
}
err = b.Close()
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}()
}

View File

@@ -5,6 +5,7 @@ import (
"crypto/rand"
"errors"
"fmt"
"hash/adler32"
"io"
"testing"
@@ -48,17 +49,17 @@ func TestBlob(t *testing.T) {
t.Fatal(err)
}
_, err = io.Copy(blob, bytes.NewReader(data[:size/2]))
_, err = blob.Write(data[:size/2])
if err != nil {
t.Fatal(err)
}
_, err = io.Copy(blob, bytes.NewReader(data[:]))
if !errors.Is(err, sqlite3.ERROR) {
t.Fatal("want error")
n, err := blob.Write(data[:])
if n != 0 || !errors.Is(err, sqlite3.ERROR) {
t.Fatalf("got (%d, %v), want (0, ERROR)", n, err)
}
_, err = io.Copy(blob, bytes.NewReader(data[size/2:size]))
_, err = blob.Write(data[size/2 : size])
if err != nil {
t.Fatal(err)
}
@@ -87,6 +88,126 @@ func TestBlob(t *testing.T) {
}
}
func TestBlob_large(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1000000))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
size := blob.Size()
if size != 1000000 {
t.Errorf("got %d, want 1000000", size)
}
hash := adler32.New()
_, err = io.CopyN(blob, io.TeeReader(rand.Reader, hash), 1000000)
if err != nil {
t.Fatal(err)
}
_, err = blob.Seek(0, io.SeekStart)
if err != nil {
t.Fatal(err)
}
want := hash.Sum32()
hash.Reset()
_, err = io.Copy(hash, blob)
if err != nil {
t.Fatal(err)
}
if got := hash.Sum32(); got != want {
t.Fatalf("got %d, want %d", got, want)
}
if err := blob.Close(); err != nil {
t.Fatal(err)
}
if err := db.Close(); err != nil {
t.Fatal(err)
}
}
func TestBlob_overflow(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
n, err := blob.ReadFrom(rand.Reader)
if n != 1024 || !errors.Is(err, sqlite3.ERROR) {
t.Fatalf("got (%d, %v), want (0, ERROR)", n, err)
}
n, err = blob.ReadFrom(rand.Reader)
if n != 0 || !errors.Is(err, sqlite3.ERROR) {
t.Fatalf("got (%d, %v), want (0, ERROR)", n, err)
}
_, err = blob.Seek(-128, io.SeekEnd)
if err != nil {
t.Fatal(err)
}
n, err = blob.WriteTo(io.Discard)
if n != 128 || err != nil {
t.Fatalf("got (%d, %v), want (128, nil)", n, err)
}
n, err = blob.WriteTo(io.Discard)
if n != 0 || err != nil {
t.Fatalf("got (%d, %v), want (0, nil)", n, err)
}
if err := blob.Close(); err != nil {
t.Fatal(err)
}
if err := db.Close(); err != nil {
t.Fatal(err)
}
}
func TestBlob_invalid(t *testing.T) {
t.Parallel()

View File

@@ -126,7 +126,7 @@ func testTxQuery(t params) {
if r.Err() != nil {
t.Fatal(err)
}
t.Fatal("expected one rows")
t.Fatal("expected one row")
}
var name string

View File

@@ -3,6 +3,9 @@ package tests
import (
"context"
"errors"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
@@ -13,7 +16,7 @@ import (
func TestConn_Open_dir(t *testing.T) {
t.Parallel()
_, err := sqlite3.Open(".")
_, err := sqlite3.OpenFlags(".", 0)
if err == nil {
t.Fatal("want error")
}
@@ -22,6 +25,55 @@ func TestConn_Open_dir(t *testing.T) {
}
}
func TestConn_Open_notfound(t *testing.T) {
t.Parallel()
_, err := sqlite3.OpenFlags("test.db", sqlite3.OPEN_READONLY)
if err == nil {
t.Fatal("want error")
}
if !errors.Is(err, sqlite3.CANTOPEN) {
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
}
}
func TestConn_Open_modeof(t *testing.T) {
t.Parallel()
dir := t.TempDir()
file := filepath.Join(dir, "test.db")
mode := filepath.Join(dir, "modeof.txt")
fd, err := os.OpenFile(mode, os.O_CREATE, 0624)
if err != nil {
t.Fatal(err)
}
fi, err := fd.Stat()
if err != nil {
t.Fatal(err)
}
fd.Close()
db, err := sqlite3.Open("file:" + file + "?modeof=" + mode)
if err != nil {
t.Fatal(err)
}
di, err := os.Stat(file)
if err != nil {
t.Fatal(err)
}
db.Close()
if di.Mode() != fi.Mode() {
t.Errorf("got %v, want %v", di.Mode(), fi.Mode())
}
_, err = sqlite3.Open("file:" + file + "?modeof=" + mode + "2")
if err == nil {
t.Fatal("want error")
}
}
func TestConn_Close(t *testing.T) {
var conn *sqlite3.Conn
conn.Close()
@@ -58,6 +110,41 @@ func TestConn_Close_BUSY(t *testing.T) {
}
}
func TestConn_Pragma(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open("file::memory:?_pragma=busy_timeout(1000)")
if err != nil {
t.Fatal(err)
}
defer db.Close()
got, err := db.Pragma("busy_timeout")
if err != nil {
t.Fatal(err)
}
want := []string{"1000"}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
var serr *sqlite3.Error
_, err = db.Pragma("+")
if err == nil {
t.Error("want: error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: near "+": syntax error` {
t.Error("got message:", got)
}
}
func TestConn_SetInterrupt(t *testing.T) {
t.Parallel()
@@ -202,59 +289,3 @@ func TestConn_Prepare_invalid(t *testing.T) {
t.Error("got message:", got)
}
}
func TestConn_MustPrepare_empty(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(``)
t.Error("want panic")
}
func TestConn_MustPrepare_tail(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(`SELECT 1; -- HERE`)
t.Error("want panic")
}
func TestConn_MustPrepare_invalid(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(`SELECT`)
t.Error("want panic")
}
func TestConn_Pragma(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.Pragma("encoding=''")
t.Error("want panic")
}

View File

@@ -6,19 +6,24 @@ import (
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
func TestDB_memory(t *testing.T) {
t.Parallel()
testDB(t, ":memory:")
}
func TestDB_file(t *testing.T) {
t.Parallel()
testDB(t, filepath.Join(t.TempDir(), "test.db"))
}
func testDB(t *testing.T, name string) {
t.Parallel()
func TestDB_vfs(t *testing.T) {
testDB(t, "file:test.db?vfs=memdb")
}
func testDB(t *testing.T, name string) {
db, err := sqlite3.Open(name)
if err != nil {
t.Fatal(err)

View File

@@ -27,18 +27,32 @@ func TestDriver(t *testing.T) {
}
defer conn.Close()
_, err = conn.ExecContext(ctx,
res, err := conn.ExecContext(ctx,
`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
changes, err := res.RowsAffected()
if err != nil {
t.Fatal(err)
}
if changes != 0 {
t.Errorf("got %d want 0", changes)
}
id, err := res.LastInsertId()
if err != nil {
t.Fatal(err)
}
if id != 0 {
t.Errorf("got %d want 0", changes)
}
res, err := conn.ExecContext(ctx,
res, err = conn.ExecContext(ctx,
`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}
changes, err := res.RowsAffected()
changes, err = res.RowsAffected()
if err != nil {
t.Fatal(err)
}

77
tests/ext_test.go Normal file
View File

@@ -0,0 +1,77 @@
package tests
import (
"testing"
"github.com/ncruces/go-sqlite3"
)
func Test_base64(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
// base64
stmt, _, err := db.Prepare(`SELECT base64('TWFueSBoYW5kcyBtYWtlIGxpZ2h0IHdvcmsu')`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal("expected one row")
}
if got := stmt.ColumnText(0); got != "Many hands make light work." {
t.Errorf("got %q", got)
}
}
func Test_decimal(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT decimal_add(decimal('0.1'), decimal('0.2')) = decimal('0.3')`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal("expected one row")
}
if !stmt.ColumnBool(0) {
t.Error("want true")
}
}
func Test_uint(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT 'z2' < 'z11' COLLATE UINT`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal("expected one row")
}
if !stmt.ColumnBool(0) {
t.Error("want true")
}
}

View File

@@ -1,26 +0,0 @@
#!/usr/bin/env bash
set -eo pipefail
cd -P -- "$(dirname -- "$0")"
if [ ! -f "mptest.c" ]; then
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/mptest.c"
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/config01.test"
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/config02.test"
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/crash01.test"
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/crash02.subtest"
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/multiwrite01.test"
fi
zig cc --target=wasm32-wasi -flto -g0 -Os \
-o mptest.wasm main.c test.c \
-I../../../sqlite3 \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-D_HAVE_SQLITE_CONFIG_H \
-DSQLITE_DEFAULT_SYNCHRONOUS=0 \
-DSQLITE_DEFAULT_LOCKING_MODE=0 \
-DHAVE_USLEEP -DSQLITE_NO_SYNC \
-DSQLITE_THREADSAFE=0 -DSQLITE_OMIT_LOAD_EXTENSION \
-D_WASI_EMULATED_GETPID -lwasi-emulated-getpid

View File

@@ -1,22 +0,0 @@
#include <stdbool.h>
#include <stddef.h>
#include "os.c"
#include "qsort.c"
#include "sqlite3.c"
sqlite3_destructor_type malloc_destructor = &free;
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);
void __attribute__((constructor)) premain() { sqlite3_initialize(); }
int sqlite3_enable_load_extension(sqlite3 *db, int onoff) { return SQLITE_OK; }
void *sqlite3_trace(sqlite3 *db, void (*xTrace)(void *, const char *),
void *pArg) {
return NULL;
}
int sqlite3_os_init() {
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
}

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3960b873a7dab969a66f7859d491cec0dd4e6c0c9f83eab449fb15ec5ebdfd8f
size 1077281

View File

@@ -1,5 +0,0 @@
#define unlink dont_unlink
#include "mptest.c"
int dont_unlink(const char *pathname) { return 0; }

View File

@@ -11,6 +11,7 @@ import (
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
func TestParallel(t *testing.T) {
@@ -21,7 +22,29 @@ func TestParallel(t *testing.T) {
iter = 5000
}
name := filepath.Join(t.TempDir(), "test.db")
name := "file:" +
filepath.Join(t.TempDir(), "test.db") +
"?_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
testParallel(t, name, iter)
testIntegrity(t, name)
}
func TestMemory(t *testing.T) {
var iter int
if testing.Short() {
iter = 1000
} else {
iter = 5000
}
name := "file:/test.db?vfs=memdb" +
"&_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(memory)" +
"&_pragma=synchronous(off)"
testParallel(t, name, iter)
testIntegrity(t, name)
}
@@ -31,8 +54,14 @@ func TestMultiProcess(t *testing.T) {
t.Skip("skipping in short mode")
}
name := filepath.Join(t.TempDir(), "test.db")
t.Setenv("TestMultiProcess_dbname", name)
file := filepath.Join(t.TempDir(), "test.db")
t.Setenv("TestMultiProcess_dbfile", file)
name := "file:" + file +
"?_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
cmd := exec.Command("go", "test", "-v", "-run", "TestChildProcess")
out, err := cmd.StdoutPipe()
@@ -57,11 +86,17 @@ func TestMultiProcess(t *testing.T) {
}
func TestChildProcess(t *testing.T) {
name := os.Getenv("TestMultiProcess_dbname")
if name == "" || testing.Short() {
file := os.Getenv("TestMultiProcess_dbfile")
if file == "" || testing.Short() {
t.SkipNow()
}
name := "file:" + file +
"?_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
testParallel(t, name, 1000)
}
@@ -73,16 +108,6 @@ func testParallel(t *testing.T, name string, n int) {
}
defer db.Close()
err = db.Exec(`
PRAGMA busy_timeout=10000;
PRAGMA synchronous=off;
PRAGMA locking_mode=normal;
PRAGMA journal_mode=truncate;
`)
if err != nil {
return err
}
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
return err

View File

@@ -40,7 +40,7 @@ func TestTimeFormat_Decode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
reftime := time.Date(2000, 1, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt sqlite3.TimeFormat
@@ -89,14 +89,14 @@ func TestTimeFormat_Decode(t *testing.T) {
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, false},
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, false},
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", reftime, 0, false},
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, true},
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
{sqlite3.TimeFormat9, "04:23:19.12-04:00", reftime, 0, false},
{sqlite3.TimeFormat9, "08:23:19.12", reftime, 0, false},
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, false},
{sqlite3.TimeFormat3, false, time.Time{}, 0, true},
{sqlite3.TimeFormat9, false, time.Time{}, 0, true},
@@ -118,3 +118,52 @@ func TestTimeFormat_Decode(t *testing.T) {
})
}
}
func TestDB_timeCollation(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS times (tstamp COLLATE TIME)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`INSERT INTO times VALUES (?), (?), (?)`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
stmt.BindTime(1, time.Unix(0, 0).UTC(), sqlite3.TimeFormatDefault)
stmt.BindTime(2, time.Unix(0, -1).UTC(), sqlite3.TimeFormatDefault)
stmt.BindTime(3, time.Unix(0, +1).UTC(), sqlite3.TimeFormatDefault)
stmt.Step()
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
stmt, _, err = db.Prepare(`SELECT tstamp FROM times ORDER BY tstamp`)
if err != nil {
t.Fatal(err)
}
var t0 time.Time
for stmt.Step() {
t1 := stmt.ColumnTime(0, sqlite3.TimeFormatAuto)
if t0.After(t1) {
t.Errorf("got %v after %v", t0, t1)
}
t0 = t1
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -185,10 +185,10 @@ func TestConn_Transaction_interrupt(t *testing.T) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
var nilErr error
tx.End(&nilErr)
if !errors.Is(nilErr, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", nilErr)
err = nil
tx.End(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
db.SetInterrupt(context.Background())
@@ -210,6 +210,33 @@ func TestConn_Transaction_interrupt(t *testing.T) {
}
}
func TestConn_Transaction_interrupted(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
cancel()
tx := db.Begin()
err = tx.Commit()
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = nil
tx.End(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
}
func TestConn_Transaction_rollback(t *testing.T) {
t.Parallel()
@@ -286,7 +313,7 @@ func TestConn_Savepoint_exec(t *testing.T) {
}
insert := func(succeed bool) (err error) {
defer db.Savepoint()(&err)
defer db.Savepoint().Release(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
@@ -344,7 +371,7 @@ func TestConn_Savepoint_panic(t *testing.T) {
}
panics := func() (err error) {
defer db.Savepoint()(&err)
defer db.Savepoint().Release(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
@@ -395,12 +422,12 @@ func TestConn_Savepoint_interrupt(t *testing.T) {
t.Fatal(err)
}
release := db.Savepoint()
savept := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
}
release(&err)
savept.Release(&err)
if err != nil {
t.Fatal(err)
}
@@ -408,19 +435,19 @@ func TestConn_Savepoint_interrupt(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
release1 := db.Savepoint()
savept1 := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (2)`)
if err != nil {
t.Fatal(err)
}
release2 := db.Savepoint()
savept2 := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (3)`)
if err != nil {
t.Fatal(err)
}
cancel()
db.Savepoint()(&err)
db.Savepoint().Release(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
@@ -431,15 +458,15 @@ func TestConn_Savepoint_interrupt(t *testing.T) {
}
err = context.Canceled
release2(&err)
savept2.Release(&err)
if err != context.Canceled {
t.Fatal(err)
}
var nilErr error
release1(&nilErr)
if !errors.Is(nilErr, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", nilErr)
err = nil
savept1.Release(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
db.SetInterrupt(context.Background())
@@ -475,7 +502,7 @@ func TestConn_Savepoint_rollback(t *testing.T) {
t.Fatal(err)
}
release := db.Savepoint()
savept := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
@@ -484,7 +511,7 @@ func TestConn_Savepoint_rollback(t *testing.T) {
if err != nil {
t.Fatal(err)
}
release(&err)
savept.Release(&err)
if err != nil {
t.Fatal(err)
}

36
tests/vfs_test.go Normal file
View File

@@ -0,0 +1,36 @@
package tests
import (
"errors"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/vfs/memdb"
"github.com/ncruces/go-sqlite3/vfs/readervfs"
)
func TestMemoryVFS_Open_notfound(t *testing.T) {
memdb.Delete("demo.db")
_, err := sqlite3.Open("file:/demo.db?vfs=memdb&mode=ro")
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.CANTOPEN) {
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
}
}
func TestReaderVFS_Open_notfound(t *testing.T) {
readervfs.Delete("demo.db")
_, err := sqlite3.Open("file:demo.db?vfs=reader")
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.CANTOPEN) {
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
}
}

40
time.go
View File

@@ -6,6 +6,7 @@ import (
"strings"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/julianday"
)
@@ -62,13 +63,18 @@ const (
// [TimeFormatDefault] and [TimeFormatAuto] encode using [time.RFC3339Nano],
// with nanosecond accuracy, and preserving any timezone offset.
//
// This is the format used by the database/sql driver:
// [database/sql.Row.Scan] is able to decode as [time.Time]
// This is the format used by the [database/sql] driver:
// [database/sql.Row.Scan] will decode as [time.Time]
// values encoded with [time.RFC3339Nano].
//
// Time values encoded with [time.RFC3339Nano] cannot be sorted as strings
// to produce a time-ordered sequence.
// Use [TimeFormat7] for time-ordered encoding.
//
// Assuming that the time zones of the time values are the same (e.g., all in UTC),
// and expressed using the same string (e.g., all "Z" or all "+00:00"),
// use the TIME [collating sequence] to produce a time-ordered sequence.
//
// Otherwise, use [TimeFormat7] for time-ordered encoding.
//
// Formats [TimeFormat1] through [TimeFormat10]
// convert time values to UTC before encoding.
@@ -78,6 +84,8 @@ const (
// or an int64 for the other numeric formats.
//
// https://www.sqlite.org/lang_datefunc.html
//
// [collating sequence]: https://www.sqlite.org/datatype3.html#collating_sequences
func (f TimeFormat) Encode(t time.Time) any {
switch f {
// Numeric formats
@@ -123,9 +131,9 @@ func (f TimeFormat) Encode(t time.Time) any {
// [TimeFormatAuto] implements (and extends) the SQLite auto modifier.
// Julian day numbers are safe to use for historical dates,
// from 4712BC through 9999AD.
// Unix timestamps (expressed in seconds, milliseconds, microseconds, or nanoseconds),
// are safe to use for current events, from 1980 through at least 2260.
// Unix timestamps before 1980 may be misinterpreted as julian day numbers,
// Unix timestamps (expressed in seconds, milliseconds, microseconds, or nanoseconds)
// are safe to use for current events, from at least 1980 through at least 2260.
// Unix timestamps before 1980 and after 9999 may be misinterpreted as julian day numbers,
// or have the wrong time unit.
//
// https://www.sqlite.org/lang_datefunc.html
@@ -141,7 +149,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return julianday.Time(v, 0), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case TimeFormatUnix, TimeFormatUnixFrac:
@@ -160,7 +168,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return time.Unix(v, 0), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case TimeFormatUnixMilli:
@@ -177,7 +185,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return time.UnixMilli(int64(v)), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case TimeFormatUnixMicro:
@@ -194,14 +202,14 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return time.UnixMicro(int64(v)), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case TimeFormatUnixNano:
if s, ok := v.(string); ok {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
v = i
}
@@ -211,7 +219,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return time.Unix(0, int64(v)), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
// Special formats
@@ -281,7 +289,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
return TimeFormatUnixNano.Decode(v)
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case
@@ -293,7 +301,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
TimeFormat7, TimeFormat7TZ:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
return f.parseRelaxed(s)
@@ -303,7 +311,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
TimeFormat10, TimeFormat10TZ:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
t, err := f.parseRelaxed(s)
return t.AddDate(2000, 0, 0), err
@@ -311,7 +319,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
default:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
if f == "" {
f = time.RFC3339Nano

154
tx.go
View File

@@ -4,10 +4,14 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"runtime"
"strconv"
)
// Tx is an in-progress database transaction.
//
// https://www.sqlite.org/lang_transaction.html
type Tx struct {
c *Conn
}
@@ -16,8 +20,9 @@ type Tx struct {
//
// https://www.sqlite.org/lang_transaction.html
func (c *Conn) Begin() Tx {
err := c.Exec(`BEGIN DEFERRED`)
if err != nil && !errors.Is(err, INTERRUPT) {
// BEGIN even if interrupted.
err := c.txExecInterrupted(`BEGIN DEFERRED`)
if err != nil {
panic(err)
}
return Tx{c}
@@ -64,21 +69,22 @@ func (tx Tx) End(errp *error) {
defer panic(recovered)
}
if tx.c.GetAutocommit() {
// There is nothing to commit/rollback.
return
}
if *errp == nil && recovered == nil {
// Success path.
if tx.c.GetAutocommit() { // There is nothing to commit.
return
}
*errp = tx.Commit()
if *errp == nil {
return
}
// Possible interrupt, fall through to the error path.
// Fall through to the error path.
}
// Error path.
if tx.c.GetAutocommit() { // There is nothing to rollback.
return
}
err := tx.Rollback()
if err != nil {
panic(err)
@@ -92,33 +98,28 @@ func (tx Tx) Commit() error {
return tx.c.Exec(`COMMIT`)
}
// Rollback rollsback the transaction.
// Rollback rolls back the transaction,
// even if the connection has been interrupted.
//
// https://www.sqlite.org/lang_transaction.html
func (tx Tx) Rollback() error {
// ROLLBACK even if the connection has been interrupted.
old := tx.c.SetInterrupt(context.Background())
defer tx.c.SetInterrupt(old)
return tx.c.Exec(`ROLLBACK`)
return tx.c.txExecInterrupted(`ROLLBACK`)
}
// Savepoint creates a named SQLite transaction using SAVEPOINT.
//
// On success Savepoint returns a release func that will call either
// RELEASE or ROLLBACK depending on whether the parameter *error
// points to a nil or non-nil error.
//
// This is meant to be deferred:
//
// func doWork(conn *sqlite3.Conn) (err error) {
// defer conn.Savepoint()(&err)
//
// // ... do work in the transaction
// }
// Savepoint is a marker within a transaction
// that allows for partial rollback.
//
// https://www.sqlite.org/lang_savepoint.html
func (c *Conn) Savepoint() (release func(*error)) {
name := "sqlite3.Savepoint" // names can be reused
type Savepoint struct {
c *Conn
name string
}
// Savepoint establishes a new transaction savepoint.
//
// https://www.sqlite.org/lang_savepoint.html
func (c *Conn) Savepoint() Savepoint {
name := "sqlite3.Savepoint"
var pc [1]uintptr
if n := runtime.Callers(2, pc[:]); n > 0 {
frames := runtime.CallersFrames(pc[:n])
@@ -127,52 +128,75 @@ func (c *Conn) Savepoint() (release func(*error)) {
name = frame.Function
}
}
// Names can be reused; this makes catching bugs more likely.
name += "#" + strconv.Itoa(int(rand.Int31()))
err := c.Exec(fmt.Sprintf("SAVEPOINT %q;", name))
err := c.txExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name))
if err != nil {
if errors.Is(err, INTERRUPT) {
return func(errp *error) {
if *errp == nil {
*errp = err
}
}
}
panic(err)
}
return Savepoint{c: c, name: name}
}
return func(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
// Release releases the savepoint rolling back any changes
// if *error points to a non-nil error.
//
// This is meant to be deferred:
//
// func doWork(conn *sqlite3.Conn) (err error) {
// savept := conn.Savepoint()
// defer savept.Release(&err)
//
// // ... do work in the transaction
// }
func (s Savepoint) Release(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
if c.GetAutocommit() {
// There is nothing to commit/rollback.
if *errp == nil && recovered == nil {
// Success path.
if s.c.GetAutocommit() { // There is nothing to commit.
return
}
if *errp == nil && recovered == nil {
// Success path.
// RELEASE the savepoint successfully.
*errp = c.Exec(fmt.Sprintf("RELEASE %q;", name))
if *errp == nil {
return
}
// Possible interrupt, fall through to the error path.
*errp = s.c.Exec(fmt.Sprintf("RELEASE %q;", s.name))
if *errp == nil {
return
}
// Fall through to the error path.
}
// Error path.
// Always ROLLBACK even if the connection has been interrupted.
old := c.SetInterrupt(context.Background())
defer c.SetInterrupt(old)
err := c.Exec(fmt.Sprintf("ROLLBACK TO %q;", name))
if err != nil {
panic(err)
}
err = c.Exec(fmt.Sprintf("RELEASE %q;", name))
if err != nil {
panic(err)
}
// Error path.
if s.c.GetAutocommit() { // There is nothing to rollback.
return
}
// ROLLBACK and RELEASE even if interrupted.
err := s.c.txExecInterrupted(fmt.Sprintf(`
ROLLBACK TO %[1]q;
RELEASE %[1]q;
`, s.name))
if err != nil {
panic(err)
}
}
// Rollback rolls the transaction back to the savepoint,
// even if the connection has been interrupted.
// Rollback does not release the savepoint.
//
// https://www.sqlite.org/lang_transaction.html
func (s Savepoint) Rollback() error {
// ROLLBACK even if interrupted.
return s.c.txExecInterrupted(fmt.Sprintf("ROLLBACK TO %q;", s.name))
}
func (c *Conn) txExecInterrupted(sql string) error {
err := c.Exec(sql)
if errors.Is(err, INTERRUPT) {
old := c.SetInterrupt(context.Background())
defer c.SetInterrupt(old)
err = c.Exec(sql)
}
return err
}

348
vfs.go
View File

@@ -1,348 +0,0 @@
package sqlite3
import (
"context"
"crypto/rand"
"errors"
"io"
"io/fs"
"os"
"path/filepath"
"runtime"
"syscall"
"time"
"github.com/ncruces/julianday"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/sys"
)
func vfsInstantiate(ctx context.Context, r wazero.Runtime) {
wasi := r.NewHostModuleBuilder("wasi_snapshot_preview1")
wasi.NewFunctionBuilder().WithFunc(vfsExit).Export("proc_exit")
_, err := wasi.Instantiate(ctx)
if err != nil {
panic(err)
}
env := vfsNewEnvModuleBuilder(r)
_, err = env.Instantiate(ctx)
if err != nil {
panic(err)
}
}
func vfsNewEnvModuleBuilder(r wazero.Runtime) wazero.HostModuleBuilder {
env := r.NewHostModuleBuilder("env")
env.NewFunctionBuilder().WithFunc(vfsLocaltime).Export("os_localtime")
env.NewFunctionBuilder().WithFunc(vfsRandomness).Export("os_randomness")
env.NewFunctionBuilder().WithFunc(vfsSleep).Export("os_sleep")
env.NewFunctionBuilder().WithFunc(vfsCurrentTime).Export("os_current_time")
env.NewFunctionBuilder().WithFunc(vfsCurrentTime64).Export("os_current_time_64")
env.NewFunctionBuilder().WithFunc(vfsFullPathname).Export("os_full_pathname")
env.NewFunctionBuilder().WithFunc(vfsDelete).Export("os_delete")
env.NewFunctionBuilder().WithFunc(vfsAccess).Export("os_access")
env.NewFunctionBuilder().WithFunc(vfsOpen).Export("os_open")
env.NewFunctionBuilder().WithFunc(vfsClose).Export("os_close")
env.NewFunctionBuilder().WithFunc(vfsRead).Export("os_read")
env.NewFunctionBuilder().WithFunc(vfsWrite).Export("os_write")
env.NewFunctionBuilder().WithFunc(vfsTruncate).Export("os_truncate")
env.NewFunctionBuilder().WithFunc(vfsSync).Export("os_sync")
env.NewFunctionBuilder().WithFunc(vfsFileSize).Export("os_file_size")
env.NewFunctionBuilder().WithFunc(vfsLock).Export("os_lock")
env.NewFunctionBuilder().WithFunc(vfsUnlock).Export("os_unlock")
env.NewFunctionBuilder().WithFunc(vfsCheckReservedLock).Export("os_check_reserved_lock")
env.NewFunctionBuilder().WithFunc(vfsFileControl).Export("os_file_control")
return env
}
// Poor man's namespaces.
const (
vfsOS vfsOSMethods = false
vfsFile vfsFileMethods = false
)
type (
vfsOSMethods bool
vfsFileMethods bool
)
type vfsKey struct{}
type vfsState struct {
files []*os.File
}
func vfsContext(ctx context.Context) (context.Context, io.Closer) {
vfs := &vfsState{}
return context.WithValue(ctx, vfsKey{}, vfs), vfs
}
func (vfs *vfsState) Close() error {
for _, f := range vfs.files {
if f != nil {
f.Close()
}
}
vfs.files = nil
return nil
}
func vfsExit(ctx context.Context, mod api.Module, exitCode uint32) {
// Ensure other callers see the exit code.
_ = mod.CloseWithExitCode(ctx, exitCode)
// Prevent any code from executing after this function.
panic(sys.NewExitError(mod.Name(), exitCode))
}
func vfsLocaltime(ctx context.Context, mod api.Module, t uint64, pTm uint32) uint32 {
tm := time.Unix(int64(t), 0)
var isdst int
if tm.IsDST() {
isdst = 1
}
// https://pubs.opengroup.org/onlinepubs/7908799/xsh/time.h.html
mem := memory{mod}
mem.writeUint32(pTm+0*ptrlen, uint32(tm.Second()))
mem.writeUint32(pTm+1*ptrlen, uint32(tm.Minute()))
mem.writeUint32(pTm+2*ptrlen, uint32(tm.Hour()))
mem.writeUint32(pTm+3*ptrlen, uint32(tm.Day()))
mem.writeUint32(pTm+4*ptrlen, uint32(tm.Month()-time.January))
mem.writeUint32(pTm+5*ptrlen, uint32(tm.Year()-1900))
mem.writeUint32(pTm+6*ptrlen, uint32(tm.Weekday()-time.Sunday))
mem.writeUint32(pTm+7*ptrlen, uint32(tm.YearDay()-1))
mem.writeUint32(pTm+8*ptrlen, uint32(isdst))
return _OK
}
func vfsRandomness(ctx context.Context, mod api.Module, pVfs, nByte, zByte uint32) uint32 {
mem := memory{mod}.view(zByte, uint64(nByte))
n, _ := rand.Reader.Read(mem)
return uint32(n)
}
func vfsSleep(ctx context.Context, pVfs, nMicro uint32) uint32 {
time.Sleep(time.Duration(nMicro) * time.Microsecond)
return _OK
}
func vfsCurrentTime(ctx context.Context, mod api.Module, pVfs, prNow uint32) uint32 {
day := julianday.Float(time.Now())
memory{mod}.writeFloat64(prNow, day)
return _OK
}
func vfsCurrentTime64(ctx context.Context, mod api.Module, pVfs, piNow uint32) uint32 {
day, nsec := julianday.Date(time.Now())
msec := day*86_400_000 + nsec/1_000_000
memory{mod}.writeUint64(piNow, uint64(msec))
return _OK
}
func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull, zFull uint32) uint32 {
rel := memory{mod}.readString(zRelative, _MAX_PATHNAME)
abs, err := filepath.Abs(rel)
if err != nil {
return uint32(IOERR)
}
// Consider either using [filepath.EvalSymlinks] to canonicalize the path (as the Unix VFS does).
// Or using [os.Readlink] to resolve a symbolic link (as the Unix VFS did).
// This might be buggy on Windows (the Windows VFS doesn't try).
size := uint64(len(abs) + 1)
if size > uint64(nFull) {
return uint32(CANTOPEN_FULLPATH)
}
mem := memory{mod}.view(zFull, size)
mem[len(abs)] = 0
copy(mem, abs)
return _OK
}
func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32) uint32 {
path := memory{mod}.readString(zPath, _MAX_PATHNAME)
err := os.Remove(path)
if errors.Is(err, fs.ErrNotExist) {
return _OK
}
if err != nil {
return uint32(IOERR_DELETE)
}
if runtime.GOOS != "windows" && syncDir != 0 {
f, err := os.Open(filepath.Dir(path))
if err == nil {
err = f.Sync()
f.Close()
}
if err != nil {
return uint32(IOERR_DELETE)
}
}
return _OK
}
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags _AccessFlag, pResOut uint32) uint32 {
// Consider using [syscall.Access] for [ACCESS_READWRITE]/[ACCESS_READ]
// (as the Unix VFS does).
path := memory{mod}.readString(zPath, _MAX_PATHNAME)
fi, err := os.Stat(path)
var res uint32
switch {
case flags == _ACCESS_EXISTS:
switch {
case err == nil:
res = 1
case errors.Is(err, fs.ErrNotExist):
res = 0
default:
return uint32(IOERR_ACCESS)
}
case err == nil:
var want fs.FileMode = syscall.S_IRUSR
if flags == _ACCESS_READWRITE {
want |= syscall.S_IWUSR
}
if fi.IsDir() {
want |= syscall.S_IXUSR
}
if fi.Mode()&want == want {
res = 1
} else {
res = 0
}
case errors.Is(err, fs.ErrPermission):
res = 0
default:
return uint32(IOERR_ACCESS)
}
memory{mod}.writeUint32(pResOut, res)
return _OK
}
func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, flags OpenFlag, pOutFlags uint32) uint32 {
var oflags int
if flags&OPEN_EXCLUSIVE != 0 {
oflags |= os.O_EXCL
}
if flags&OPEN_CREATE != 0 {
oflags |= os.O_CREATE
}
if flags&OPEN_READONLY != 0 {
oflags |= os.O_RDONLY
}
if flags&OPEN_READWRITE != 0 {
oflags |= os.O_RDWR
}
var err error
var file *os.File
if zName == 0 {
file, err = os.CreateTemp("", "*.db")
} else {
name := memory{mod}.readString(zName, _MAX_PATHNAME)
file, err = os.OpenFile(name, oflags, 0600)
}
if err != nil {
return uint32(CANTOPEN)
}
if flags&OPEN_DELETEONCLOSE != 0 {
vfsOS.DeleteOnClose(file)
}
vfsFile.Open(ctx, mod, pFile, file)
if pOutFlags != 0 {
memory{mod}.writeUint32(pOutFlags, uint32(flags))
}
return _OK
}
func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 {
err := vfsFile.Close(ctx, mod, pFile)
if err != nil {
return uint32(IOERR_CLOSE)
}
return _OK
}
func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 {
buf := memory{mod}.view(zBuf, uint64(iAmt))
file := vfsFile.GetOS(ctx, mod, pFile)
n, err := file.ReadAt(buf, int64(iOfst))
if n == int(iAmt) {
return _OK
}
if n == 0 && err != io.EOF {
return uint32(IOERR_READ)
}
for i := range buf[n:] {
buf[n+i] = 0
}
return uint32(IOERR_SHORT_READ)
}
func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 {
buf := memory{mod}.view(zBuf, uint64(iAmt))
file := vfsFile.GetOS(ctx, mod, pFile)
_, err := file.WriteAt(buf, int64(iOfst))
if err != nil {
return uint32(IOERR_WRITE)
}
return _OK
}
func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte uint64) uint32 {
file := vfsFile.GetOS(ctx, mod, pFile)
err := file.Truncate(int64(nByte))
if err != nil {
return uint32(IOERR_TRUNCATE)
}
return _OK
}
func vfsSync(ctx context.Context, mod api.Module, pFile, flags uint32) uint32 {
file := vfsFile.GetOS(ctx, mod, pFile)
err := file.Sync()
if err != nil {
return uint32(IOERR_FSYNC)
}
return _OK
}
func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) uint32 {
// This uses [os.File.Seek] because we don't care about the offset for reading/writing.
// But consider using [os.File.Stat] instead (as other VFSes do).
file := vfsFile.GetOS(ctx, mod, pFile)
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return uint32(IOERR_SEEK)
}
memory{mod}.writeUint64(pSize, uint64(off))
return _OK
}
func vfsFileControl(ctx context.Context, pFile, op, pArg uint32) uint32 {
// SQLite calls vfsFileControl with these opcodes:
// SQLITE_FCNTL_SIZE_HINT
// SQLITE_FCNTL_PRAGMA
// SQLITE_FCNTL_BUSYHANDLER
// SQLITE_FCNTL_HAS_MOVED
// SQLITE_FCNTL_SYNC
// SQLITE_FCNTL_COMMIT_PHASETWO
// SQLITE_FCNTL_PDB
return uint32(NOTFOUND)
}

9
vfs/README.md Normal file
View File

@@ -0,0 +1,9 @@
# Go SQLite VFS API
This package implements the SQLite [OS Interface](https://www.sqlite.org/vfs.html) (aka VFS).
It replaces the default VFS with a pure Go implementation,
that is tested on Linux, macOS and Windows,
but which should also work on illumos and the various BSDs.
It also exposes interfaces that should allow you to implement your own custom VFSes.

103
vfs/api.go Normal file
View File

@@ -0,0 +1,103 @@
// Package vfs wraps the C SQLite VFS API.
package vfs
import "net/url"
// A VFS defines the interface between the SQLite core and the underlying operating system.
//
// Use sqlite3.ErrorCode or sqlite3.ExtendedErrorCode to return specific error codes to SQLite.
//
// https://www.sqlite.org/c3ref/vfs.html
type VFS interface {
Open(name string, flags OpenFlag) (File, OpenFlag, error)
Delete(name string, syncDir bool) error
Access(name string, flags AccessFlag) (bool, error)
FullPathname(name string) (string, error)
}
// VFSParams extends VFS to with the ability to handle URI parameters
// through the OpenParams method.
//
// https://www.sqlite.org/c3ref/uri_boolean.html
type VFSParams interface {
VFS
OpenParams(name string, flags OpenFlag, params url.Values) (File, OpenFlag, error)
}
// A File represents an open file in the OS interface layer.
//
// Use sqlite3.ErrorCode or sqlite3.ExtendedErrorCode to return specific error codes to SQLite.
// In particular, sqlite3.BUSY is necessary to correctly implement lock methods.
//
// https://www.sqlite.org/c3ref/io_methods.html
type File interface {
Close() error
ReadAt(p []byte, off int64) (n int, err error)
WriteAt(p []byte, off int64) (n int, err error)
Truncate(size int64) error
Sync(flags SyncFlag) error
Size() (int64, error)
Lock(lock LockLevel) error
Unlock(lock LockLevel) error
CheckReservedLock() (bool, error)
SectorSize() int
DeviceCharacteristics() DeviceCharacteristic
}
// FileLockState extends File to implement the
// SQLITE_FCNTL_LOCKSTATE file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type FileLockState interface {
File
LockState() LockLevel
}
// FileSizeHint extends File to implement the
// SQLITE_FCNTL_SIZE_HINT file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type FileSizeHint interface {
File
SizeHint(size int64) error
}
// FileHasMoved extends File to implement the
// SQLITE_FCNTL_HAS_MOVED file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type FileHasMoved interface {
File
HasMoved() (bool, error)
}
// FilePowersafeOverwrite extends File to implement the
// SQLITE_FCNTL_POWERSAFE_OVERWRITE file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type FilePowersafeOverwrite interface {
File
PowersafeOverwrite() bool
SetPowersafeOverwrite(bool)
}
// FilePowersafeOverwrite extends File to implement the
// SQLITE_FCNTL_COMMIT_PHASETWO file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type FileCommitPhaseTwo interface {
File
CommitPhaseTwo() error
}
// FileBatchAtomicWrite extends File to implement the
// SQLITE_FCNTL_BEGIN_ATOMIC_WRITE, SQLITE_FCNTL_COMMIT_ATOMIC_WRITE
// and SQLITE_FCNTL_ROLLBACK_ATOMIC_WRITE file control opcodes.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type FileBatchAtomicWrite interface {
File
BeginAtomicWrite() error
CommitAtomicWrite() error
RollbackAtomicWrite() error
}

215
vfs/const.go Normal file
View File

@@ -0,0 +1,215 @@
package vfs
import "github.com/ncruces/go-sqlite3/internal/util"
const (
_MAX_STRING = 512 // Used for short strings: names, error messages…
_MAX_PATHNAME = 512
_DEFAULT_SECTOR_SIZE = 4096
)
// https://www.sqlite.org/rescode.html
type _ErrorCode uint32
func (e _ErrorCode) Error() string {
return util.ErrorCodeString(uint32(e))
}
const (
_OK _ErrorCode = util.OK
_PERM _ErrorCode = util.PERM
_BUSY _ErrorCode = util.BUSY
_READONLY _ErrorCode = util.READONLY
_IOERR _ErrorCode = util.IOERR
_NOTFOUND _ErrorCode = util.NOTFOUND
_CANTOPEN _ErrorCode = util.CANTOPEN
_IOERR_READ _ErrorCode = util.IOERR_READ
_IOERR_SHORT_READ _ErrorCode = util.IOERR_SHORT_READ
_IOERR_WRITE _ErrorCode = util.IOERR_WRITE
_IOERR_FSYNC _ErrorCode = util.IOERR_FSYNC
_IOERR_DIR_FSYNC _ErrorCode = util.IOERR_DIR_FSYNC
_IOERR_TRUNCATE _ErrorCode = util.IOERR_TRUNCATE
_IOERR_FSTAT _ErrorCode = util.IOERR_FSTAT
_IOERR_UNLOCK _ErrorCode = util.IOERR_UNLOCK
_IOERR_RDLOCK _ErrorCode = util.IOERR_RDLOCK
_IOERR_DELETE _ErrorCode = util.IOERR_DELETE
_IOERR_ACCESS _ErrorCode = util.IOERR_ACCESS
_IOERR_CHECKRESERVEDLOCK _ErrorCode = util.IOERR_CHECKRESERVEDLOCK
_IOERR_LOCK _ErrorCode = util.IOERR_LOCK
_IOERR_CLOSE _ErrorCode = util.IOERR_CLOSE
_IOERR_SEEK _ErrorCode = util.IOERR_SEEK
_IOERR_DELETE_NOENT _ErrorCode = util.IOERR_DELETE_NOENT
_IOERR_BEGIN_ATOMIC _ErrorCode = util.IOERR_BEGIN_ATOMIC
_IOERR_COMMIT_ATOMIC _ErrorCode = util.IOERR_COMMIT_ATOMIC
_IOERR_ROLLBACK_ATOMIC _ErrorCode = util.IOERR_ROLLBACK_ATOMIC
_CANTOPEN_FULLPATH _ErrorCode = util.CANTOPEN_FULLPATH
_CANTOPEN_ISDIR _ErrorCode = util.CANTOPEN_ISDIR
_OK_SYMLINK _ErrorCode = util.OK_SYMLINK
)
// OpenFlag is a flag for the [VFS] Open method.
//
// https://www.sqlite.org/c3ref/c_open_autoproxy.html
type OpenFlag uint32
const (
OPEN_READONLY OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */
OPEN_READWRITE OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */
OPEN_CREATE OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */
OPEN_DELETEONCLOSE OpenFlag = 0x00000008 /* VFS only */
OPEN_EXCLUSIVE OpenFlag = 0x00000010 /* VFS only */
OPEN_AUTOPROXY OpenFlag = 0x00000020 /* VFS only */
OPEN_URI OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */
OPEN_MEMORY OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */
OPEN_MAIN_DB OpenFlag = 0x00000100 /* VFS only */
OPEN_TEMP_DB OpenFlag = 0x00000200 /* VFS only */
OPEN_TRANSIENT_DB OpenFlag = 0x00000400 /* VFS only */
OPEN_MAIN_JOURNAL OpenFlag = 0x00000800 /* VFS only */
OPEN_TEMP_JOURNAL OpenFlag = 0x00001000 /* VFS only */
OPEN_SUBJOURNAL OpenFlag = 0x00002000 /* VFS only */
OPEN_SUPER_JOURNAL OpenFlag = 0x00004000 /* VFS only */
OPEN_NOMUTEX OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */
OPEN_FULLMUTEX OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */
OPEN_SHAREDCACHE OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */
OPEN_PRIVATECACHE OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
OPEN_WAL OpenFlag = 0x00080000 /* VFS only */
OPEN_NOFOLLOW OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
)
// AccessFlag is a flag for the [VFS] Access method.
//
// https://www.sqlite.org/c3ref/c_access_exists.html
type AccessFlag uint32
const (
ACCESS_EXISTS AccessFlag = 0
ACCESS_READWRITE AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
ACCESS_READ AccessFlag = 2 /* Unused */
)
// SyncFlag is a flag for the [File] Sync method.
//
// https://www.sqlite.org/c3ref/c_sync_dataonly.html
type SyncFlag uint32
const (
SYNC_NORMAL SyncFlag = 0x00002
SYNC_FULL SyncFlag = 0x00003
SYNC_DATAONLY SyncFlag = 0x00010
)
// LockLevel is a value used with [File] Lock and Unlock methods.
//
// https://www.sqlite.org/c3ref/c_lock_exclusive.html
type LockLevel uint32
const (
// No locks are held on the database.
// The database may be neither read nor written.
// Any internally cached data is considered suspect and subject to
// verification against the database file before being used.
// Other processes can read or write the database as their own locking
// states permit.
// This is the default state.
LOCK_NONE LockLevel = 0 /* xUnlock() only */
// The database may be read but not written.
// Any number of processes can hold SHARED locks at the same time,
// hence there can be many simultaneous readers.
// But no other thread or process is allowed to write to the database file
// while one or more SHARED locks are active.
LOCK_SHARED LockLevel = 1 /* xLock() or xUnlock() */
// A RESERVED lock means that the process is planning on writing to the
// database file at some point in the future but that it is currently just
// reading from the file.
// Only a single RESERVED lock may be active at one time,
// though multiple SHARED locks can coexist with a single RESERVED lock.
// RESERVED differs from PENDING in that new SHARED locks can be acquired
// while there is a RESERVED lock.
LOCK_RESERVED LockLevel = 2 /* xLock() only */
// A PENDING lock means that the process holding the lock wants to write to
// the database as soon as possible and is just waiting on all current
// SHARED locks to clear so that it can get an EXCLUSIVE lock.
// No new SHARED locks are permitted against the database if a PENDING lock
// is active, though existing SHARED locks are allowed to continue.
LOCK_PENDING LockLevel = 3 /* internal use only */
// An EXCLUSIVE lock is needed in order to write to the database file.
// Only one EXCLUSIVE lock is allowed on the file and no other locks of any
// kind are allowed to coexist with an EXCLUSIVE lock.
// In order to maximize concurrency, SQLite works to minimize the amount of
// time that EXCLUSIVE locks are held.
LOCK_EXCLUSIVE LockLevel = 4 /* xLock() only */
)
// DeviceCharacteristic is a flag retuned by the [File] DeviceCharacteristics method.
//
// https://www.sqlite.org/c3ref/c_iocap_atomic.html
type DeviceCharacteristic uint32
const (
IOCAP_ATOMIC DeviceCharacteristic = 0x00000001
IOCAP_ATOMIC512 DeviceCharacteristic = 0x00000002
IOCAP_ATOMIC1K DeviceCharacteristic = 0x00000004
IOCAP_ATOMIC2K DeviceCharacteristic = 0x00000008
IOCAP_ATOMIC4K DeviceCharacteristic = 0x00000010
IOCAP_ATOMIC8K DeviceCharacteristic = 0x00000020
IOCAP_ATOMIC16K DeviceCharacteristic = 0x00000040
IOCAP_ATOMIC32K DeviceCharacteristic = 0x00000080
IOCAP_ATOMIC64K DeviceCharacteristic = 0x00000100
IOCAP_SAFE_APPEND DeviceCharacteristic = 0x00000200
IOCAP_SEQUENTIAL DeviceCharacteristic = 0x00000400
IOCAP_UNDELETABLE_WHEN_OPEN DeviceCharacteristic = 0x00000800
IOCAP_POWERSAFE_OVERWRITE DeviceCharacteristic = 0x00001000
IOCAP_IMMUTABLE DeviceCharacteristic = 0x00002000
IOCAP_BATCH_ATOMIC DeviceCharacteristic = 0x00004000
)
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type _FcntlOpcode uint32
const (
_FCNTL_LOCKSTATE _FcntlOpcode = 1
_FCNTL_GET_LOCKPROXYFILE _FcntlOpcode = 2
_FCNTL_SET_LOCKPROXYFILE _FcntlOpcode = 3
_FCNTL_LAST_ERRNO _FcntlOpcode = 4
_FCNTL_SIZE_HINT _FcntlOpcode = 5
_FCNTL_CHUNK_SIZE _FcntlOpcode = 6
_FCNTL_FILE_POINTER _FcntlOpcode = 7
_FCNTL_SYNC_OMITTED _FcntlOpcode = 8
_FCNTL_WIN32_AV_RETRY _FcntlOpcode = 9
_FCNTL_PERSIST_WAL _FcntlOpcode = 10
_FCNTL_OVERWRITE _FcntlOpcode = 11
_FCNTL_VFSNAME _FcntlOpcode = 12
_FCNTL_POWERSAFE_OVERWRITE _FcntlOpcode = 13
_FCNTL_PRAGMA _FcntlOpcode = 14
_FCNTL_BUSYHANDLER _FcntlOpcode = 15
_FCNTL_TEMPFILENAME _FcntlOpcode = 16
_FCNTL_MMAP_SIZE _FcntlOpcode = 18
_FCNTL_TRACE _FcntlOpcode = 19
_FCNTL_HAS_MOVED _FcntlOpcode = 20
_FCNTL_SYNC _FcntlOpcode = 21
_FCNTL_COMMIT_PHASETWO _FcntlOpcode = 22
_FCNTL_WIN32_SET_HANDLE _FcntlOpcode = 23
_FCNTL_WAL_BLOCK _FcntlOpcode = 24
_FCNTL_ZIPVFS _FcntlOpcode = 25
_FCNTL_RBU _FcntlOpcode = 26
_FCNTL_VFS_POINTER _FcntlOpcode = 27
_FCNTL_JOURNAL_POINTER _FcntlOpcode = 28
_FCNTL_WIN32_GET_HANDLE _FcntlOpcode = 29
_FCNTL_PDB _FcntlOpcode = 30
_FCNTL_BEGIN_ATOMIC_WRITE _FcntlOpcode = 31
_FCNTL_COMMIT_ATOMIC_WRITE _FcntlOpcode = 32
_FCNTL_ROLLBACK_ATOMIC_WRITE _FcntlOpcode = 33
_FCNTL_LOCK_TIMEOUT _FcntlOpcode = 34
_FCNTL_DATA_VERSION _FcntlOpcode = 35
_FCNTL_SIZE_LIMIT _FcntlOpcode = 36
_FCNTL_CKPT_DONE _FcntlOpcode = 37
_FCNTL_RESERVE_BYTES _FcntlOpcode = 38
_FCNTL_CKPT_START _FcntlOpcode = 39
_FCNTL_EXTERNAL_READER _FcntlOpcode = 40
_FCNTL_CKSM_FILE _FcntlOpcode = 41
_FCNTL_RESET_CACHE _FcntlOpcode = 42
)

201
vfs/file.go Normal file
View File

@@ -0,0 +1,201 @@
package vfs
import (
"errors"
"io"
"io/fs"
"net/url"
"os"
"path/filepath"
"runtime"
"syscall"
"time"
)
type vfsOS struct{}
func (vfsOS) FullPathname(path string) (string, error) {
path, err := filepath.Abs(path)
if err != nil {
return "", err
}
fi, err := os.Lstat(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return path, nil
}
return "", err
}
if fi.Mode()&fs.ModeSymlink != 0 {
err = _OK_SYMLINK
}
return path, err
}
func (vfsOS) Delete(path string, syncDir bool) error {
err := os.Remove(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return _IOERR_DELETE_NOENT
}
return err
}
if runtime.GOOS != "windows" && syncDir {
f, err := os.Open(filepath.Dir(path))
if err != nil {
return _OK
}
defer f.Close()
err = osSync(f, false, false)
if err != nil {
return _IOERR_DIR_FSYNC
}
}
return nil
}
func (vfsOS) Access(name string, flags AccessFlag) (bool, error) {
err := osAccess(name, flags)
if flags == ACCESS_EXISTS {
if errors.Is(err, fs.ErrNotExist) {
return false, nil
}
} else {
if errors.Is(err, fs.ErrPermission) {
return false, nil
}
}
return err == nil, err
}
func (vfsOS) Open(name string, flags OpenFlag) (File, OpenFlag, error) {
return vfsOS{}.OpenParams(name, flags, nil)
}
func (vfsOS) OpenParams(name string, flags OpenFlag, params url.Values) (File, OpenFlag, error) {
var oflags int
if flags&OPEN_EXCLUSIVE != 0 {
oflags |= os.O_EXCL
}
if flags&OPEN_CREATE != 0 {
oflags |= os.O_CREATE
}
if flags&OPEN_READONLY != 0 {
oflags |= os.O_RDONLY
}
if flags&OPEN_READWRITE != 0 {
oflags |= os.O_RDWR
}
var err error
var f *os.File
if name == "" {
f, err = os.CreateTemp("", "*.db")
} else {
f, err = osOpenFile(name, oflags, 0666)
}
if err != nil {
if errors.Is(err, syscall.EISDIR) {
return nil, flags, _CANTOPEN_ISDIR
}
return nil, flags, err
}
if modeof := params.Get("modeof"); modeof != "" {
if err = osSetMode(f, modeof); err != nil {
f.Close()
return nil, flags, _IOERR_FSTAT
}
}
if flags&OPEN_DELETEONCLOSE != 0 {
os.Remove(f.Name())
}
file := vfsFile{
File: f,
psow: true,
readOnly: flags&OPEN_READONLY != 0,
syncDir: runtime.GOOS != "windows" &&
flags&(OPEN_CREATE) != 0 &&
flags&(OPEN_MAIN_JOURNAL|OPEN_SUPER_JOURNAL|OPEN_WAL) != 0,
}
return &file, flags, nil
}
type vfsFile struct {
*os.File
lockTimeout time.Duration
lock LockLevel
psow bool
syncDir bool
readOnly bool
}
var (
// Ensure these interfaces are implemented:
_ FileLockState = &vfsFile{}
_ FileHasMoved = &vfsFile{}
_ FileSizeHint = &vfsFile{}
_ FilePowersafeOverwrite = &vfsFile{}
)
func (f *vfsFile) Sync(flags SyncFlag) error {
dataonly := (flags & SYNC_DATAONLY) != 0
fullsync := (flags & 0x0f) == SYNC_FULL
err := osSync(f.File, fullsync, dataonly)
if err != nil {
return err
}
if runtime.GOOS != "windows" && f.syncDir {
f.syncDir = false
d, err := os.Open(filepath.Dir(f.File.Name()))
if err != nil {
return nil
}
defer d.Close()
err = osSync(d, false, false)
if err != nil {
return _IOERR_DIR_FSYNC
}
}
return nil
}
func (f *vfsFile) Size() (int64, error) {
return f.Seek(0, io.SeekEnd)
}
func (*vfsFile) SectorSize() int {
return _DEFAULT_SECTOR_SIZE
}
func (f *vfsFile) DeviceCharacteristics() DeviceCharacteristic {
if f.psow {
return IOCAP_POWERSAFE_OVERWRITE
}
return 0
}
func (f *vfsFile) SizeHint(size int64) error {
return osAllocate(f.File, size)
}
func (f *vfsFile) HasMoved() (bool, error) {
fi, err := f.Stat()
if err != nil {
return false, err
}
pi, err := os.Stat(f.Name())
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return true, nil
}
return false, err
}
return !os.SameFile(fi, pi), nil
}
func (f *vfsFile) LockState() LockLevel { return f.lock }
func (f *vfsFile) PowersafeOverwrite() bool { return f.psow }
func (f *vfsFile) SetPowersafeOverwrite(psow bool) { f.psow = psow }

150
vfs/lock.go Normal file
View File

@@ -0,0 +1,150 @@
package vfs
import (
"os"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
const (
_PENDING_BYTE = 0x40000000
_RESERVED_BYTE = (_PENDING_BYTE + 1)
_SHARED_FIRST = (_PENDING_BYTE + 2)
_SHARED_SIZE = 510
)
func (f *vfsFile) Lock(lock LockLevel) error {
// Argument check. SQLite never explicitly requests a pending lock.
if lock != LOCK_SHARED && lock != LOCK_RESERVED && lock != LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
switch {
case f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE:
// Connection state check.
panic(util.AssertErr())
case f.lock == LOCK_NONE && lock > LOCK_SHARED:
// We never move from unlocked to anything higher than a shared lock.
panic(util.AssertErr())
case f.lock != LOCK_SHARED && lock == LOCK_RESERVED:
// A shared lock is always held when a reserved lock is requested.
panic(util.AssertErr())
}
// If we already have an equal or more restrictive lock, do nothing.
if f.lock >= lock {
return nil
}
// Do not allow any kind of write-lock on a read-only database.
if f.readOnly && lock >= LOCK_RESERVED {
return _IOERR_LOCK
}
switch lock {
case LOCK_SHARED:
// Must be unlocked to get SHARED.
if f.lock != LOCK_NONE {
panic(util.AssertErr())
}
if rc := osGetSharedLock(f.File, f.lockTimeout); rc != _OK {
return rc
}
f.lock = LOCK_SHARED
return nil
case LOCK_RESERVED:
// Must be SHARED to get RESERVED.
if f.lock != LOCK_SHARED {
panic(util.AssertErr())
}
if rc := osGetReservedLock(f.File, f.lockTimeout); rc != _OK {
return rc
}
f.lock = LOCK_RESERVED
return nil
case LOCK_EXCLUSIVE:
// Must be SHARED, RESERVED or PENDING to get EXCLUSIVE.
if f.lock <= LOCK_NONE || f.lock >= LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
// A PENDING lock is needed before acquiring an EXCLUSIVE lock.
if f.lock < LOCK_PENDING {
if rc := osGetPendingLock(f.File); rc != _OK {
return rc
}
f.lock = LOCK_PENDING
}
if rc := osGetExclusiveLock(f.File, f.lockTimeout); rc != _OK {
return rc
}
f.lock = LOCK_EXCLUSIVE
return nil
default:
panic(util.AssertErr())
}
}
func (f *vfsFile) Unlock(lock LockLevel) error {
// Argument check.
if lock != LOCK_NONE && lock != LOCK_SHARED {
panic(util.AssertErr())
}
// Connection state check.
if f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
// If we don't have a more restrictive lock, do nothing.
if f.lock <= lock {
return nil
}
switch lock {
case LOCK_SHARED:
if rc := osDowngradeLock(f.File, f.lock); rc != _OK {
return rc
}
f.lock = LOCK_SHARED
return nil
case LOCK_NONE:
rc := osReleaseLock(f.File, f.lock)
f.lock = LOCK_NONE
return rc
default:
panic(util.AssertErr())
}
}
func (f *vfsFile) CheckReservedLock() (bool, error) {
// Connection state check.
if f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
if f.lock >= LOCK_RESERVED {
return true, nil
}
return osCheckReservedLock(f.File)
}
func osGetReservedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Acquire the RESERVED lock.
return osWriteLock(file, _RESERVED_BYTE, 1, timeout)
}
func osGetPendingLock(file *os.File) _ErrorCode {
// Acquire the PENDING lock.
return osWriteLock(file, _PENDING_BYTE, 1, 0)
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
// Test the RESERVED lock.
return osCheckLock(file, _RESERVED_BYTE, 1)
}

220
vfs/lock_test.go Normal file
View File

@@ -0,0 +1,220 @@
package vfs
import (
"context"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/experimental/wazerotest"
)
func Test_vfsLock(t *testing.T) {
switch runtime.GOOS {
case "linux", "darwin", "windows":
break
default:
t.Skip("OS lacks OFD locks")
}
name := filepath.Join(t.TempDir(), "test.db")
// Create a temporary file.
file1, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
if err != nil {
t.Fatal(err)
}
defer file1.Close()
// Open the temporary file again.
file2, err := os.OpenFile(name, os.O_RDWR, 0)
if err != nil {
t.Fatal(err)
}
defer file2.Close()
const (
pFile1 = 4
pFile2 = 16
pOutput = 32
)
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
ctx, vfs := NewContext(context.TODO())
defer vfs.Close()
vfsFileRegister(ctx, mod, pFile1, &vfsFile{File: file1})
vfsFileRegister(ctx, mod, pFile2, &vfsFile{File: file2})
rc := vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_NONE) {
t.Error("invalid lock state", got)
}
rc = vfsLock(ctx, mod, pFile2, LOCK_SHARED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) {
t.Error("invalid lock state", got)
}
rc = vfsLock(ctx, mod, pFile2, LOCK_RESERVED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsLock(ctx, mod, pFile2, LOCK_SHARED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_RESERVED) {
t.Error("invalid lock state", got)
}
rc = vfsLock(ctx, mod, pFile2, LOCK_EXCLUSIVE)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_EXCLUSIVE) {
t.Error("invalid lock state", got)
}
rc = vfsLock(ctx, mod, pFile1, LOCK_SHARED)
if rc == _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCKSTATE, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_NONE) {
t.Error("invalid lock state", got)
}
rc = vfsUnlock(ctx, mod, pFile2, LOCK_SHARED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(ctx, mod, pFile1, LOCK_SHARED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCKSTATE, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) {
t.Error("invalid lock state", got)
}
rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCK_TIMEOUT, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
}

9
vfs/memdb/README.md Normal file
View File

@@ -0,0 +1,9 @@
# Go `"memdb"` SQLite VFS
This package implements the [`"memdb"`](https://www.sqlite.org/src/file/src/memdb.c)
SQLite VFS in pure Go.
It has some benefits over the C version:
- the memory backing the database needs not be contiguous,
- the database can grow/shrink incrementally without copying,
- reader-writer concurrency is slightly improved.

59
vfs/memdb/api.go Normal file
View File

@@ -0,0 +1,59 @@
// Package memdb implements the "memdb" SQLite VFS.
//
// The "memdb" [vfs.VFS] allows the same in-memory database to be shared
// among multiple database connections in the same process,
// as long as the database name begins with "/".
//
// Importing package memdb registers the VFS.
//
// import _ "github.com/ncruces/go-sqlite3/vfs/memdb"
package memdb
import (
"sync"
"github.com/ncruces/go-sqlite3/vfs"
)
func init() {
vfs.Register("memdb", memVFS{})
}
var (
memoryMtx sync.Mutex
// +checklocks:memoryMtx
memoryDBs = map[string]*memDB{}
)
// Create creates a shared memory database,
// using data as its initial contents.
// The new database takes ownership of data,
// and the caller should not use data after this call.
func Create(name string, data []byte) {
memoryMtx.Lock()
defer memoryMtx.Unlock()
db := new(memDB)
db.size = int64(len(data))
sectors := divRoundUp(db.size, sectorSize)
db.data = make([]*[sectorSize]byte, sectors)
for i := range db.data {
sector := data[i*sectorSize:]
if len(sector) >= sectorSize {
db.data[i] = (*[sectorSize]byte)(sector)
} else {
db.data[i] = new([sectorSize]byte)
copy((*db.data[i])[:], sector)
}
}
memoryDBs[name] = db
}
// Delete deletes a shared memory database.
func Delete(name string) {
memoryMtx.Lock()
defer memoryMtx.Unlock()
delete(memoryDBs, name)
}

51
vfs/memdb/example_test.go Normal file
View File

@@ -0,0 +1,51 @@
package memdb_test
import (
"database/sql"
"fmt"
"log"
_ "embed"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/vfs/memdb"
)
//go:embed testdata/test.db
var testDB []byte
func Example() {
memdb.Create("test.db", testDB)
db, err := sql.Open("sqlite3", "file:/test.db?vfs=memdb")
if err != nil {
log.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`INSERT INTO users (id, name) VALUES (3, 'rust')`)
if err != nil {
log.Fatal(err)
}
rows, err := db.Query(`SELECT id, name FROM users`)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var id, name string
err = rows.Scan(&id, &name)
if err != nil {
log.Fatal(err)
}
fmt.Printf("%s %s\n", id, name)
}
// Output:
// 0 go
// 1 zig
// 2 whatever
// 3 rust
}

294
vfs/memdb/memdb.go Normal file
View File

@@ -0,0 +1,294 @@
package memdb
import (
"io"
"runtime"
"strings"
"sync"
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/vfs"
)
type memVFS struct{}
func (memVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
// Allowed file types:
// - databases, which only do page aligned reads/writes;
// - temp journals, used by the sorter, which does the same.
const types = vfs.OPEN_MAIN_DB |
vfs.OPEN_TRANSIENT_DB |
vfs.OPEN_TEMP_DB |
vfs.OPEN_TEMP_JOURNAL
if flags&types == 0 {
return nil, flags, sqlite3.CANTOPEN
}
var db *memDB
shared := strings.HasPrefix(name, "/")
if shared {
memoryMtx.Lock()
defer memoryMtx.Unlock()
db = memoryDBs[name[1:]]
}
if db == nil {
if flags&vfs.OPEN_CREATE == 0 {
return nil, flags, sqlite3.CANTOPEN
}
db = new(memDB)
}
if shared {
memoryDBs[name[1:]] = db // +checklocksignore: lock is held
}
return &memFile{
memDB: db,
readOnly: flags&vfs.OPEN_READONLY != 0,
}, flags | vfs.OPEN_MEMORY, nil
}
func (memVFS) Delete(name string, dirSync bool) error {
return sqlite3.IOERR_DELETE
}
func (memVFS) Access(name string, flag vfs.AccessFlag) (bool, error) {
return false, nil
}
func (memVFS) FullPathname(name string) (string, error) {
return name, nil
}
// Must be a multiple of 64K (the largest page size).
const sectorSize = 65536
type memDB struct {
// +checklocks:lockMtx
pending *memFile
// +checklocks:lockMtx
reserved *memFile
// +checklocks:dataMtx
data []*[sectorSize]byte
// +checklocks:dataMtx
size int64
// +checklocks:lockMtx
shared int
lockMtx sync.Mutex
dataMtx sync.RWMutex
}
type memFile struct {
*memDB
lock vfs.LockLevel
readOnly bool
}
var (
// Ensure these interfaces are implemented:
_ vfs.FileLockState = &memFile{}
_ vfs.FileSizeHint = &memFile{}
)
func (m *memFile) Close() error {
return m.Unlock(vfs.LOCK_NONE)
}
func (m *memFile) ReadAt(b []byte, off int64) (n int, err error) {
m.dataMtx.RLock()
defer m.dataMtx.RUnlock()
if off >= m.size {
return 0, io.EOF
}
base := off / sectorSize
rest := off % sectorSize
have := int64(sectorSize)
if base == int64(len(m.data))-1 {
have = modRoundUp(m.size, sectorSize)
}
n = copy(b, (*m.data[base])[rest:have])
if n < len(b) {
// Assume reads are page aligned.
return 0, io.ErrNoProgress
}
return n, nil
}
func (m *memFile) WriteAt(b []byte, off int64) (n int, err error) {
m.dataMtx.Lock()
defer m.dataMtx.Unlock()
base := off / sectorSize
rest := off % sectorSize
for base >= int64(len(m.data)) {
m.data = append(m.data, new([sectorSize]byte))
}
n = copy((*m.data[base])[rest:], b)
if n < len(b) {
// Assume writes are page aligned.
return 0, io.ErrShortWrite
}
if size := off + int64(len(b)); size > m.size {
m.size = size
}
return n, nil
}
func (m *memFile) Truncate(size int64) error {
m.dataMtx.Lock()
defer m.dataMtx.Unlock()
return m.truncate(size)
}
// +checklocks:m.dataMtx
func (m *memFile) truncate(size int64) error {
if size < m.size {
base := size / sectorSize
rest := size % sectorSize
if rest != 0 {
clear((*m.data[base])[rest:])
}
}
sectors := divRoundUp(size, sectorSize)
for sectors > int64(len(m.data)) {
m.data = append(m.data, new([sectorSize]byte))
}
clear(m.data[sectors:])
m.data = m.data[:sectors]
m.size = size
return nil
}
func (*memFile) Sync(flag vfs.SyncFlag) error {
return nil
}
func (m *memFile) Size() (int64, error) {
m.dataMtx.RLock()
defer m.dataMtx.RUnlock()
return m.size, nil
}
func (m *memFile) Lock(lock vfs.LockLevel) error {
if m.lock >= lock {
return nil
}
if m.readOnly && lock >= vfs.LOCK_RESERVED {
return sqlite3.IOERR_LOCK
}
m.lockMtx.Lock()
defer m.lockMtx.Unlock()
switch lock {
case vfs.LOCK_SHARED:
if m.pending != nil {
return sqlite3.BUSY
}
m.shared++
case vfs.LOCK_RESERVED:
if m.reserved != nil {
return sqlite3.BUSY
}
m.reserved = m
case vfs.LOCK_EXCLUSIVE:
if m.lock < vfs.LOCK_PENDING {
if m.pending != nil {
return sqlite3.BUSY
}
m.lock = vfs.LOCK_PENDING
m.pending = m
}
for start := time.Now(); m.shared > 1; {
if time.Since(start) > time.Millisecond {
return sqlite3.BUSY
}
m.lockMtx.Unlock()
runtime.Gosched()
m.lockMtx.Lock()
}
}
m.lock = lock
return nil
}
func (m *memFile) Unlock(lock vfs.LockLevel) error {
if m.lock <= lock {
return nil
}
m.lockMtx.Lock()
defer m.lockMtx.Unlock()
if m.pending == m {
m.pending = nil
}
if m.reserved == m {
m.reserved = nil
}
if lock < vfs.LOCK_SHARED {
m.shared--
}
m.lock = lock
return nil
}
func (m *memFile) CheckReservedLock() (bool, error) {
if m.lock >= vfs.LOCK_RESERVED {
return true, nil
}
m.lockMtx.Lock()
defer m.lockMtx.Unlock()
return m.reserved != nil, nil
}
func (*memFile) SectorSize() int {
return sectorSize
}
func (*memFile) DeviceCharacteristics() vfs.DeviceCharacteristic {
return vfs.IOCAP_ATOMIC |
vfs.IOCAP_SEQUENTIAL |
vfs.IOCAP_SAFE_APPEND |
vfs.IOCAP_POWERSAFE_OVERWRITE
}
func (m *memFile) SizeHint(size int64) error {
m.dataMtx.Lock()
defer m.dataMtx.Unlock()
if size > m.size {
return m.truncate(size)
}
return nil
}
func (m *memFile) LockState() vfs.LockLevel {
return m.lock
}
func divRoundUp(a, b int64) int64 {
return (a + b - 1) / b
}
func modRoundUp(a, b int64) int64 {
return b - (b-a%b)%b
}
func clear[T any](b []T) {
var zero T
for i := range b {
b[i] = zero
}
}

BIN
vfs/memdb/testdata/test.db vendored Normal file

Binary file not shown.

56
vfs/os_bsd.go Normal file
View File

@@ -0,0 +1,56 @@
//go:build freebsd || openbsd || netbsd || dragonfly || (darwin && sqlite3_bsd)
package vfs
import (
"os"
"time"
"golang.org/x/sys/unix"
)
func osUnlock(file *os.File, start, len int64) _ErrorCode {
if start == 0 && len == 0 {
err := unix.Flock(int(file.Fd()), unix.LOCK_UN)
if err != nil {
return _IOERR_UNLOCK
}
}
return _OK
}
func osLock(file *os.File, how int, timeout time.Duration, def _ErrorCode) _ErrorCode {
var err error
for {
err = unix.Flock(int(file.Fd()), how)
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
break
}
if timeout < time.Millisecond {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)
}
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.LOCK_SH|unix.LOCK_NB, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.LOCK_EX|unix.LOCK_NB, timeout, _IOERR_LOCK)
}
func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) {
lock := unix.Flock_t{
Type: unix.F_RDLCK,
Start: start,
Len: len,
}
if unix.FcntlFlock(file.Fd(), unix.F_GETLK, &lock) != nil {
return false, _IOERR_CHECKRESERVEDLOCK
}
return lock.Type != unix.F_UNLCK, _OK
}

106
vfs/os_darwin.go Normal file
View File

@@ -0,0 +1,106 @@
//go:build !sqlite3_bsd
package vfs
import (
"io"
"os"
"time"
"golang.org/x/sys/unix"
)
const (
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
_F_OFD_SETLK = 90
_F_OFD_SETLKW = 91
_F_OFD_GETLK = 92
_F_OFD_SETLKWTIMEOUT = 93
)
type flocktimeout_t struct {
fl unix.Flock_t
timeout unix.Timespec
}
func osSync(file *os.File, fullsync, dataonly bool) error {
if fullsync {
return file.Sync()
}
return unix.Fsync(int(file.Fd()))
}
func osAllocate(file *os.File, size int64) error {
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return err
}
if size <= off {
return nil
}
// https://stackoverflow.com/a/11497568/867786
store := unix.Fstore_t{
Flags: unix.F_ALLOCATECONTIG,
Posmode: unix.F_PEOFPOSMODE,
Offset: 0,
Length: size,
}
// Try to get a continuous chunk of disk space.
err = unix.FcntlFstore(file.Fd(), unix.F_PREALLOCATE, &store)
if err != nil {
// OK, perhaps we are too fragmented, allocate non-continuous.
store.Flags = unix.F_ALLOCATEALL
unix.FcntlFstore(file.Fd(), unix.F_PREALLOCATE, &store)
}
return file.Truncate(size)
}
func osUnlock(file *os.File, start, len int64) _ErrorCode {
err := unix.FcntlFlock(file.Fd(), _F_OFD_SETLK, &unix.Flock_t{
Type: unix.F_UNLCK,
Start: start,
Len: len,
})
if err != nil {
return _IOERR_UNLOCK
}
return _OK
}
func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, def _ErrorCode) _ErrorCode {
lock := flocktimeout_t{fl: unix.Flock_t{
Type: typ,
Start: start,
Len: len,
}}
var err error
if timeout == 0 {
err = unix.FcntlFlock(file.Fd(), _F_OFD_SETLK, &lock.fl)
} else {
lock.timeout = unix.NsecToTimespec(int64(timeout / time.Nanosecond))
err = unix.FcntlFlock(file.Fd(), _F_OFD_SETLKWTIMEOUT, &lock.fl)
}
return osLockErrorCode(err, def)
}
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_RDLCK, start, len, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_WRLCK, start, len, timeout, _IOERR_LOCK)
}
func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) {
lock := unix.Flock_t{
Type: unix.F_RDLCK,
Start: start,
Len: len,
}
if unix.FcntlFlock(file.Fd(), _F_OFD_GETLK, &lock) != nil {
return false, _IOERR_CHECKRESERVEDLOCK
}
return lock.Type != unix.F_UNLCK, _OK
}

25
vfs/os_linux.go Normal file
View File

@@ -0,0 +1,25 @@
package vfs
import (
"os"
"golang.org/x/sys/unix"
)
func osSync(file *os.File, fullsync, dataonly bool) error {
if dataonly {
_, _, err := unix.Syscall(unix.SYS_FDATASYNC, file.Fd(), 0, 0)
if err != 0 {
return err
}
return nil
}
return file.Sync()
}
func osAllocate(file *os.File, size int64) error {
if size == 0 {
return nil
}
return unix.Fallocate(int(file.Fd()), 0, 0, size)
}

63
vfs/os_ofd.go Normal file
View File

@@ -0,0 +1,63 @@
//go:build linux || illumos
package vfs
import (
"os"
"time"
"golang.org/x/sys/unix"
)
func osUnlock(file *os.File, start, len int64) _ErrorCode {
err := unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &unix.Flock_t{
Type: unix.F_UNLCK,
Start: start,
Len: len,
})
if err != nil {
return _IOERR_UNLOCK
}
return _OK
}
func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, def _ErrorCode) _ErrorCode {
lock := unix.Flock_t{
Type: typ,
Start: start,
Len: len,
}
var err error
for {
err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock)
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
break
}
if timeout < time.Millisecond {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)
}
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_RDLCK, start, len, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_WRLCK, start, len, timeout, _IOERR_LOCK)
}
func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) {
lock := unix.Flock_t{
Type: unix.F_RDLCK,
Start: start,
Len: len,
}
if unix.FcntlFlock(file.Fd(), unix.F_OFD_GETLK, &lock) != nil {
return false, _IOERR_CHECKRESERVEDLOCK
}
return lock.Type != unix.F_UNLCK, _OK
}

23
vfs/os_other.go Normal file
View File

@@ -0,0 +1,23 @@
//go:build !linux && (!darwin || sqlite3_bsd)
package vfs
import (
"io"
"os"
)
func osSync(file *os.File, fullsync, dataonly bool) error {
return file.Sync()
}
func osAllocate(file *os.File, size int64) error {
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return err
}
if size <= off {
return nil
}
return file.Truncate(size)
}

100
vfs/os_unix.go Normal file
View File

@@ -0,0 +1,100 @@
//go:build unix
package vfs
import (
"io/fs"
"os"
"syscall"
"time"
"golang.org/x/sys/unix"
)
func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(name, flag, perm)
}
func osAccess(path string, flags AccessFlag) error {
var access uint32 // unix.F_OK
switch flags {
case ACCESS_READWRITE:
access = unix.R_OK | unix.W_OK
case ACCESS_READ:
access = unix.R_OK
}
return unix.Access(path, access)
}
func osSetMode(file *os.File, modeof string) error {
fi, err := os.Stat(modeof)
if err != nil {
return err
}
file.Chmod(fi.Mode())
if sys, ok := fi.Sys().(*syscall.Stat_t); ok {
file.Chown(int(sys.Uid), int(sys.Gid))
}
return nil
}
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Test the PENDING lock before acquiring a new SHARED lock.
if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending {
return _BUSY
}
// Acquire the SHARED lock.
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
}
func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
if timeout == 0 {
timeout = time.Millisecond
}
// Acquire the EXCLUSIVE lock.
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
}
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
if state >= LOCK_EXCLUSIVE {
// Downgrade to a SHARED lock.
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return _IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return _IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return osUnlock(file, _PENDING_BYTE, 2)
}
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
// Release all locks.
return osUnlock(file, 0, 0)
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(unix.Errno); ok {
switch errno {
case
unix.EACCES,
unix.EAGAIN,
unix.EBUSY,
unix.EINTR,
unix.ENOLCK,
unix.EDEADLK,
unix.ETIMEDOUT:
return _BUSY
case unix.EPERM:
return _PERM
}
}
return def
}

249
vfs/os_windows.go Normal file
View File

@@ -0,0 +1,249 @@
package vfs
import (
"io/fs"
"os"
"syscall"
"time"
"golang.org/x/sys/windows"
)
// osOpenFile is a simplified copy of [os.openFileNolog]
// that uses syscall.FILE_SHARE_DELETE.
// https://go.dev/src/os/file_windows.go
//
// See: https://go.dev/issue/32088
func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
if name == "" {
return nil, &os.PathError{Op: "open", Path: name, Err: syscall.ENOENT}
}
r, e := syscallOpen(name, flag, uint32(perm.Perm()))
if e != nil {
return nil, &os.PathError{Op: "open", Path: name, Err: e}
}
return os.NewFile(uintptr(r), name), nil
}
func osAccess(path string, flags AccessFlag) error {
fi, err := os.Stat(path)
if err != nil {
return err
}
if flags == ACCESS_EXISTS {
return nil
}
var want fs.FileMode = windows.S_IRUSR
if flags == ACCESS_READWRITE {
want |= windows.S_IWUSR
}
if fi.IsDir() {
want |= windows.S_IXUSR
}
if fi.Mode()&want != want {
return fs.ErrPermission
}
return nil
}
func osSetMode(file *os.File, modeof string) error {
fi, err := os.Stat(modeof)
if err != nil {
return err
}
file.Chmod(fi.Mode())
return nil
}
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Acquire the PENDING lock temporarily before acquiring a new SHARED lock.
rc := osReadLock(file, _PENDING_BYTE, 1, timeout)
if rc == _OK {
// Acquire the SHARED lock.
rc = osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
// Release the PENDING lock.
osUnlock(file, _PENDING_BYTE, 1)
}
return rc
}
func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
if timeout == 0 {
timeout = time.Millisecond
}
// Release the SHARED lock.
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Acquire the EXCLUSIVE lock.
rc := osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
if rc != _OK {
// Reacquire the SHARED lock.
osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
}
return rc
}
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
if state >= LOCK_EXCLUSIVE {
// Release the SHARED lock.
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Reacquire the SHARED lock.
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// This should never happen.
// We should always be able to reacquire the read lock.
return _IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
if state >= LOCK_RESERVED {
osUnlock(file, _RESERVED_BYTE, 1)
}
if state >= LOCK_PENDING {
osUnlock(file, _PENDING_BYTE, 1)
}
return _OK
}
func osReleaseLock(file *os.File, state LockLevel) _ErrorCode {
// Release all locks.
if state >= LOCK_RESERVED {
osUnlock(file, _RESERVED_BYTE, 1)
}
if state >= LOCK_SHARED {
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
}
if state >= LOCK_PENDING {
osUnlock(file, _PENDING_BYTE, 1)
}
return _OK
}
func osUnlock(file *os.File, start, len uint32) _ErrorCode {
err := windows.UnlockFileEx(windows.Handle(file.Fd()),
0, len, 0, &windows.Overlapped{Offset: start})
if err == windows.ERROR_NOT_LOCKED {
return _OK
}
if err != nil {
return _IOERR_UNLOCK
}
return _OK
}
func osLock(file *os.File, flags, start, len uint32, timeout time.Duration, def _ErrorCode) _ErrorCode {
var err error
for {
err = windows.LockFileEx(windows.Handle(file.Fd()), flags,
0, len, 0, &windows.Overlapped{Offset: start})
if errno, _ := err.(windows.Errno); errno != windows.ERROR_LOCK_VIOLATION {
break
}
if timeout < time.Millisecond {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)
}
func osReadLock(file *os.File, start, len uint32, timeout time.Duration) _ErrorCode {
return osLock(file,
windows.LOCKFILE_FAIL_IMMEDIATELY,
start, len, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len uint32, timeout time.Duration) _ErrorCode {
return osLock(file,
windows.LOCKFILE_FAIL_IMMEDIATELY|windows.LOCKFILE_EXCLUSIVE_LOCK,
start, len, timeout, _IOERR_LOCK)
}
func osCheckLock(file *os.File, start, len uint32) (bool, _ErrorCode) {
rc := osLock(file,
windows.LOCKFILE_FAIL_IMMEDIATELY,
start, len, 0, _IOERR_CHECKRESERVEDLOCK)
if rc == _BUSY {
return true, _OK
}
if rc == _OK {
osUnlock(file, start, len)
}
return false, rc
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(windows.Errno); ok {
// https://devblogs.microsoft.com/oldnewthing/20140905-00/?p=63
switch errno {
case
windows.ERROR_LOCK_VIOLATION,
windows.ERROR_IO_PENDING,
windows.ERROR_OPERATION_ABORTED:
return _BUSY
}
}
return def
}
// syscallOpen is a simplified copy of [syscall.Open]
// that uses syscall.FILE_SHARE_DELETE.
// https://go.dev/src/syscall/syscall_windows.go
func syscallOpen(path string, mode int, perm uint32) (fd syscall.Handle, err error) {
if len(path) == 0 {
return syscall.InvalidHandle, syscall.ERROR_FILE_NOT_FOUND
}
pathp, err := syscall.UTF16PtrFromString(path)
if err != nil {
return syscall.InvalidHandle, err
}
var access uint32
switch mode & (syscall.O_RDONLY | syscall.O_WRONLY | syscall.O_RDWR) {
case syscall.O_RDONLY:
access = syscall.GENERIC_READ
case syscall.O_WRONLY:
access = syscall.GENERIC_WRITE
case syscall.O_RDWR:
access = syscall.GENERIC_READ | syscall.GENERIC_WRITE
}
if mode&syscall.O_CREAT != 0 {
access |= syscall.GENERIC_WRITE
}
if mode&syscall.O_APPEND != 0 {
access &^= syscall.GENERIC_WRITE
access |= syscall.FILE_APPEND_DATA
}
sharemode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE)
var createmode uint32
switch {
case mode&(syscall.O_CREAT|syscall.O_EXCL) == (syscall.O_CREAT | syscall.O_EXCL):
createmode = syscall.CREATE_NEW
case mode&(syscall.O_CREAT|syscall.O_TRUNC) == (syscall.O_CREAT | syscall.O_TRUNC):
createmode = syscall.CREATE_ALWAYS
case mode&syscall.O_CREAT == syscall.O_CREAT:
createmode = syscall.OPEN_ALWAYS
case mode&syscall.O_TRUNC == syscall.O_TRUNC:
createmode = syscall.TRUNCATE_EXISTING
default:
createmode = syscall.OPEN_EXISTING
}
var attrs uint32 = syscall.FILE_ATTRIBUTE_NORMAL
if perm&syscall.S_IWRITE == 0 {
attrs = syscall.FILE_ATTRIBUTE_READONLY
}
if createmode == syscall.OPEN_EXISTING && access == syscall.GENERIC_READ {
// Necessary for opening directory handles.
attrs |= syscall.FILE_FLAG_BACKUP_SEMANTICS
}
return syscall.CreateFile(pathp, access, sharemode, nil, createmode, attrs, 0)
}

5
vfs/readervfs/README.md Normal file
View File

@@ -0,0 +1,5 @@
# Go `"reader"` SQLite VFS
This package implements a `"reader"` SQLite VFS
that allows accessing any [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
as an immutable SQLite database.

84
vfs/readervfs/api.go Normal file
View File

@@ -0,0 +1,84 @@
// Package readervfs implements an SQLite VFS for immutable databases.
//
// The "reader" [vfs.VFS] permits accessing any [io.ReaderAt]
// as an immutable SQLite database.
//
// Importing package readervfs registers the VFS.
//
// import _ "github.com/ncruces/go-sqlite3/vfs/readervfs"
package readervfs
import (
"io"
"io/fs"
"sync"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/vfs"
)
func init() {
vfs.Register("reader", readerVFS{})
}
var (
readerMtx sync.RWMutex
// +checklocks:readerMtx
readerDBs = map[string]SizeReaderAt{}
)
// Create creates an immutable database from reader.
// The caller should ensure that data from reader does not mutate,
// otherwise SQLite might return incorrect query results and/or [sqlite3.CORRUPT] errors.
func Create(name string, reader SizeReaderAt) {
readerMtx.Lock()
defer readerMtx.Unlock()
readerDBs[name] = reader
}
// Delete deletes a shared memory database.
func Delete(name string) {
readerMtx.Lock()
defer readerMtx.Unlock()
delete(readerDBs, name)
}
// A SizeReaderAt is a ReaderAt with a Size method.
// Use [NewSizeReaderAt] to adapt different Size interfaces.
type SizeReaderAt interface {
Size() (int64, error)
io.ReaderAt
}
// NewSizeReaderAt returns a SizeReaderAt given an io.ReaderAt
// that implements one of:
// - Size() (int64, error)
// - Size() int64
// - Len() int
// - Stat() (fs.FileInfo, error)
// - Seek(offset int64, whence int) (int64, error)
func NewSizeReaderAt(r io.ReaderAt) SizeReaderAt {
return sizer{r}
}
type sizer struct{ io.ReaderAt }
func (s sizer) Size() (int64, error) {
switch s := s.ReaderAt.(type) {
case interface{ Size() (int64, error) }:
return s.Size()
case interface{ Size() int64 }:
return s.Size(), nil
case interface{ Len() int }:
return int64(s.Len()), nil
case interface{ Stat() (fs.FileInfo, error) }:
fi, err := s.Stat()
if err != nil {
return 0, err
}
return fi.Size(), nil
case io.Seeker:
return s.Seek(0, io.SeekEnd)
}
return 0, sqlite3.IOERR_SEEK
}

Some files were not shown because too many files have changed in this diff Show More