Compare commits

...

94 Commits

Author SHA1 Message Date
Nuno Cruces
a7c00eb150 SQLite 3.44.0. 2023-11-03 03:43:14 -07:00
Nuno Cruces
0bcdb712ba SQL json_time function. 2023-11-03 03:40:46 -07:00
Nuno Cruces
2157d0f325 Interrupts: avoid goroutine. 2023-10-25 14:12:21 +01:00
Nuno Cruces
6353160619 Improve benchmark repeatability. 2023-10-25 13:17:37 +01:00
Nuno Cruces
501d157279 Update BSD test. 2023-10-24 23:27:10 +01:00
Nuno Cruces
4db18a7b9a JSON encoding fix. 2023-10-19 16:46:58 +01:00
Nuno Cruces
a9dddaa86c Optimize VFS search. 2023-10-19 16:43:54 +01:00
Nuno Cruces
b25936dbec Unix formats return UTC. 2023-10-19 12:07:03 +01:00
dependabot[bot]
bf23041e46 Bump github.com/ncruces/julianday from 0.1.5 to 1.0.0 (#33)
Bumps [github.com/ncruces/julianday](https://github.com/ncruces/julianday) from 0.1.5 to 1.0.0.
- [Commits](https://github.com/ncruces/julianday/compare/v0.1.5...v1.0.0)

---
updated-dependencies:
- dependency-name: github.com/ncruces/julianday
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-19 00:34:35 +01:00
Nuno Cruces
d60fceac92 JSON support. 2023-10-18 23:14:46 +01:00
Nuno Cruces
61da30f44a Allow configuring wazero. 2023-10-18 15:06:32 +01:00
Nuno Cruces
d4ff605983 Time scanner. 2023-10-18 15:06:12 +01:00
Nuno Cruces
8d0c654178 Cross compilation. 2023-10-17 15:30:08 +01:00
Nuno Cruces
728e59951b Test BSD. 2023-10-16 12:51:49 +01:00
Nuno Cruces
f7b16bad5c Patch flaky tests. 2023-10-16 12:26:25 +01:00
Nuno Cruces
db3e6da31a BSD locks. 2023-10-16 02:11:20 +01:00
Nuno Cruces
3f443b2ecc API change. 2023-10-13 18:53:37 +01:00
Nuno Cruces
eec45ea684 Towards JSON. 2023-10-13 17:06:05 +01:00
Nuno Cruces
f6d77f3cf4 GORM v1.25.5. 2023-10-13 00:42:06 +01:00
Nuno Cruces
d5d7cd1f2d SQLite 3.43.2. 2023-10-12 10:52:48 +01:00
dependabot[bot]
a33a187d48 Bump golang.org/x/sys from 0.12.0 to 0.13.0 (#30)
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.12.0 to 0.13.0.
- [Commits](https://github.com/golang/sys/compare/v0.12.0...v0.13.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-05 23:53:11 +01:00
dependabot[bot]
70c6ee15c6 Bump golang.org/x/sync from 0.3.0 to 0.4.0 (#29)
Bumps [golang.org/x/sync](https://github.com/golang/sync) from 0.3.0 to 0.4.0.
- [Commits](https://github.com/golang/sync/compare/v0.3.0...v0.4.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sync
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-05 23:43:11 +01:00
Nuno Cruces
994d9b1812 Updated dependencies. 2023-10-02 10:09:26 +01:00
Nuno Cruces
b19bd28ed3 Simplify lock timeouts. 2023-10-02 10:06:09 +01:00
Nuno Cruces
e66bd51845 More VFS API. 2023-09-21 02:43:45 +01:00
Nuno Cruces
f5614bc2ed Tweaks. 2023-09-20 15:07:07 +01:00
Nuno Cruces
d9fcf60b7d Driver API. 2023-09-20 02:41:09 +01:00
Nuno Cruces
ac6dd1aa5f Updated dependencies. 2023-09-18 15:22:11 +01:00
Nuno Cruces
b1495bd6cb Build tags, docs. 2023-09-18 15:11:05 +01:00
Nuno Cruces
2d91760295 Portability. 2023-09-18 12:44:18 +01:00
Nuno Cruces
38d4254bc4 Update README.md 2023-09-15 15:37:57 +01:00
Nuno Cruces
c0aa734786 binaryen-version_116. 2023-09-15 15:10:08 +01:00
Nuno Cruces
fa845dbd3d Run test in all platforms. 2023-09-12 15:30:43 +01:00
Nuno Cruces
fed315ab79 Update go.yml 2023-09-12 15:28:11 +01:00
Nuno Cruces
726d7316f7 Update README.md 2023-09-12 00:00:32 +01:00
Nuno Cruces
ddb387b021 Updated dependencies. 2023-09-11 23:54:22 +01:00
Nuno Cruces
d0f19507f5 SQLite 3.43.1. 2023-09-11 23:48:38 +01:00
Nuno Cruces
9d997552ad Pearson correlation. 2023-09-02 00:48:55 +01:00
Nuno Cruces
9d75c39dcc Update README.md 2023-09-01 16:01:42 +01:00
Nuno Cruces
746a84965e Covariance. 2023-09-01 02:38:57 +01:00
Nuno Cruces
312d3b58f2 Statistics functions. 2023-09-01 01:23:25 +01:00
Nuno Cruces
b71cd295c2 Updated dependencies. 2023-08-25 09:56:09 +01:00
Nuno Cruces
5b3b61a304 SQLite 3.43.0. 2023-08-24 18:56:23 +01:00
Nuno Cruces
d661d15723 wazero v1.5.0. 2023-08-24 18:56:10 +01:00
Nuno Cruces
1e38165ad0 Timer resolution. 2023-08-20 03:12:55 +01:00
Nuno Cruces
58a32d7c9d Update GORM. 2023-08-20 00:56:08 +01:00
Nuno Cruces
6765e883c1 Register collation. 2023-08-10 13:39:52 +01:00
Nuno Cruces
18fc608433 Embed database as string. 2023-08-10 13:23:54 +01:00
Nuno Cruces
77f37893b9 Driver connector. 2023-08-10 13:18:13 +01:00
Nuno Cruces
f1e36e2581 Updated dependencies. 2023-08-09 16:30:32 +01:00
Nuno Cruces
772b9153c7 Use clear builtin. 2023-08-09 16:16:45 +01:00
Nuno Cruces
4b280a3a7e Updated dependencies. 2023-08-09 15:22:48 +01:00
Nuno Cruces
19b6098bf6 Update go.yml (#28) 2023-08-05 01:12:16 +01:00
dependabot[bot]
2aa685320f Bump golang.org/x/text from 0.11.0 to 0.12.0 (#26)
Bumps [golang.org/x/text](https://github.com/golang/text) from 0.11.0 to 0.12.0.
- [Release notes](https://github.com/golang/text/releases)
- [Commits](https://github.com/golang/text/compare/v0.11.0...v0.12.0)

---
updated-dependencies:
- dependency-name: golang.org/x/text
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-08-05 00:36:56 +01:00
dependabot[bot]
9941be05c2 Bump golang.org/x/sys from 0.10.0 to 0.11.0 (#27)
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.10.0 to 0.11.0.
- [Commits](https://github.com/golang/sys/compare/v0.10.0...v0.11.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-08-05 00:35:11 +01:00
Nuno Cruces
a0a9ab7737 Avoid unnecessary alloc. 2023-08-04 14:12:36 +01:00
Nuno Cruces
a77727a1ce Port script. 2023-07-31 15:27:10 +01:00
Nuno Cruces
47fe032078 Updated dependencies. 2023-07-26 12:42:18 +01:00
Nuno Cruces
bdfe279444 Soundex. 2023-07-26 02:02:39 +01:00
dependabot[bot]
a86937a54e Bump github.com/tetratelabs/wazero from 1.3.0 to 1.3.1
Bumps [github.com/tetratelabs/wazero](https://github.com/tetratelabs/wazero) from 1.3.0 to 1.3.1.
- [Release notes](https://github.com/tetratelabs/wazero/releases)
- [Commits](https://github.com/tetratelabs/wazero/compare/v1.3.0...v1.3.1)

---
updated-dependencies:
- dependency-name: github.com/tetratelabs/wazero
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-07-25 08:02:20 +01:00
Nuno Cruces
6ef422fbde Unicode tests. 2023-07-13 12:19:32 +01:00
Nuno Cruces
ff0cb6fb88 Unicode tests, fixes. 2023-07-12 13:39:07 +01:00
Nuno Cruces
72db90efdf Unicode. 2023-07-11 16:34:15 +01:00
Nuno Cruces
5a3fdef3c5 wazero v1.3.0. 2023-07-11 12:30:39 +01:00
dependabot[bot]
ff34b0cae1 Bump golang.org/x/text from 0.10.0 to 0.11.0
Bumps [golang.org/x/text](https://github.com/golang/text) from 0.10.0 to 0.11.0.
- [Release notes](https://github.com/golang/text/releases)
- [Commits](https://github.com/golang/text/compare/v0.10.0...v0.11.0)

---
updated-dependencies:
- dependency-name: golang.org/x/text
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-07-04 23:55:17 +01:00
Nuno Cruces
f064492bb1 Updated dependencies. 2023-07-04 19:55:11 +01:00
Nuno Cruces
1427d30541 Updated dependencies. 2023-07-04 19:48:55 +01:00
Nuno Cruces
d3730341f0 Unknown collations. 2023-07-04 11:16:29 +01:00
Nuno Cruces
78ac2386f6 Refactor. 2023-07-04 02:29:38 +01:00
Nuno Cruces
632ea933b3 Function aux data. 2023-07-04 02:18:03 +01:00
Nuno Cruces
0f7fa6ebc9 Tests. 2023-07-03 18:28:46 +01:00
Nuno Cruces
6f7f776488 Refactor. 2023-07-03 17:42:53 +01:00
Nuno Cruces
f6d7c5e9c5 Refactor. 2023-07-03 17:08:16 +01:00
Nuno Cruces
1cc7ecfe8d Custom aggregate functions. 2023-07-03 15:45:16 +01:00
Nuno Cruces
3844e81404 Custom aggregate functions. 2023-07-01 15:19:45 +01:00
Nuno Cruces
fec1f8d32a Custom scalar functions. 2023-07-01 00:16:42 +01:00
Nuno Cruces
31572e6095 Fix nil/zero handles. 2023-06-30 17:09:01 +01:00
Nuno Cruces
4aee38b957 Error handling. 2023-06-30 12:25:07 +01:00
Nuno Cruces
232a7705b5 Wrap context. 2023-06-30 11:48:54 +01:00
Nuno Cruces
a6c2fccd74 Wrap value. 2023-06-30 10:45:16 +01:00
Nuno Cruces
6a982559cd Custom collating sequences. 2023-06-30 02:49:21 +01:00
Nuno Cruces
c7904d30de Refactor file handles. 2023-06-30 01:52:18 +01:00
Nuno Cruces
ce4386604d GORM v1.25.1. 2023-06-29 20:06:56 +01:00
Nuno Cruces
26b62c520d Towards SQL functions. 2023-06-29 14:21:59 +01:00
Nuno Cruces
738714bf32 Fix WAL. 2023-06-26 13:31:42 +01:00
Nuno Cruces
41b020bafc go-sqlite3 v0.8.0. 2023-06-16 17:21:50 +01:00
Nuno Cruces
d0e720272b Optimization flags. 2023-06-15 15:57:39 +01:00
Nuno Cruces
76171da12b go-sqlite3 v0.7.3. 2023-06-15 03:56:02 +01:00
Nuno Cruces
dcc845d684 wazero v1.2.1. 2023-06-15 03:43:25 +01:00
dependabot[bot]
f1b42c26d5 Bump golang.org/x/sync from 0.2.0 to 0.3.0
Bumps [golang.org/x/sync](https://github.com/golang/sync) from 0.2.0 to 0.3.0.
- [Commits](https://github.com/golang/sync/compare/v0.2.0...v0.3.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sync
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-06-15 00:13:43 +01:00
dependabot[bot]
1e94407ae7 Bump golang.org/x/sys from 0.8.0 to 0.9.0
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.8.0 to 0.9.0.
- [Commits](https://github.com/golang/sys/compare/v0.8.0...v0.9.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-06-13 00:31:02 +01:00
Nuno Cruces
eb8d9b95fd Consistent lock timeouts. 2023-06-12 13:04:37 +01:00
Nuno Cruces
04037a75ed GORM driver sync. 2023-06-12 10:56:03 +01:00
Nuno Cruces
2472ceb0a0 Fix GORM module name. 2023-06-07 12:40:18 +01:00
115 changed files with 4000 additions and 1027 deletions

29
.github/workflows/bsd.yml vendored Normal file
View File

@@ -0,0 +1,29 @@
name: BSD
on:
workflow_dispatch:
jobs:
test:
runs-on: macos-12
steps:
- uses: actions/checkout@v3
with:
lfs: 'true'
- name: Set up
uses: actions/setup-go@v4
with:
go-version: stable
- name: Build
run: GOOS=freebsd go test -c ./...
- name: Test
uses: cross-platform-actions/action@v0.20.0
with:
operating_system: freebsd
version: '13.2'
sync_files: runner-to-vm
run: find . -name '*.test' -maxdepth 1 -exec {} -test.v \;

22
.github/workflows/cross.sh vendored Executable file
View File

@@ -0,0 +1,22 @@
#!/usr/bin/env bash
echo android ; GOOS=android GOARCH=amd64 go build .
echo darwin ; GOOS=darwin GOARCH=amd64 go build .
echo dragonfly ; GOOS=dragonfly GOARCH=amd64 go build .
echo freebsd ; GOOS=freebsd GOARCH=amd64 go build .
echo illumos ; GOOS=illumos GOARCH=amd64 go build .
echo ios ; GOOS=ios GOARCH=amd64 go build .
echo linux ; GOOS=linux GOARCH=amd64 go build .
echo netbsd ; GOOS=netbsd GOARCH=amd64 go build .
echo openbsd ; GOOS=openbsd GOARCH=amd64 go build .
echo plan9 ; GOOS=plan9 GOARCH=amd64 go build .
echo solaris ; GOOS=solaris GOARCH=amd64 go build .
echo windows ; GOOS=windows GOARCH=amd64 go build .
# echo aix ; GOOS=aix GOARCH=ppc64 go build .
echo js ; GOOS=js GOARCH=wasm go build .
echo wasip1 ; GOOS=wasip1 GOARCH=wasm go build .
echo darwin-flock ; GOOS=darwin GOARCH=amd64 go build -tags sqlite3_flock .
echo darwin-nosys ; GOOS=darwin GOARCH=amd64 go build -tags sqlite3_nosys .
echo linux-nosys ; GOOS=linux GOARCH=amd64 go build -tags sqlite3_nosys .
echo windows-nosys ; GOOS=windows GOARCH=amd64 go build -tags sqlite3_nosys .
echo freebsd-nosys ; GOOS=freebsd GOARCH=amd64 go build -tags sqlite3_nosys .

21
.github/workflows/cross.yml vendored Normal file
View File

@@ -0,0 +1,21 @@
name: Cross compile
on:
workflow_dispatch:
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
lfs: 'true'
- name: Set up
uses: actions/setup-go@v4
with:
go-version: stable
- name: Build
run: .github/workflows/cross.sh

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
lfs: 'true'
@@ -39,7 +39,6 @@ jobs:
- name: Vet
run: go vet ./...
continue-on-error: true
- name: Build
run: go build -v ./...
@@ -47,16 +46,19 @@ jobs:
- name: Test
run: go test -v ./...
- name: Test no locks
run: go test -v -tags sqlite3_nosys ./tests -run TestDB_nolock
- name: Test BSD locks
run: go test -v -tags sqlite3_bsd ./...
run: go test -v -tags sqlite3_flock ./...
if: matrix.os == 'macos-latest'
- name: Coverage report
uses: ncruces/go-coverage-report@v0
with:
chart: 'true'
amend: 'true'
chart: true
amend: true
reuse-go: true
if: |
matrix.os == 'ubuntu-latest' &&
github.event_name == 'push'
continue-on-error: true
github.event_name == 'push' &&
matrix.os == 'ubuntu-latest'

View File

@@ -7,18 +7,26 @@
Go module `github.com/ncruces/go-sqlite3` wraps a [WASM](https://webassembly.org/) build of [SQLite](https://sqlite.org/),
and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings.
- Package [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
wraps the [C SQLite API](https://www.sqlite.org/cintro.html)
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
- Package [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver)
provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
- Package [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
embeds a build of SQLite into your application.
- Package [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
- Package [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
provides a [GORM](https://gorm.io) driver.
- [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
wraps the [C SQLite API](https://www.sqlite.org/cintro.html)
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
- [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver)
provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
- [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
embeds a build of SQLite into your application.
- [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
- [`github.com/ncruces/go-sqlite3/vfs/memdb`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/memdb)
implements an in-memory VFS.
- [`github.com/ncruces/go-sqlite3/vfs/readervfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/readervfs)
implements a VFS for immutable databases.
- [`github.com/ncruces/go-sqlite3/ext/unicode`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/unicode)
registers Unicode aware functions.
- [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
registers [statistics functions](https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html).
- [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
provides a [GORM](https://gorm.io) driver.
### Caveats
@@ -31,57 +39,64 @@ This has benefits, but also comes with some drawbacks.
Because WASM does not support shared memory,
[WAL](https://www.sqlite.org/wal.html) support is [limited](https://www.sqlite.org/wal.html#noshm).
To work around this limitation, SQLite is compiled with
[`SQLITE_DEFAULT_LOCKING_MODE=1`](https://www.sqlite.org/compile.html#default_locking_mode),
making `EXCLUSIVE` the default locking mode.
For non-WAL databases, `NORMAL` locking mode can be activated with
[`PRAGMA locking_mode=NORMAL`](https://www.sqlite.org/pragma.html#pragma_locking_mode).
To work around this limitation, SQLite is [patched](sqlite3/locking_mode.patch)
to always use `EXCLUSIVE` locking mode for WAL databases.
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
the `database/sql` driver defaults to `NORMAL` locking mode.
To open WAL databases, or use `EXCLUSIVE` locking mode,
disable connection pooling by calling
to use the [`database/sql`](https://pkg.go.dev/database/sql) driver
with WAL mode databases you should disable connection pooling by calling
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
#### POSIX Advisory Locks
#### File Locking
POSIX advisory locks, which SQLite uses, are
[broken by design](https://www.sqlite.org/src/artifact/90c4fa?ln=1073-1161).
POSIX advisory locks, which SQLite uses on Unix, are
[broken by design](https://www.sqlite.org/src/artifact/2e8b12?ln=1073-1161).
On Linux, macOS and illumos, this module uses
[OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
to synchronize access to database files.
OFD locks are fully compatible with process-associated POSIX advisory locks.
OFD locks are fully compatible with POSIX advisory locks.
On BSD Unixes, this module uses
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2).
BSD locks may _not_ be compatible with process-associated POSIX advisory locks.
On BSD Unixes, BSD locks are fully compatible with POSIX advisory locks.
On Windows, this module uses `LockFile`, `LockFileEx`, and `UnlockFile`,
like SQLite.
On all other platforms, file locking is not supported, and you must use
[`nolock=1`](https://www.sqlite.org/uri.html#urinolock)
to open database files.
To use the [`database/sql`](https://pkg.go.dev/database/sql) driver
with `nolock=1` you must disable connection pooling by calling
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
#### Testing
The pure Go VFS is tested by running an unmodified build of SQLite's
The pure Go VFS is tested by running SQLite's
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
on Linux, macOS and Windows.
on Linux, macOS, Windows and FreeBSD.
Performance is tested by running
[speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c).
### Roadmap
- [ ] advanced SQLite features
- [x] custom functions
- [x] nested transactions
- [x] incremental BLOB I/O
- [x] online backup
- [x] JSON support
- [ ] session extension
- [ ] custom VFSes
- [x] custom VFS API
- [x] in-memory VFS
- [x] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
- [ ] custom SQL functions
### Alternatives
- [`modernc.org/sqlite`](https://pkg.go.dev/modernc.org/sqlite)
- [`crawshaw.io/sqlite`](https://pkg.go.dev/crawshaw.io/sqlite)
- [`github.com/mattn/go-sqlite3`](https://pkg.go.dev/github.com/mattn/go-sqlite3)
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)

View File

@@ -77,7 +77,7 @@ func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string
if r == 0 {
defer c.closeDB(other)
r = c.call(c.api.errcode, uint64(dst))
return nil, c.module.error(r, dst)
return nil, c.sqlite.error(r, dst)
}
return &Backup{

View File

@@ -118,8 +118,8 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
return 0, nil
}
want := int64(1024 * 1024)
avail := b.bytes - b.offset
want := int64(65536)
if want > avail {
want = avail
}
@@ -175,8 +175,11 @@ func (b *Blob) Write(p []byte) (n int, err error) {
//
// https://www.sqlite.org/c3ref/blob_write.html
func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) {
want := int64(1024 * 1024)
avail := b.bytes - b.offset
want := int64(65536)
if l, ok := r.(*io.LimitedReader); ok && want > l.N {
want = l.N
}
if want > avail {
want = avail
}

109
conn.go
View File

@@ -2,16 +2,14 @@ package sqlite3
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"net/url"
"runtime"
"strings"
"sync/atomic"
"unsafe"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
)
// Conn is a database connection handle.
@@ -19,10 +17,9 @@ import (
//
// https://www.sqlite.org/c3ref/sqlite3.html
type Conn struct {
*module
*sqlite
interrupt context.Context
waiter chan struct{}
pending *Stmt
arena arena
@@ -39,7 +36,7 @@ func Open(filename string) (*Conn, error) {
// If none of the required flags is used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used.
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
//
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)")
//
// https://www.sqlite.org/c3ref/open.html
func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
@@ -49,21 +46,24 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
return newConn(filename, flags)
}
type connKey struct{}
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
mod, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
return nil, err
}
defer func() {
if conn == nil {
mod.close()
sqlite.close()
} else {
runtime.SetFinalizer(conn, util.Finalizer[Conn](3))
}
}()
c := &Conn{module: mod}
c := &Conn{sqlite: sqlite}
c.arena = c.newArena(1024)
c.ctx = context.WithValue(c.ctx, connKey{}, c)
c.handle, err = c.openDB(filename, flags)
if err != nil {
return nil, err
@@ -80,7 +80,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
handle := util.ReadUint32(c.mod, connPtr)
if err := c.module.error(r, handle); err != nil {
if err := c.sqlite.error(r, handle); err != nil {
c.closeDB(handle)
return 0, err
}
@@ -99,7 +99,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
c.arena.reset()
pragmaPtr := c.arena.string(pragmas.String())
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.module.error(r, handle, pragmas.String()); err != nil {
if err := c.sqlite.error(r, handle, pragmas.String()); err != nil {
if errors.Is(err, ERROR) {
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
@@ -113,7 +113,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
func (c *Conn) closeDB(handle uint32) {
r := c.call(c.api.closeZombie, uint64(handle))
if err := c.module.error(r, handle); err != nil {
if err := c.sqlite.error(r, handle); err != nil {
panic(err)
}
}
@@ -132,7 +132,6 @@ func (c *Conn) Close() error {
return nil
}
c.SetInterrupt(context.Background())
c.pending.Close()
c.pending = nil
@@ -143,7 +142,7 @@ func (c *Conn) Close() error {
c.handle = 0
runtime.SetFinalizer(c, nil)
return c.module.close()
return c.close()
}
// Exec is a convenience function that allows an application to run
@@ -240,65 +239,45 @@ func (c *Conn) Changes() int64 {
//
// https://www.sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
// Is a waiter running?
if c.waiter != nil {
c.waiter <- struct{}{} // Cancel the waiter.
<-c.waiter // Wait for it to finish.
c.waiter = nil
// Is it the same context?
if ctx == c.interrupt {
return ctx
}
// Reset the pending statement.
if c.pending != nil {
// An uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
if c.pending == nil {
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
} else {
c.pending.Reset()
}
old = c.interrupt
c.interrupt = ctx
// Remove the handler if the context can't be canceled.
if ctx == nil || ctx.Done() == nil {
c.call(c.api.progressHandler, uint64(c.handle), 0)
return old
}
// Creating an uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
if c.pending == nil {
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
}
c.pending.Step()
// Don't create the goroutine if we're already interrupted.
// This happens frequently while restoring to a previously interrupted state.
if c.checkInterrupt() {
return old
}
waiter := make(chan struct{})
c.waiter = waiter
go func() {
select {
case <-waiter: // Waiter was cancelled.
break
case <-ctx.Done(): // Done was closed.
const isInterruptedOffset = 280
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
// Wait for the next call to SetInterrupt.
<-waiter
}
// Signal that the waiter has finished.
waiter <- struct{}{}
}()
c.call(c.api.progressHandler, uint64(c.handle), 100)
return old
}
func (c *Conn) checkInterrupt() bool {
if c.interrupt == nil || c.interrupt.Err() == nil {
return false
func callbackProgress(ctx context.Context, mod api.Module, _ uint32) uint32 {
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
if c.interrupt != nil && c.interrupt.Err() != nil {
return 1
}
}
return 0
}
func (c *Conn) checkInterrupt() {
if c.interrupt != nil && c.interrupt.Err() != nil {
c.call(c.api.interrupt, uint64(c.handle))
}
const isInterruptedOffset = 280
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
return true
}
// Pragma executes a PRAGMA statement and returns any results.
@@ -319,7 +298,7 @@ func (c *Conn) Pragma(str string) ([]string, error) {
}
func (c *Conn) error(rc uint64, sql ...string) error {
return c.module.error(rc, c.handle, sql...)
return c.sqlite.error(rc, c.handle, sql...)
}
// DriverConn is implemented by the SQLite [database/sql] driver connection.
@@ -331,15 +310,5 @@ func (c *Conn) error(rc uint64, sql ...string) error {
// [online backup]: https://www.sqlite.org/backup.html
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
type DriverConn interface {
driver.Conn
driver.ConnBeginTx
driver.ExecerContext
driver.ConnPrepareContext
SetInterrupt(ctx context.Context) (old context.Context)
Savepoint() Savepoint
Backup(srcDB, dstURI string) error
Restore(dstDB, srcURI string) error
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
Raw() *Conn
}

View File

@@ -167,6 +167,18 @@ const (
PREPARE_NO_VTAB PrepareFlag = 0x04
)
// FunctionFlag is a flag that can be passed to [Conn.PrepareFlags].
//
// https://www.sqlite.org/c3ref/c_deterministic.html
type FunctionFlag uint32
const (
DETERMINISTIC FunctionFlag = 0x000000800
DIRECTONLY FunctionFlag = 0x000080000
SUBTYPE FunctionFlag = 0x000100000
INNOCUOUS FunctionFlag = 0x000200000
)
// Datatype is a fundamental datatype of SQLite.
//
// https://www.sqlite.org/c3ref/c_blob.html
@@ -182,18 +194,18 @@ const (
// String implements the [fmt.Stringer] interface.
func (t Datatype) String() string {
const name = "INTEGERFLOATTEXTBLOBNULL"
const name = "INTEGERFLOATEXTBLOBNULL"
switch t {
case INTEGER:
return name[0:7]
case FLOAT:
return name[7:12]
case TEXT:
return name[12:16]
return name[11:15]
case BLOB:
return name[16:20]
return name[15:19]
case NULL:
return name[20:24]
return name[19:23]
}
return strconv.FormatUint(uint64(t), 10)
}

200
context.go Normal file
View File

@@ -0,0 +1,200 @@
package sqlite3
import (
"encoding/json"
"errors"
"math"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Context is the context in which an SQL function executes.
// An SQLite [Context] is in no way related to a Go [context.Context].
//
// https://www.sqlite.org/c3ref/context.html
type Context struct {
*sqlite
handle uint32
}
// SetAuxData saves metadata for argument n of the function.
//
// https://www.sqlite.org/c3ref/get_auxdata.html
func (c Context) SetAuxData(n int, data any) {
ptr := util.AddHandle(c.ctx, data)
c.call(c.api.setAuxData, uint64(c.handle), uint64(n), uint64(ptr))
}
// GetAuxData returns metadata for argument n of the function.
//
// https://www.sqlite.org/c3ref/get_auxdata.html
func (c Context) GetAuxData(n int) any {
ptr := uint32(c.call(c.api.getAuxData, uint64(c.handle), uint64(n)))
return util.GetHandle(c.ctx, ptr)
}
// ResultBool sets the result of the function to a bool.
// SQLite does not have a separate boolean storage class.
// Instead, boolean values are stored as integers 0 (false) and 1 (true).
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultBool(value bool) {
var i int64
if value {
i = 1
}
c.ResultInt64(i)
}
// ResultInt sets the result of the function to an int.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultInt(value int) {
c.ResultInt64(int64(value))
}
// ResultInt64 sets the result of the function to an int64.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultInt64(value int64) {
c.call(c.api.resultInteger,
uint64(c.handle), uint64(value))
}
// ResultFloat sets the result of the function to a float64.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultFloat(value float64) {
c.call(c.api.resultFloat,
uint64(c.handle), math.Float64bits(value))
}
// ResultText sets the result of the function to a string.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultText(value string) {
ptr := c.newString(value)
c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(value)),
uint64(c.api.destructor), _UTF8)
}
// ResultBlob sets the result of the function to a []byte.
// Returning a nil slice is the same as calling [Context.ResultNull].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultBlob(value []byte) {
ptr := c.newBytes(value)
c.call(c.api.resultBlob,
uint64(c.handle), uint64(ptr), uint64(len(value)),
uint64(c.api.destructor))
}
// BindZeroBlob sets the result of the function to a zero-filled, length n BLOB.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultZeroBlob(n int64) {
c.call(c.api.resultZeroBlob,
uint64(c.handle), uint64(n))
}
// ResultNull sets the result of the function to NULL.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultNull() {
c.call(c.api.resultNull,
uint64(c.handle))
}
// ResultTime sets the result of the function to a [time.Time].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultTime(value time.Time, format TimeFormat) {
if format == TimeFormatDefault {
c.resultRFC3339Nano(value)
return
}
switch v := format.Encode(value).(type) {
case string:
c.ResultText(v)
case int64:
c.ResultInt64(v)
case float64:
c.ResultFloat(v)
default:
panic(util.AssertErr())
}
}
func (c Context) resultRFC3339Nano(value time.Time) {
const maxlen = uint64(len(time.RFC3339Nano)) + 5
ptr := c.new(maxlen)
buf := util.View(c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(buf)),
uint64(c.api.destructor), _UTF8)
}
// ResultJSON sets the result of the function to the JSON encoding of value.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultJSON(value any) {
data, err := json.Marshal(value)
if err != nil {
c.ResultError(err)
}
ptr := c.newBytes(data)
c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(data)),
uint64(c.api.destructor))
}
// ResultValue sets the result of the function a copy of [Value].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultValue(value Value) {
if value.sqlite != c.sqlite {
c.ResultError(MISUSE)
}
c.call(c.api.resultValue,
uint64(c.handle), uint64(value.handle))
}
// ResultError sets the result of the function an error.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultError(err error) {
if errors.Is(err, NOMEM) {
c.call(c.api.resultErrorMem, uint64(c.handle))
return
}
if errors.Is(err, TOOBIG) {
c.call(c.api.resultErrorBig, uint64(c.handle))
return
}
str := err.Error()
ptr := c.newString(str)
c.call(c.api.resultError,
uint64(c.handle), uint64(ptr), uint64(len(str)))
c.free(ptr)
var code uint64
var ecode ErrorCode
var xcode xErrorCode
switch {
case errors.As(err, &xcode):
code = uint64(xcode)
case errors.As(err, &ecode):
code = uint64(ecode)
}
if code != 0 {
c.call(c.api.resultErrorCode,
uint64(c.handle), code)
}
}

View File

@@ -14,10 +14,9 @@
//
// [PRAGMA] statements can be specified using "_pragma":
//
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)")
//
// If no PRAGMAs are specified, a busy timeout of 1 minute
// and normal locking mode are used.
// If no PRAGMAs are specified, a busy timeout of 1 minute is set.
//
// Order matters:
// busy timeout and locking mode should be the first PRAGMAs set, in that order.
@@ -31,6 +30,8 @@ import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"io"
"net/url"
@@ -41,65 +42,117 @@ import (
"github.com/ncruces/go-sqlite3/internal/util"
)
// This variable can be replaced with -ldflags:
//
// go build -ldflags="-X github.com/ncruces/go-sqlite3.driverName=sqlite"
var driverName = "sqlite3"
func init() {
sql.Register("sqlite3", sqlite{})
if driverName != "" {
sql.Register(driverName, sqlite{})
}
}
// Open opens the SQLite database specified by dataSourceName as a [database/sql.DB].
//
// The init function is called by the driver on new connections.
// The conn can be used to execute queries, register functions, etc.
// Any error return closes the conn and passes the error to [database/sql].
func Open(dataSourceName string, init func(*sqlite3.Conn) error) (*sql.DB, error) {
c, err := newConnector(dataSourceName, init)
if err != nil {
return nil, err
}
return sql.OpenDB(c), nil
}
type sqlite struct{}
func (sqlite) Open(name string) (_ driver.Conn, err error) {
var c conn
c.Conn, err = sqlite3.Open(name)
func (sqlite) Open(name string) (driver.Conn, error) {
c, err := newConnector(name, nil)
if err != nil {
return nil, err
}
return c.Connect(context.Background())
}
c.txBegin = "BEGIN"
var pragmas []string
func (sqlite) OpenConnector(name string) (driver.Connector, error) {
return newConnector(name, nil)
}
func newConnector(name string, init func(*sqlite3.Conn) error) (*connector, error) {
c := connector{name: name, init: init}
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
query, _ := url.ParseQuery(after)
switch s := query.Get("_txlock"); s {
case "":
c.txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
c.txBegin = "BEGIN " + s
default:
c.Close()
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
query, err := url.ParseQuery(after)
if err != nil {
return nil, err
}
pragmas = query["_pragma"]
c.txlock = query.Get("_txlock")
c.pragmas = len(query["_pragma"]) > 0
}
}
if len(pragmas) == 0 {
err := c.Conn.Exec(`
PRAGMA busy_timeout=60000;
PRAGMA locking_mode=normal;
`)
return &c, nil
}
type connector struct {
init func(*sqlite3.Conn) error
name string
txlock string
pragmas bool
}
func (n *connector) Driver() driver.Driver {
return sqlite{}
}
func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
var c conn
c.Conn, err = sqlite3.Open(n.name)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
c.Close()
return nil, err
}
c.reusable = true
} else {
s, _, err := c.Conn.Prepare(`
SELECT * FROM
PRAGMA_locking_mode,
PRAGMA_query_only;
`)
}()
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
switch n.txlock {
case "":
c.txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
c.txBegin = "BEGIN " + n.txlock
default:
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", n.txlock)
}
if !n.pragmas {
err = c.Conn.Exec(`PRAGMA busy_timeout=60000`)
if err != nil {
c.Close()
return nil, err
}
if s.Step() {
c.reusable = s.ColumnText(0) == "normal"
c.readOnly = s.ColumnRawText(1)[0] // 0 or 1
}
if n.init != nil {
err = n.init(c.Conn)
if err != nil {
return nil, err
}
}
if n.pragmas || n.init != nil {
s, _, err := c.Conn.Prepare(`PRAGMA query_only`)
if err != nil {
return nil, err
}
if s.Step() && s.ColumnBool(0) {
c.readOnly = '1'
} else {
c.readOnly = '0'
}
err = s.Close()
if err != nil {
c.Close()
return nil, err
}
}
@@ -111,20 +164,19 @@ type conn struct {
txBegin string
txCommit string
txRollback string
reusable bool
readOnly byte
}
var (
// Ensure these interfaces are implemented:
_ driver.ExecerContext = &conn{}
_ driver.ConnBeginTx = &conn{}
_ driver.Validator = &conn{}
_ sqlite3.DriverConn = &conn{}
_ driver.ConnPrepareContext = &conn{}
_ driver.ExecerContext = &conn{}
_ driver.ConnBeginTx = &conn{}
_ sqlite3.DriverConn = &conn{}
)
func (c *conn) IsValid() bool {
return c.reusable
func (c *conn) Raw() *sqlite3.Conn {
return c.Conn
}
func (c *conn) Begin() (driver.Tx, error) {
@@ -140,10 +192,10 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
txBegin = `
BEGIN deferred;
PRAGMA query_only=on`
c.txCommit = `
c.txRollback = `
ROLLBACK;
PRAGMA query_only=` + string(c.readOnly)
c.txRollback = c.txCommit
c.txCommit = c.txRollback
}
switch opts.Isolation {
@@ -167,14 +219,20 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
func (c *conn) Commit() error {
err := c.Conn.Exec(c.txCommit)
if err != nil && !c.GetAutocommit() {
if err != nil && !c.Conn.GetAutocommit() {
c.Rollback()
}
return err
}
func (c *conn) Rollback() error {
return c.Conn.Exec(c.txRollback)
err := c.Conn.Exec(c.txRollback)
if errors.Is(err, sqlite3.INTERRUPT) {
old := c.Conn.SetInterrupt(context.Background())
defer c.Conn.SetInterrupt(old)
err = c.Conn.Exec(c.txRollback)
}
return err
}
func (c *conn) Prepare(query string) (driver.Stmt, error) {
@@ -222,6 +280,10 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
return newResult(c.Conn), nil
}
func (*conn) CheckNamedValue(arg *driver.NamedValue) error {
return nil
}
type stmt struct {
Stmt *sqlite3.Stmt
Conn *sqlite3.Conn
@@ -259,13 +321,14 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
}
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
// Use QueryContext to setup bindings.
// No need to close rows: that simply resets the statement, exec does the same.
_, err := s.QueryContext(ctx, args)
err := s.setupBindings(args)
if err != nil {
return nil, err
}
old := s.Conn.SetInterrupt(ctx)
defer s.Conn.SetInterrupt(old)
err = s.Stmt.Exec()
if err != nil {
return nil, err
@@ -275,10 +338,18 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
}
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
err := s.Stmt.ClearBindings()
err := s.setupBindings(args)
if err != nil {
return nil, err
}
return &rows{ctx, s.Stmt, s.Conn}, nil
}
func (s *stmt) setupBindings(args []driver.NamedValue) error {
err := s.Stmt.ClearBindings()
if err != nil {
return err
}
var ids [3]int
for _, arg := range args {
@@ -311,6 +382,8 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
err = s.Stmt.BindZeroBlob(id, int64(a))
case time.Time:
err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
case json.Marshaler:
err = s.Stmt.BindJSON(id, a)
case nil:
err = s.Stmt.BindNull(id)
default:
@@ -318,17 +391,16 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
}
}
if err != nil {
return nil, err
return err
}
}
return &rows{ctx, s.Stmt, s.Conn}, nil
return nil
}
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) {
case bool, int, int64, float64, string, []byte,
sqlite3.ZeroBlob, time.Time, nil:
sqlite3.ZeroBlob, time.Time, json.Marshaler, nil:
return nil
default:
return driver.ErrSkip
@@ -407,11 +479,7 @@ func (r *rows) Next(dest []driver.Value) error {
case sqlite3.TEXT:
dest[i] = stringOrTime(r.Stmt.ColumnRawText(i))
case sqlite3.NULL:
if buf, ok := dest[i].([]byte); ok {
dest[i] = buf[0:0]
} else {
dest[i] = nil
}
dest[i] = nil
default:
panic(util.AssertErr())
}

View File

@@ -47,7 +47,7 @@ func ExampleDriverConn() {
}
err = conn.Raw(func(driverConn any) error {
conn := driverConn.(sqlite3.DriverConn)
conn := driverConn.(sqlite3.DriverConn).Raw()
savept := conn.Savepoint()
defer savept.Release(&err)

View File

@@ -1,6 +1,6 @@
# Embeddable WASM build of SQLite
This folder includes an embeddable WASM build of SQLite 3.42.0 for use with
This folder includes an embeddable WASM build of SQLite 3.44.0 for use with
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
The following optional features are compiled in:
@@ -9,6 +9,7 @@ The following optional features are compiled in:
- [JSON](https://www.sqlite.org/json1.html)
- [R*Tree](https://www.sqlite.org/rtree.html)
- [GeoPoly](https://www.sqlite.org/geopoly.html)
- [soundex](https://www.sqlite.org/lang_corefunc.html#soundex)
- [base64](https://github.com/sqlite/sqlite/blob/master/ext/misc/base64.c)
- [decimal](https://github.com/sqlite/sqlite/blob/master/ext/misc/decimal.c)
- [regexp](https://github.com/sqlite/sqlite/blob/master/ext/misc/regexp.c)

View File

@@ -4,24 +4,27 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../
BINARYEN="$ROOT/tools/binaryen-version_113/bin"
BINARYEN="$ROOT/tools/binaryen-version_116/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
-o sqlite3.wasm "$ROOT/sqlite3/main.c" \
-I"$ROOT/sqlite3" \
-mexec-model=reactor \
-mmutable-globals \
-msimd128 -mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-fno-stack-protector -fno-stack-clash-protection \
-Wl,--initial-memory=327680 \
-Wl,--stack-first \
-Wl,--import-undefined \
-D_HAVE_SQLITE_CONFIG_H \
$(awk '{print "-Wl,--export="$0}' exports.txt)
trap 'rm -f sqlite3.tmp' EXIT
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
"$BINARYEN/wasm-opt" -g -O2 sqlite3.tmp -o sqlite3.wasm \
--enable-multivalue --enable-mutable-globals \
"$BINARYEN/wasm-opt" -g --strip -c -O3 \
sqlite3.tmp -o sqlite3.wasm \
--enable-simd --enable-mutable-globals --enable-multivalue \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext

View File

@@ -13,6 +13,8 @@ sqlite3_finalize
sqlite3_reset
sqlite3_step
sqlite3_exec
sqlite3_interrupt
sqlite3_progress_handler_go
sqlite3_clear_bindings
sqlite3_bind_parameter_count
sqlite3_bind_parameter_index
@@ -33,10 +35,10 @@ sqlite3_column_blob
sqlite3_column_bytes
sqlite3_blob_open
sqlite3_blob_close
sqlite3_blob_reopen
sqlite3_blob_bytes
sqlite3_blob_read
sqlite3_blob_write
sqlite3_blob_reopen
sqlite3_backup_init
sqlite3_backup_step
sqlite3_backup_finish
@@ -46,4 +48,30 @@ sqlite3_uri_parameter
sqlite3_uri_key
sqlite3_changes64
sqlite3_last_insert_rowid
sqlite3_get_autocommit
sqlite3_get_autocommit
sqlite3_anycollseq_init
sqlite3_create_collation_go
sqlite3_create_function_go
sqlite3_create_aggregate_function_go
sqlite3_create_window_function_go
sqlite3_aggregate_context
sqlite3_user_data
sqlite3_set_auxdata_go
sqlite3_get_auxdata
sqlite3_value_type
sqlite3_value_int64
sqlite3_value_double
sqlite3_value_text
sqlite3_value_blob
sqlite3_value_bytes
sqlite3_result_null
sqlite3_result_int64
sqlite3_result_double
sqlite3_result_text64
sqlite3_result_blob64
sqlite3_result_zeroblob64
sqlite3_result_value
sqlite3_result_error
sqlite3_result_error_code
sqlite3_result_error_nomem
sqlite3_result_error_toobig

Binary file not shown.

View File

@@ -68,6 +68,19 @@ func (e *Error) Is(err error) bool {
return false
}
// As converts this error to an [ErrorCode] or [ExtendedErrorCode].
func (e *Error) As(err any) bool {
switch c := err.(type) {
case *ErrorCode:
*c = e.Code()
return true
case *ExtendedErrorCode:
*c = e.ExtendedCode()
return true
}
return false
}
// Temporary returns true for [BUSY] errors.
func (e *Error) Temporary() bool {
return e.Code() == BUSY
@@ -104,6 +117,15 @@ func (e ExtendedErrorCode) Is(err error) bool {
return ok && c == ErrorCode(e)
}
// As converts this error to an [ErrorCode].
func (e ExtendedErrorCode) As(err any) bool {
c, ok := err.(*ErrorCode)
if ok {
*c = ErrorCode(e)
}
return ok
}
// Temporary returns true for [BUSY] errors.
func (e ExtendedErrorCode) Temporary() bool {
return ErrorCode(e) == BUSY

View File

@@ -18,22 +18,36 @@ func Test_assertErr(t *testing.T) {
func TestError(t *testing.T) {
t.Parallel()
err := Error{code: 0x8080}
if rc := err.Code(); rc != 0x80 {
t.Errorf("got %#x, want 0x80", rc)
var ecode ErrorCode
var xcode xErrorCode
err := &Error{code: 0x8080}
if !errors.As(err, &err) {
t.Fatal("want true")
}
if !errors.Is(&err, ErrorCode(0x80)) {
if ecode := err.Code(); ecode != 0x80 {
t.Errorf("got %#x, want 0x80", uint8(ecode))
}
if ok := errors.As(err, &ecode); !ok || ecode != ErrorCode(0x80) {
t.Errorf("got %#x, want 0x80", uint8(ecode))
}
if !errors.Is(err, ErrorCode(0x80)) {
t.Errorf("want true")
}
if rc := err.ExtendedCode(); rc != 0x8080 {
t.Errorf("got %#x, want 0x8080", rc)
if xcode := err.ExtendedCode(); xcode != 0x8080 {
t.Errorf("got %#x, want 0x8080", uint16(xcode))
}
if !errors.Is(&err, ExtendedErrorCode(0x8080)) {
if ok := errors.As(err, &xcode); !ok || xcode != xErrorCode(0x8080) {
t.Errorf("got %#x, want 0x8080", uint16(xcode))
}
if !errors.Is(err, xErrorCode(0x8080)) {
t.Errorf("want true")
}
if s := err.Error(); s != "sqlite3: 32896" {
t.Errorf("got %q", s)
}
if ok := errors.As(err.ExtendedCode(), &ecode); !ok || ecode != ErrorCode(0x80) {
t.Errorf("got %#x, want 0x80", uint8(ecode))
}
if !errors.Is(err.ExtendedCode(), ErrorCode(0x80)) {
t.Errorf("want true")
}

109
ext/stats/stats.go Normal file
View File

@@ -0,0 +1,109 @@
// Package stats provides aggregate functions for statistics.
//
// Functions:
// - stddev_pop: population standard deviation
// - stddev_samp: sample standard deviation
// - var_pop: population variance
// - var_samp: sample variance
// - covar_pop: population covariance
// - covar_samp: sample covariance
// - corr: correlation coefficient
//
// See: [ANSI SQL Aggregate Functions]
//
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
package stats
import "github.com/ncruces/go-sqlite3"
// Register registers statistics functions.
func Register(db *sqlite3.Conn) {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop))
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp))
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop))
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp))
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr))
}
const (
var_pop = iota
var_samp
stddev_pop
stddev_samp
corr
)
func newVariance(kind int) func() sqlite3.AggregateFunction {
return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
}
type variance struct {
kind int
welford
}
func (fn *variance) Value(ctx sqlite3.Context) {
var r float64
switch fn.kind {
case var_pop:
r = fn.var_pop()
case var_samp:
r = fn.var_samp()
case stddev_pop:
r = fn.stddev_pop()
case stddev_samp:
r = fn.stddev_samp()
}
ctx.ResultFloat(r)
}
func (fn *variance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
if a := arg[0]; a.Type() != sqlite3.NULL {
fn.enqueue(a.Float())
}
}
func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
if a := arg[0]; a.Type() != sqlite3.NULL {
fn.dequeue(a.Float())
}
}
func newCovariance(kind int) func() sqlite3.AggregateFunction {
return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
}
type covariance struct {
kind int
welford2
}
func (fn *covariance) Value(ctx sqlite3.Context) {
var r float64
switch fn.kind {
case var_pop:
r = fn.covar_pop()
case var_samp:
r = fn.covar_samp()
case corr:
r = fn.correlation()
}
ctx.ResultFloat(r)
}
func (fn *covariance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
a, b := arg[0], arg[1]
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
fn.enqueue(a.Float(), b.Float())
}
}
func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
a, b := arg[0], arg[1]
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
fn.dequeue(a.Float(), b.Float())
}
}

140
ext/stats/stats_test.go Normal file
View File

@@ -0,0 +1,140 @@
package stats
import (
"math"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestRegister_variance(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
Register(db)
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`
SELECT
sum(x), avg(x),
var_samp(x), var_pop(x),
stddev_samp(x), stddev_pop(x)
FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 40 {
t.Errorf("got %v, want 40", got)
}
if got := stmt.ColumnFloat(1); got != 10 {
t.Errorf("got %v, want 10", got)
}
if got := stmt.ColumnFloat(2); got != 30 {
t.Errorf("got %v, want 30", got)
}
if got := stmt.ColumnFloat(3); got != 22.5 {
t.Errorf("got %v, want 22.5", got)
}
if got := stmt.ColumnFloat(4); got != math.Sqrt(30) {
t.Errorf("got %v, want √30", got)
}
if got := stmt.ColumnFloat(5); got != math.Sqrt(22.5) {
t.Errorf("got %v, want √22.5", got)
}
}
{
stmt, _, err := db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
want := [...]float64{0, 4.5, 18, 0, 0}
for i := 0; stmt.Step(); i++ {
if got := stmt.ColumnFloat(0); got != want[i] {
t.Errorf("got %v, want %v", got, want[i])
}
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
t.Errorf("got %v, want %v", got, want[i])
}
}
}
}
func TestRegister_covariance(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
Register(db)
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x, y)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO data (x, y) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT
corr(x, y), covar_samp(x, y), covar_pop(x, y) FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
t.Errorf("got %v, want 0.9881049293224639", got)
}
if got := stmt.ColumnFloat(1); got != 21.25 {
t.Errorf("got %v, want 21.25", got)
}
if got := stmt.ColumnFloat(2); got != 17 {
t.Errorf("got %v, want 17", got)
}
}
{
stmt, _, err := db.Prepare(`SELECT covar_samp(x, y) OVER (ROWS 1 PRECEDING) FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
want := [...]float64{0, 10, 30, 75, 22.5}
for i := 0; stmt.Step(); i++ {
if got := stmt.ColumnFloat(0); got != want[i] {
t.Errorf("got %v, want %v", got, want[i])
}
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
t.Errorf("got %v, want %v", got, want[i])
}
}
}
}

109
ext/stats/welford.go Normal file
View File

@@ -0,0 +1,109 @@
package stats
import "math"
// Welford's algorithm with Kahan summation:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
type welford struct {
m1, m2 kahan
n uint64
}
func (w welford) average() float64 {
return w.m1.hi
}
func (w welford) var_pop() float64 {
return w.m2.hi / float64(w.n)
}
func (w welford) var_samp() float64 {
return w.m2.hi / float64(w.n-1) // Bessel's correction
}
func (w welford) stddev_pop() float64 {
return math.Sqrt(w.var_pop())
}
func (w welford) stddev_samp() float64 {
return math.Sqrt(w.var_samp())
}
func (w *welford) enqueue(x float64) {
w.n++
d1 := x - w.m1.hi - w.m1.lo
w.m1.add(d1 / float64(w.n))
d2 := x - w.m1.hi - w.m1.lo
w.m2.add(d1 * d2)
}
func (w *welford) dequeue(x float64) {
w.n--
d1 := x - w.m1.hi - w.m1.lo
w.m1.sub(d1 / float64(w.n))
d2 := x - w.m1.hi - w.m1.lo
w.m2.sub(d1 * d2)
}
type welford2 struct {
m1x, m2x kahan
m1y, m2y kahan
cov kahan
n uint64
}
func (w welford2) covar_pop() float64 {
return w.cov.hi / float64(w.n)
}
func (w welford2) covar_samp() float64 {
return w.cov.hi / float64(w.n-1) // Bessel's correction
}
func (w welford2) correlation() float64 {
return w.cov.hi / math.Sqrt(w.m2x.hi*w.m2y.hi)
}
func (w *welford2) enqueue(x, y float64) {
w.n++
d1x := x - w.m1x.hi - w.m1x.lo
d1y := y - w.m1y.hi - w.m1y.lo
w.m1x.add(d1x / float64(w.n))
w.m1y.add(d1y / float64(w.n))
d2x := x - w.m1x.hi - w.m1x.lo
d2y := y - w.m1y.hi - w.m1y.lo
w.m2x.add(d1x * d2x)
w.m2y.add(d1y * d2y)
w.cov.add(d1x * d2y)
}
func (w *welford2) dequeue(x, y float64) {
w.n--
d1x := x - w.m1x.hi - w.m1x.lo
d1y := y - w.m1y.hi - w.m1y.lo
w.m1x.sub(d1x / float64(w.n))
w.m1y.sub(d1y / float64(w.n))
d2x := x - w.m1x.hi - w.m1x.lo
d2y := y - w.m1y.hi - w.m1y.lo
w.m2x.sub(d1x * d2x)
w.m2y.sub(d1y * d2y)
w.cov.sub(d1x * d2y)
}
type kahan struct{ hi, lo float64 }
func (k *kahan) add(x float64) {
y := k.lo + x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}
func (k *kahan) sub(x float64) {
y := k.lo - x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}

75
ext/stats/welford_test.go Normal file
View File

@@ -0,0 +1,75 @@
package stats
import (
"math"
"testing"
)
func Test_welford(t *testing.T) {
var s1, s2 welford
s1.enqueue(4)
s1.enqueue(7)
s1.enqueue(13)
s1.enqueue(16)
if got := s1.average(); got != 10 {
t.Errorf("got %v, want 10", got)
}
if got := s1.var_samp(); got != 30 {
t.Errorf("got %v, want 30", got)
}
if got := s1.var_pop(); got != 22.5 {
t.Errorf("got %v, want 22.5", got)
}
if got := s1.stddev_samp(); got != math.Sqrt(30) {
t.Errorf("got %v, want √30", got)
}
if got := s1.stddev_pop(); got != math.Sqrt(22.5) {
t.Errorf("got %v, want √22.5", got)
}
s1.dequeue(4)
s2.enqueue(7)
s2.enqueue(13)
s2.enqueue(16)
if s1.var_pop() != s2.var_pop() {
t.Errorf("got %v, want %v", s1, s2)
}
}
func Test_covar(t *testing.T) {
var c1, c2 welford2
c1.enqueue(3, 70)
c1.enqueue(5, 80)
c1.enqueue(2, 60)
c1.enqueue(7, 90)
c1.enqueue(4, 75)
if got := c1.covar_samp(); got != 21.25 {
t.Errorf("got %v, want 21.25", got)
}
if got := c1.covar_pop(); got != 17 {
t.Errorf("got %v, want 17", got)
}
c1.dequeue(3, 70)
c2.enqueue(5, 80)
c2.enqueue(2, 60)
c2.enqueue(7, 90)
c2.enqueue(4, 75)
if c1.covar_pop() != c2.covar_pop() {
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
}
}
func Test_correlation(t *testing.T) {
var c welford2
c.enqueue(1, 3)
c.enqueue(2, 2)
c.enqueue(3, 1)
if got := c.correlation(); got != -1 {
t.Errorf("got %v, want -1", got)
}
}

181
ext/unicode/unicode.go Normal file
View File

@@ -0,0 +1,181 @@
// Package unicode provides an alternative to the SQLite ICU extension.
//
// Like the [ICU extension], it provides Unicode aware:
// - upper() and lower() functions,
// - LIKE and REGEXP operators,
// - collation sequences.
//
// The implementation is not 100% compatible with the [ICU extension]:
// - upper() and lower() use [strings.ToUpper], [strings.ToLower] and [cases];
// - the LIKE operator follows [strings.EqualFold] rules;
// - the REGEXP operator uses Go [regex/syntax];
// - collation sequences use [collate].
//
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
//
// [ICU extension]: https://sqlite.org/src/dir/ext/icu
package unicode
import (
"bytes"
"regexp"
"strings"
"unicode/utf8"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"golang.org/x/text/cases"
"golang.org/x/text/collate"
"golang.org/x/text/language"
)
// Register registers Unicode aware functions for a database connection.
func Register(db *sqlite3.Conn) {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
db.CreateFunction("like", 2, flags, like)
db.CreateFunction("like", 3, flags, like)
db.CreateFunction("upper", 1, flags, upper)
db.CreateFunction("upper", 2, flags, upper)
db.CreateFunction("lower", 1, flags, lower)
db.CreateFunction("lower", 2, flags, lower)
db.CreateFunction("regexp", 2, flags, regex)
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
name := arg[1].Text()
if name == "" {
return
}
err := RegisterCollation(db, arg[0].Text(), name)
if err != nil {
ctx.ResultError(err)
return
}
})
}
// RegisterCollation registers a Unicode collation sequence for a database connection.
func RegisterCollation(db *sqlite3.Conn, locale, name string) error {
tag, err := language.Parse(locale)
if err != nil {
return err
}
return db.CreateCollation(name, collate.New(tag).Compare)
}
func upper(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) == 1 {
ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob()))
return
}
cs, ok := ctx.GetAuxData(1).(cases.Caser)
if !ok {
t, err := language.Parse(arg[1].Text())
if err != nil {
ctx.ResultError(err)
return
}
c := cases.Upper(t)
ctx.SetAuxData(1, c)
cs = c
}
ctx.ResultBlob(cs.Bytes(arg[0].RawBlob()))
}
func lower(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) == 1 {
ctx.ResultBlob(bytes.ToLower(arg[0].RawBlob()))
return
}
cs, ok := ctx.GetAuxData(1).(cases.Caser)
if !ok {
t, err := language.Parse(arg[1].Text())
if err != nil {
ctx.ResultError(err)
return
}
c := cases.Lower(t)
ctx.SetAuxData(1, c)
cs = c
}
ctx.ResultBlob(cs.Bytes(arg[0].RawBlob()))
}
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return
}
re = r
ctx.SetAuxData(0, re)
}
ctx.ResultBool(re.Match(arg[1].RawBlob()))
}
func like(ctx sqlite3.Context, arg ...sqlite3.Value) {
escape := rune(-1)
if len(arg) == 3 {
var size int
b := arg[2].RawBlob()
escape, size = utf8.DecodeRune(b)
if size != len(b) {
ctx.ResultError(util.ErrorString("ESCAPE expression must be a single character"))
return
}
}
type likeData struct {
*regexp.Regexp
escape rune
}
re, ok := ctx.GetAuxData(0).(likeData)
if !ok || re.escape != escape {
re = likeData{
regexp.MustCompile(like2regex(arg[0].Text(), escape)),
escape,
}
ctx.SetAuxData(0, re)
}
ctx.ResultBool(re.Match(arg[1].RawBlob()))
}
func like2regex(pattern string, escape rune) string {
var re strings.Builder
start := 0
literal := false
re.Grow(len(pattern) + 10)
re.WriteString(`(?is)\A`) // case insensitive, . matches any character
for i, r := range pattern {
if start < 0 {
start = i
}
if literal {
literal = false
continue
}
var symbol string
switch r {
case '_':
symbol = `.`
case '%':
symbol = `.*`
case escape:
literal = true
default:
continue
}
re.WriteString(regexp.QuoteMeta(pattern[start:i]))
re.WriteString(symbol)
start = -1
}
if start >= 0 {
re.WriteString(regexp.QuoteMeta(pattern[start:]))
}
re.WriteString(`\z`)
return re.String()
}

215
ext/unicode/unicode_test.go Normal file
View File

@@ -0,0 +1,215 @@
package unicode
import (
"errors"
"reflect"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestRegister(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
exec := func(fn string) string {
stmt, _, err := db.Prepare(`SELECT ` + fn)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
return stmt.ColumnText(0)
}
t.Fatal(stmt.Err())
return ""
}
Register(db)
tests := []struct {
test string
want string
}{
{`upper('hello')`, "HELLO"},
{`lower('HELLO')`, "hello"},
{`upper('привет')`, "ПРИВЕТ"},
{`lower('ПРИВЕТ')`, "привет"},
{`upper('istanbul')`, "ISTANBUL"},
{`upper('istanbul', 'tr-TR')`, "İSTANBUL"},
{`lower('Dünyanın İlk Borsası', 'tr-TR')`, "dünyanın ilk borsası"},
{`upper('Dünyanın İlk Borsası', 'tr-TR')`, "DÜNYANIN İLK BORSASI"},
{`'Hello' REGEXP 'ell'`, "1"},
{`'Hello' REGEXP 'el.'`, "1"},
{`'Hello' LIKE 'hel_'`, "0"},
{`'Hello' LIKE 'hel%'`, "1"},
{`'Hello' LIKE 'h_llo'`, "1"},
{`'Hello' LIKE 'hello'`, "1"},
{`'Привет' LIKE 'ПРИВЕТ'`, "1"},
{`'100%' LIKE '100|%' ESCAPE '|'`, "1"},
}
for _, tt := range tests {
t.Run(tt.test, func(t *testing.T) {
if got := exec(tt.test); got != tt.want {
t.Errorf("exec(%q) = %q, want %q", tt.test, got, tt.want)
}
})
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}
func TestRegister_collation(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
Register(db)
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`SELECT icu_load_collation('fr_FR', 'french')`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT word FROM words ORDER BY word COLLATE french`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
got, want := []string{}, []string{"cote", "coté", "côte", "côté", "cotée", "coter"}
for stmt.Step() {
got = append(got, stmt.ColumnText(0))
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, want) {
t.Error("not equal")
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}
func TestRegister_error(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
Register(db)
err = db.Exec(`SELECT upper('hello', 'enUS')`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT lower('hello', 'enUS')`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT 'hello' REGEXP '\'`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT 'hello' LIKE 'HELLO' ESCAPE '\\'`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT icu_load_collation('enUS', 'error')`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT icu_load_collation('enUS', '')`)
if err != nil {
t.Error(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}
func Test_like2regex(t *testing.T) {
const prefix = `(?is)\A`
const sufix = `\z`
tests := []struct {
pattern string
escape rune
want string
}{
{`a`, -1, `a`},
{`a.`, -1, `a\.`},
{`a%`, -1, `a.*`},
{`a\`, -1, `a\\`},
{`a_b`, -1, `a.b`},
{`a|b`, '|', `ab`},
{`a|_`, '|', `a_`},
}
for _, tt := range tests {
t.Run(tt.pattern, func(t *testing.T) {
want := prefix + tt.want + sufix
if got := like2regex(tt.pattern, tt.escape); got != want {
t.Errorf("like2regex() = %q, want %q", got, want)
}
})
}
}

175
func.go Normal file
View File

@@ -0,0 +1,175 @@
package sqlite3
import (
"context"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
)
// AnyCollationNeeded registers a fake collating function
// for any unknown collating sequence.
// The fake collating function works like BINARY.
//
// This can be used to load schemas that contain
// one or more unknown collating sequences.
func (c *Conn) AnyCollationNeeded() {
c.call(c.api.anyCollation, uint64(c.handle), 0, 0)
}
// CreateCollation defines a new collating sequence.
//
// https://www.sqlite.org/c3ref/create_collation.html
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
r := c.call(c.api.createCollation,
uint64(c.handle), uint64(namePtr), uint64(funcPtr))
if err := c.error(r); err != nil {
util.DelHandle(c.ctx, funcPtr)
return err
}
return nil
}
// CreateFunction defines a new scalar SQL function.
//
// https://www.sqlite.org/c3ref/create_function.html
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error {
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
r := c.call(c.api.createFunction,
uint64(c.handle), uint64(namePtr), uint64(nArg),
uint64(flag), uint64(funcPtr))
return c.error(r)
}
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
// If fn returns a [WindowFunction], then an aggregate window function is created.
// If fn returns an [io.Closer], it will be called to free resources.
//
// https://www.sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
call := c.api.createAggregate
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
if _, ok := fn().(WindowFunction); ok {
call = c.api.createWindow
}
r := c.call(call,
uint64(c.handle), uint64(namePtr), uint64(nArg),
uint64(flag), uint64(funcPtr))
return c.error(r)
}
// AggregateFunction is the interface an aggregate function should implement.
//
// https://www.sqlite.org/appfunc.html
type AggregateFunction interface {
// Step is invoked to add a row to the current window.
// The function arguments, if any, corresponding to the row being added are passed to Step.
Step(ctx Context, arg ...Value)
// Value is invoked to return the current (or final) value of the aggregate.
Value(ctx Context)
}
// WindowFunction is the interface an aggregate window function should implement.
//
// https://www.sqlite.org/windowfunctions.html
type WindowFunction interface {
AggregateFunction
// Inverse is invoked to remove the oldest presently aggregated result of Step from the current window.
// The function arguments, if any, are those passed to Step for the row being removed.
Inverse(ctx Context, arg ...Value)
}
func callbackDestroy(ctx context.Context, mod api.Module, pApp uint32) {
util.DelHandle(ctx, pApp)
}
func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))
}
func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(connKey{}).(*Conn).sqlite
fn := callbackHandle(sqlite, pCtx).(func(ctx Context, arg ...Value))
fn(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
}
func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(connKey{}).(*Conn).sqlite
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
fn.Step(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
}
func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) {
var handle uint32
sqlite := ctx.Value(connKey{}).(*Conn).sqlite
fn := callbackAggregate(sqlite, pCtx, &handle).(AggregateFunction)
fn.Value(Context{sqlite, pCtx})
if err := util.DelHandle(ctx, handle); err != nil {
Context{sqlite, pCtx}.ResultError(err)
}
}
func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) {
sqlite := ctx.Value(connKey{}).(*Conn).sqlite
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
fn.Value(Context{sqlite, pCtx})
}
func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(connKey{}).(*Conn).sqlite
fn := callbackAggregate(sqlite, pCtx, nil).(WindowFunction)
fn.Inverse(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
}
func callbackHandle(sqlite *sqlite, pCtx uint32) any {
pApp := uint32(sqlite.call(sqlite.api.userData, uint64(pCtx)))
return util.GetHandle(sqlite.ctx, pApp)
}
func callbackAggregate(sqlite *sqlite, pCtx uint32, close *uint32) any {
// On close, we're getting rid of the handle.
// Don't allocate space to store it.
var size uint64
if close == nil {
size = ptrlen
}
ptr := uint32(sqlite.call(sqlite.api.aggregateCtx, uint64(pCtx), size))
// Try loading the handle, if we already have one, or want a new one.
if ptr != 0 || size != 0 {
if handle := util.ReadUint32(sqlite.mod, ptr); handle != 0 {
fn := util.GetHandle(sqlite.ctx, handle)
if close != nil {
*close = handle
}
if fn != nil {
return fn
}
}
}
// Create a new aggregate and store the handle.
fn := callbackHandle(sqlite, pCtx).(func() AggregateFunction)()
if ptr != 0 {
util.WriteUint32(sqlite.mod, ptr, util.AddHandle(sqlite.ctx, fn))
}
return fn
}
func callbackArgs(sqlite *sqlite, nArg, pArg uint32) []Value {
args := make([]Value, nArg)
for i := range args {
args[i] = Value{
sqlite: sqlite,
handle: util.ReadUint32(sqlite.mod, pArg+ptrlen*uint32(i)),
}
}
return args
}

154
func_test.go Normal file
View File

@@ -0,0 +1,154 @@
package sqlite3_test
import (
"bytes"
"fmt"
"log"
"regexp"
"golang.org/x/text/collate"
"golang.org/x/text/language"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func ExampleConn_CreateCollation() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
log.Fatal(err)
}
err = db.CreateCollation("french", collate.New(language.French).Compare)
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT word FROM words ORDER BY word COLLATE french`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnText(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Output:
// cote
// coté
// côte
// côté
// cotée
// coter
}
func ExampleConn_CreateFunction() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
log.Fatal(err)
}
err = db.CreateFunction("upper", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob()))
})
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT upper(word) FROM words`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnText(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Unordered output:
// COTE
// COTÉ
// CÔTE
// CÔTÉ
// COTÉE
// COTER
}
func ExampleContext_SetAuxData() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
log.Fatal(err)
}
err = db.CreateFunction("regexp", 2, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return
}
ctx.SetAuxData(0, r)
re = r
}
ctx.ResultBool(re.Match(arg[1].RawBlob()))
})
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT word FROM words WHERE word REGEXP '^\p{L}+e$'`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnText(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Unordered output:
// cote
// côte
// cotée
}

87
func_win_test.go Normal file
View File

@@ -0,0 +1,87 @@
package sqlite3_test
import (
"fmt"
"log"
"unicode"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func ExampleConn_CreateWindowFunction() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
log.Fatal(err)
}
err = db.CreateWindowFunction("count_ascii", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, newASCIICounter)
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT count_ascii(word) OVER (ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM words`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnInt(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Output:
// 1
// 2
// 2
// 1
// 0
// 0
}
type countASCII struct{ result int }
func newASCIICounter() sqlite3.AggregateFunction {
return &countASCII{}
}
func (f *countASCII) Value(ctx sqlite3.Context) {
ctx.ResultInt(f.result)
}
func (f *countASCII) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
if f.isASCII(arg[0]) {
f.result++
}
}
func (f *countASCII) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
if f.isASCII(arg[0]) {
f.result--
}
}
func (f *countASCII) isASCII(arg sqlite3.Value) bool {
if arg.Type() != sqlite3.TEXT {
return false
}
for _, c := range arg.RawBlob() {
if c > unicode.MaxASCII {
return false
}
}
return true
}

11
go.mod
View File

@@ -1,13 +1,14 @@
module github.com/ncruces/go-sqlite3
go 1.19
go 1.21
require (
github.com/ncruces/julianday v0.1.5
github.com/ncruces/julianday v1.0.0
github.com/psanford/httpreadat v0.1.0
github.com/tetratelabs/wazero v1.2.0
golang.org/x/sync v0.2.0
golang.org/x/sys v0.8.0
github.com/tetratelabs/wazero v1.5.0
golang.org/x/sync v0.4.0
golang.org/x/sys v0.13.0
golang.org/x/text v0.13.0
)
retract v0.4.0 // tagged from the wrong branch

18
go.sum
View File

@@ -1,10 +1,12 @@
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
github.com/tetratelabs/wazero v1.2.0 h1:I/8LMf4YkCZ3r2XaL9whhA0VMyAvF6QE+O7rco0DCeQ=
github.com/tetratelabs/wazero v1.2.0/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=

6
go.work Normal file
View File

@@ -0,0 +1,6 @@
go 1.21
use (
.
./gormlite
)

5
go.work.sum Normal file
View File

@@ -0,0 +1,5 @@
github.com/ncruces/go-sqlite3 v0.9.1/go.mod h1:jFoUbaCDNUS1KN5ZgFxN7bgcWoWfO0EOKeik9QAHZ08=
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=

View File

@@ -1,5 +1,7 @@
# GORM SQLite Driver
[![Go Reference](https://pkg.go.dev/badge/image)](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
## Usage
```go
@@ -19,6 +21,6 @@ Checkout [https://gorm.io](https://gorm.io) for details.
Foreign-key constraint is disabled by default in SQLite. To activate it, use connection URL parameter:
```go
db, err := gorm.Open(gormlite.Open(
"file:gorm.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)&_pragma=foreign_keys(1)"),
"file:gorm.db?_pragma=busy_timeout(10000)&_pragma=foreign_keys(1)"),
&gorm.Config{})
```

View File

@@ -125,7 +125,7 @@ func parseDDL(strs ...string) (*ddl, error) {
ColumnTypeValue: sql.NullString{String: matches[2], Valid: true},
PrimaryKeyValue: sql.NullBool{Valid: true},
UniqueValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Bool: true, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
}
@@ -162,7 +162,7 @@ func parseDDL(strs ...string) (*ddl, error) {
for _, column := range getAllColumns(matches[1]) {
for idx, c := range result.columns {
if c.NameValue.String == column {
c.UniqueValue = sql.NullBool{Bool: true, Valid: true}
c.UniqueValue = sql.NullBool{Bool: strings.ToUpper(strings.Fields(str)[1]) == "UNIQUE", Valid: true}
result.columns[idx] = c
}
}
@@ -175,6 +175,18 @@ func parseDDL(strs ...string) (*ddl, error) {
return &result, nil
}
func (d *ddl) clone() *ddl {
copied := new(ddl)
*copied = *d
copied.fields = make([]string, len(d.fields))
copy(copied.fields, d.fields)
copied.columns = make([]migrator.ColumnType, len(d.columns))
copy(copied.columns, d.columns)
return copied
}
func (d *ddl) compile() string {
if len(d.fields) == 0 {
return d.head
@@ -183,6 +195,21 @@ func (d *ddl) compile() string {
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
}
func (d *ddl) renameTable(dst, src string) error {
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + regexp.QuoteMeta(src) + "\\b('|`|\")?\\s*")
if err != nil {
return err
}
replaced := tableReg.ReplaceAllString(d.head, fmt.Sprintf(" `%s` ", dst))
if replaced == d.head {
return fmt.Errorf("failed to look up tablename `%s` from DDL head '%s'", src, d.head)
}
d.head = replaced
return nil
}
func (d *ddl) addConstraint(name string, sql string) {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
@@ -208,6 +235,18 @@ func (d *ddl) removeConstraint(name string) bool {
return false
}
//lint:ignore U1000 ignore unused code.
func (d *ddl) hasConstraint(name string) bool {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
for _, f := range d.fields {
if reg.MatchString(f) {
return true
}
}
return false
}
func (d *ddl) getColumns() []string {
res := []string{}
@@ -229,3 +268,30 @@ func (d *ddl) getColumns() []string {
}
return res
}
func (d *ddl) alterColumn(name, sql string) bool {
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields[i] = sql
return false
}
}
d.fields = append(d.fields, sql)
return true
}
func (d *ddl) removeColumn(name string) bool {
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields = append(d.fields[:i], d.fields[i+1:]...)
return true
}
}
return false
}

View File

@@ -20,16 +20,16 @@ func TestParseDDL(t *testing.T) {
"CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)",
}, 6, []migrator.ColumnType{
{NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}},
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
},
{"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
}},
{"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
@@ -39,7 +39,7 @@ func TestParseDDL(t *testing.T) {
{"with_special_characters", []string{
"CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")",
}, 1, []migrator.ColumnType{
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
},
{
@@ -79,6 +79,24 @@ func TestParseDDL(t *testing.T) {
},
},
},
{
"non-unique index",
[]string{
"CREATE TABLE `test-c` (`field` integer NOT NULL)",
"CREATE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
},
1,
[]migrator.ColumnType{
{
NameValue: sql.NullString{String: "field", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
UniqueValue: sql.NullBool{Bool: false, Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
},
},
},
}
for _, p := range params {
@@ -104,7 +122,7 @@ func TestParseDDL_Whitespaces(t *testing.T) {
NameValue: sql.NullString{String: "id", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
NullableValue: sql.NullBool{Bool: true, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true},
@@ -113,7 +131,7 @@ func TestParseDDL_Whitespaces(t *testing.T) {
NameValue: sql.NullString{String: "dark_mode", Valid: true},
DataTypeValue: sql.NullString{String: "numeric", Valid: true},
ColumnTypeValue: sql.NullString{String: "numeric", Valid: true},
NullableValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Bool: true, Valid: true},
DefaultValueValue: sql.NullString{String: "true", Valid: true},
UniqueValue: sql.NullBool{Bool: false, Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},

11
gormlite/download.sh Executable file
View File

@@ -0,0 +1,11 @@
#!/usr/bin/env bash
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/ddlmod.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/ddlmod_test.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/error_translator.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/migrator.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/sqlite.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/sqlite_test.go"

View File

@@ -7,9 +7,15 @@ import (
"gorm.io/gorm"
)
func (dialector Dialector) Translate(err error) error {
if errors.Is(err, sqlite3.CONSTRAINT_UNIQUE) {
func (_Dialector) Translate(err error) error {
switch {
case
errors.Is(err, sqlite3.CONSTRAINT_UNIQUE),
errors.Is(err, sqlite3.CONSTRAINT_PRIMARYKEY):
return gorm.ErrDuplicatedKey
case
errors.Is(err, sqlite3.CONSTRAINT_FOREIGNKEY):
return gorm.ErrForeignKeyViolated
}
return err
}

View File

@@ -1,7 +0,0 @@
package gormlite
import "errors"
var (
ErrConstraintsNotImplemented = errors.New("constraints not implemented on sqlite, consider using DisableForeignKeyConstraintWhenMigrating, more details https://github.com/go-gorm/gorm/wiki/GORM-V2-Release-Note-Draft#all-new-migrator")
)

View File

@@ -1,19 +1,16 @@
module github.com/ncruces/go-sqlite/gormlite
module github.com/ncruces/go-sqlite3/gormlite
go 1.19
go 1.21
require (
github.com/google/uuid v1.3.0
github.com/jinzhu/now v1.1.5
github.com/ncruces/go-sqlite3 v0.7.1
gorm.io/driver/mysql v1.5.1
gorm.io/gorm v1.25.1
github.com/ncruces/go-sqlite3 v0.9.1
gorm.io/gorm v1.25.5
)
require (
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/ncruces/julianday v0.1.5 // indirect
github.com/tetratelabs/wazero v1.2.0 // indirect
golang.org/x/sys v0.8.0 // indirect
github.com/tetratelabs/wazero v1.5.0 // indirect
golang.org/x/sys v0.13.0 // indirect
)

View File

@@ -1,20 +1,16 @@
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/ncruces/go-sqlite3 v0.7.1 h1:SDd3g18RobYi3NM8nZgozHM6jaqIbpMEmX42YGOVcTU=
github.com/ncruces/go-sqlite3 v0.7.1/go.mod h1:n+DEDYam8SK5jmsfUC/9GFhSF0gVHGXiYFXnAo8Jwsc=
github.com/ncruces/go-sqlite3 v0.9.1 h1:kV7Zy+ZNyHMfMyZeWc1Yyq+wtgYZDZdp2qAA/wfeMWo=
github.com/ncruces/go-sqlite3 v0.9.1/go.mod h1:jFoUbaCDNUS1KN5ZgFxN7bgcWoWfO0EOKeik9QAHZ08=
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.2.0 h1:I/8LMf4YkCZ3r2XaL9whhA0VMyAvF6QE+O7rco0DCeQ=
github.com/tetratelabs/wazero v1.2.0/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gorm.io/driver/mysql v1.5.1 h1:WUEH5VF9obL/lTtzjmML/5e6VfFR/788coz2uaVCAZw=
gorm.io/driver/mysql v1.5.1/go.mod h1:Jo3Xu7mMhCyj8dlrb3WoCaRd1FhsVh+yMXb1jUInf5o=
gorm.io/gorm v1.25.1 h1:nsSALe5Pr+cM3V1qwwQ7rOkw+6UeLrX5O4v3llhHa64=
gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=

View File

@@ -3,7 +3,6 @@ package gormlite
import (
"database/sql"
"fmt"
"regexp"
"strings"
"gorm.io/gorm"
@@ -12,11 +11,11 @@ import (
"gorm.io/gorm/schema"
)
type Migrator struct {
type _Migrator struct {
migrator.Migrator
}
func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
func (m *_Migrator) RunWithoutForeignKey(fc func() error) error {
var enabled int
m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
if enabled == 1 {
@@ -27,7 +26,7 @@ func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
return fc()
}
func (m Migrator) HasTable(value interface{}) bool {
func (m _Migrator) HasTable(value interface{}) bool {
var count int
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
@@ -35,7 +34,7 @@ func (m Migrator) HasTable(value interface{}) bool {
return count > 0
}
func (m Migrator) DropTable(values ...interface{}) error {
func (m _Migrator) DropTable(values ...interface{}) error {
return m.RunWithoutForeignKey(func() error {
values = m.ReorderModels(values, false)
tx := m.DB.Session(&gorm.Session{})
@@ -52,11 +51,11 @@ func (m Migrator) DropTable(values ...interface{}) error {
})
}
func (m Migrator) GetTables() (tableList []string, err error) {
func (m _Migrator) GetTables() (tableList []string, err error) {
return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error
}
func (m Migrator) HasColumn(value interface{}, name string) bool {
func (m _Migrator) HasColumn(value interface{}, name string) bool {
var count int
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
@@ -76,31 +75,24 @@ func (m Migrator) HasColumn(value interface{}, name string) bool {
return count > 0
}
func (m Migrator) AlterColumn(value interface{}, name string) error {
func (m _Migrator) AlterColumn(value interface{}, name string) error {
return m.RunWithoutForeignKey(func() error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
if field := stmt.Schema.LookUpField(name); field != nil {
// lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)`
reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)")
if err != nil {
return "", nil, err
if ddl.alterColumn(field.DBName, fmt.Sprintf("`%s` ?", field.DBName)) {
return nil, nil, fmt.Errorf("field `%s` not found in origin ddl, ddl= '%s'", name, ddl.compile())
}
createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName))
if createSQL == rawDDL {
return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL)
}
return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil
return ddl, []interface{}{m.FullDataTypeOf(field)}, nil
}
return "", nil, fmt.Errorf("failed to alter field with name %v", name)
return nil, nil, fmt.Errorf("failed to alter field with name `%s`", name)
})
})
}
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
func (m _Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
var (
@@ -148,29 +140,23 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
return columnTypes, execErr
}
func (m Migrator) DropColumn(value interface{}, name string) error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
func (m _Migrator) DropColumn(value interface{}, name string) error {
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,")
if err != nil {
return "", nil, err
}
createSQL := reg.ReplaceAllString(rawDDL, "")
return createSQL, nil, nil
ddl.removeColumn(name)
return ddl, nil, nil
})
}
func (m Migrator) CreateConstraint(value interface{}, name string) error {
func (m _Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
return m.recreateTable(value, &table,
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
var (
constraintName string
constraintSql string
@@ -185,22 +171,16 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
constraintSql = "CONSTRAINT ? CHECK (?)"
constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
} else {
return "", nil, nil
return nil, nil, nil
}
createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
}
createDDL.addConstraint(constraintName, constraintSql)
createSQL := createDDL.compile()
return createSQL, constraintValues, nil
ddl.addConstraint(constraintName, constraintSql)
return ddl, constraintValues, nil
})
})
}
func (m Migrator) DropConstraint(value interface{}, name string) error {
func (m _Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil {
@@ -210,20 +190,14 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
}
return m.recreateTable(value, &table,
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
}
createDDL.removeConstraint(name)
createSQL := createDDL.compile()
return createSQL, nil, nil
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
ddl.removeConstraint(name)
return ddl, nil, nil
})
})
}
func (m Migrator) HasConstraint(value interface{}, name string) bool {
func (m _Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
@@ -244,13 +218,13 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
return count > 0
}
func (m Migrator) CurrentDatabase() (name string) {
func (m _Migrator) CurrentDatabase() (name string) {
var null interface{}
m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
return
}
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
func (m _Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
if opt.Expression != "" {
@@ -269,7 +243,7 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem
return
}
func (m Migrator) CreateIndex(value interface{}, name string) error {
func (m _Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
@@ -298,7 +272,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
})
}
func (m Migrator) HasIndex(value interface{}, name string) bool {
func (m _Migrator) HasIndex(value interface{}, name string) bool {
var count int
m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
@@ -317,18 +291,21 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
return count > 0
}
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
func (m _Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
var sql string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
if sql != "" {
if err := m.DropIndex(value, oldName); err != nil {
return err
}
return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
}
return fmt.Errorf("failed to find index with name %v", oldName)
})
}
func (m Migrator) DropIndex(value interface{}, name string) error {
func (m _Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
@@ -362,7 +339,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter
return
}
func (m Migrator) getRawDDL(table string) (string, error) {
func (m _Migrator) getRawDDL(table string) (string, error) {
var createSQL string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
@@ -372,8 +349,10 @@ func (m Migrator) getRawDDL(table string) (string, error) {
return createSQL, nil
}
func (m Migrator) recreateTable(value interface{}, tablePtr *string,
getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error {
func (m _Migrator) recreateTable(
value interface{}, tablePtr *string,
getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error),
) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
table := stmt.Table
if tablePtr != nil {
@@ -385,27 +364,26 @@ func (m Migrator) recreateTable(value interface{}, tablePtr *string,
return err
}
newTableName := table + "__temp"
createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt)
originDDL, err := parseDDL(rawDDL)
if err != nil {
return err
}
if createSQL == "" {
createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt)
if err != nil {
return err
}
if createDDL == nil {
return nil
}
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
if err != nil {
newTableName := table + "__temp"
if err := createDDL.renameTable(newTableName, table); err != nil {
return err
}
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
createDDL, err := parseDDL(createSQL)
if err != nil {
return err
}
columns := createDDL.getColumns()
createSQL := createDDL.compile()
return m.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {

View File

@@ -5,7 +5,6 @@ import (
"context"
"database/sql"
"strconv"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
@@ -14,27 +13,33 @@ import (
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
_ "github.com/ncruces/go-sqlite3/driver"
"github.com/ncruces/go-sqlite3/driver"
)
type Dialector struct {
// Open opens a GORM dialector from a data source name.
func Open(dsn string) gorm.Dialector {
return &_Dialector{DSN: dsn}
}
// Open opens a GORM dialector from a database handle.
func OpenDB(db *sql.DB) gorm.Dialector {
return &_Dialector{Conn: db}
}
type _Dialector struct {
DSN string
Conn gorm.ConnPool
}
func Open(dsn string) gorm.Dialector {
return &Dialector{DSN: dsn}
}
func (dialector Dialector) Name() string {
func (dialector _Dialector) Name() string {
return "sqlite"
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
func (dialector _Dialector) Initialize(db *gorm.DB) (err error) {
if dialector.Conn != nil {
db.ConnPool = dialector.Conn
} else {
conn, err := sql.Open("sqlite3", dialector.DSN)
conn, err := driver.Open(dialector.DSN, nil)
if err != nil {
return err
}
@@ -49,7 +54,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
if compareVersion(version, "3.35.0") >= 0 {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE", "RETURNING"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
LastInsertIDReversed: true,
})
@@ -65,7 +70,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
return
}
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
func (dialector _Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
return map[string]clause.ClauseBuilder{
"INSERT": func(c clause.Clause, builder clause.Builder) {
if insert, ok := c.Expression.(clause.Insert); ok {
@@ -114,7 +119,7 @@ func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
}
}
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
func (dialector _Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
if field.AutoIncrement {
return clause.Expr{SQL: "NULL"}
}
@@ -123,44 +128,77 @@ func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression
return clause.Expr{SQL: "DEFAULT"}
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{
func (dialector _Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return _Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
CreateIndexAfterCreateTable: true,
}}}
}
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
func (dialector _Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteByte('?')
}
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteByte('`')
if strings.Contains(str, ".") {
for idx, str := range strings.Split(str, ".") {
if idx > 0 {
writer.WriteString(".`")
func (dialector _Dialector) QuoteTo(writer clause.Writer, str string) {
var (
underQuoted, selfQuoted bool
continuousBacktick int8
shiftDelimiter int8
)
for _, v := range []byte(str) {
switch v {
case '`':
continuousBacktick++
if continuousBacktick == 2 {
writer.WriteString("``")
continuousBacktick = 0
}
writer.WriteString(str)
writer.WriteByte('`')
case '.':
if continuousBacktick > 0 || !selfQuoted {
shiftDelimiter = 0
underQuoted = false
continuousBacktick = 0
writer.WriteString("`")
}
writer.WriteByte(v)
continue
default:
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
writer.WriteString("`")
underQuoted = true
if selfQuoted = continuousBacktick > 0; selfQuoted {
continuousBacktick -= 1
}
}
for ; continuousBacktick > 0; continuousBacktick -= 1 {
writer.WriteString("``")
}
writer.WriteByte(v)
}
} else {
writer.WriteString(str)
writer.WriteByte('`')
shiftDelimiter++
}
if continuousBacktick > 0 && !selfQuoted {
writer.WriteString("``")
}
writer.WriteString("`")
}
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
func (dialector _Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, nil, `"`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
func (dialector _Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType {
case schema.Bool:
return "numeric"
case schema.Int, schema.Uint:
if field.AutoIncrement && !field.PrimaryKey {
if field.AutoIncrement {
// doesn't check `PrimaryKey`, to keep backward compatibility
// https://www.sqlite.org/autoinc.html
return "integer PRIMARY KEY AUTOINCREMENT"
} else {
@@ -184,12 +222,12 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
return string(field.DataType)
}
func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error {
func (dialectopr _Dialector) SavePoint(tx *gorm.DB, name string) error {
tx.Exec("SAVEPOINT " + name)
return nil
}
func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
func (dialectopr _Dialector) RollbackTo(tx *gorm.DB, name string) error {
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
return nil
}

View File

@@ -6,6 +6,8 @@ import (
"gorm.io/gorm"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
@@ -13,22 +15,52 @@ func TestDialector(t *testing.T) {
// This is the DSN of the in-memory SQLite database for these tests.
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
// Custom connection with a custom function called "my_custom_function".
db, err := driver.Open(InMemoryDSN, func(conn *sqlite3.Conn) error {
return conn.CreateFunction("my_custom_function", 0, sqlite3.DETERMINISTIC,
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultText("my-result")
})
})
if err != nil {
t.Fatal(err)
}
rows := []struct {
description string
dialector *Dialector
dialector gorm.Dialector
openSuccess bool
query string
querySuccess bool
}{
{
description: "Default driver",
dialector: &Dialector{
DSN: InMemoryDSN,
},
description: "Default driver",
dialector: Open(InMemoryDSN),
openSuccess: true,
query: "SELECT 1",
querySuccess: true,
},
{
description: "Custom function",
dialector: Open(InMemoryDSN),
openSuccess: true,
query: "SELECT my_custom_function()",
querySuccess: false,
},
{
description: "Custom connection",
dialector: OpenDB(db),
openSuccess: true,
query: "SELECT 1",
querySuccess: true,
},
{
description: "Custom connection, custom function",
dialector: OpenDB(db),
openSuccess: true,
query: "SELECT my_custom_function()",
querySuccess: true,
},
}
for rowIndex, row := range rows {
t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) {

View File

@@ -4,11 +4,21 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
rm -rf gorm/ tests/
git clone --filter=blob:none --branch=v1.25.1 https://github.com/go-gorm/gorm.git
git clone --filter=blob:none https://github.com/go-gorm/gorm.git
mv gorm/tests tests
rm -rf gorm/
patch -p1 -N < tests.patch
cd tests
go mod tidy && go test
go mod edit \
-require github.com/ncruces/go-sqlite3/gormlite@v0.0.0 \
-replace github.com/ncruces/go-sqlite3/gormlite=../ \
-replace github.com/ncruces/go-sqlite3=../../ \
-droprequire gorm.io/driver/sqlite \
-dropreplace gorm.io/gorm
go mod tidy && go work use . && go test
cd ..
rm -rf tests/
go work use -r .

View File

@@ -1,51 +1,10 @@
diff --git a/tests/.gitignore b/tests/.gitignore
index 08cb523..72e8ffc 100644
--- a/tests/.gitignore
+++ b/tests/.gitignore
@@ -1 +1 @@
-go.sum
+*
diff --git a/tests/go.mod b/tests/go.mod
index f47d175..dba4a24 100644
--- a/tests/go.mod
+++ b/tests/go.mod
@@ -7,13 +7,10 @@ require (
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.8
- github.com/mattn/go-sqlite3 v1.14.16 // indirect
+ github.com/ncruces/go-sqlite3 v0.7.1
golang.org/x/crypto v0.8.0 // indirect
gorm.io/driver/mysql v1.5.0
gorm.io/driver/postgres v1.5.0
- gorm.io/driver/sqlite v1.5.0
gorm.io/driver/sqlserver v1.4.3
- gorm.io/gorm v1.25.0
-)
-
-replace gorm.io/gorm => ../
+ gorm.io/gorm v1.25.1
+)
\ No newline at end of file
diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go
index 1412169..472434b 100644
--- a/tests/scanner_valuer_test.go
+++ b/tests/scanner_valuer_test.go
@@ -170,10 +170,10 @@ func (data *EncryptedData) Scan(value interface{}) error {
return errors.New("Too short")
}
- *data = b[3:]
+ *data = append((*data)[0:], b[3:]...)
return nil
} else if s, ok := value.(string); ok {
- *data = []byte(s)[3:]
+ *data = []byte(s[3:])
return nil
}
diff --git a/tests/tests_test.go b/tests/tests_test.go
index 90eb847..cd9af43 100644
--- a/tests/tests_test.go
+++ b/tests/tests_test.go
@@ -7,9 +7,11 @@ import (
@@ -61,3 +20,12 @@ index 90eb847..cd9af43 100644
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@@ -89,7 +91,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
db, err = gorm.Open(mysql.Open(dbDSN), cfg)
default:
log.Println("testing sqlite3...")
- db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg)
+ db, err = gorm.Open(sqlite.Open("file:"+filepath.Join(os.TempDir(), "gorm.db")+"?_pragma=busy_timeout(1000)&_pragma=foreign_keys(1)"), cfg)
}
if err != nil {

View File

@@ -10,6 +10,32 @@ import (
type i32 interface{ ~int32 | ~uint32 }
type i64 interface{ ~int64 | ~uint64 }
type funcVI[T0 i32] func(context.Context, api.Module, T0)
func (fn funcVI[T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {
fn(ctx, mod, T0(stack[0]))
}
func ExportFuncVI[T0 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0)) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcVI[T0](fn),
[]api.ValueType{api.ValueTypeI32}, nil).
Export(name)
}
type funcVIII[T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2)
func (fn funcVIII[T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]))
}
func ExportFuncVIII[T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2)) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcVIII[T0, T1, T2](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, nil).
Export(name)
}
type funcII[TR, T0 i32] func(context.Context, api.Module, T0) TR
func (fn funcII[TR, T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {

75
internal/util/handle.go Normal file
View File

@@ -0,0 +1,75 @@
package util
import (
"context"
"io"
"github.com/tetratelabs/wazero/experimental"
)
type handleKey struct{}
type handleState struct {
handles []any
empty int
}
func NewContext(ctx context.Context) context.Context {
state := new(handleState)
ctx = experimental.WithCloseNotifier(ctx, state)
ctx = context.WithValue(ctx, handleKey{}, state)
return ctx
}
func (s *handleState) CloseNotify(ctx context.Context, exitCode uint32) {
for _, h := range s.handles {
if c, ok := h.(io.Closer); ok {
c.Close()
}
}
s.handles = nil
s.empty = 0
}
func GetHandle(ctx context.Context, id uint32) any {
if id == 0 {
return nil
}
s := ctx.Value(handleKey{}).(*handleState)
return s.handles[^id]
}
func DelHandle(ctx context.Context, id uint32) error {
if id == 0 {
return nil
}
s := ctx.Value(handleKey{}).(*handleState)
a := s.handles[^id]
s.handles[^id] = nil
s.empty++
if c, ok := a.(io.Closer); ok {
return c.Close()
}
return nil
}
func AddHandle(ctx context.Context, a any) (id uint32) {
if a == nil {
panic(NilErr)
}
s := ctx.Value(handleKey{}).(*handleState)
// Find an empty slot.
if s.empty > cap(s.handles)-len(s.handles) {
for id, h := range s.handles {
if h == nil {
s.empty--
s.handles[id] = a
return ^uint32(id)
}
}
}
// Add a new slot.
s.handles = append(s.handles, a)
return -uint32(len(s.handles))
}

View File

@@ -14,6 +14,9 @@ func View(mod api.Module, ptr uint32, size uint64) []byte {
if size > math.MaxUint32 {
panic(RangeErr)
}
if size == 0 {
return nil
}
buf, ok := mod.Memory().Read(ptr, uint32(size))
if !ok {
panic(RangeErr)

56
json.go Normal file
View File

@@ -0,0 +1,56 @@
package sqlite3
import (
"encoding/json"
"strconv"
"time"
"unsafe"
"github.com/ncruces/go-sqlite3/internal/util"
)
// JSON returns:
// a [json.Marshaler] that can be used as an argument to
// [database/sql.DB.Exec] and similar methods to
// store value as JSON; and
// a [database/sql.Scanner] that can be used as an argument to
// [database/sql.Row.Scan] and similar methods to
// decode JSON into value.
func JSON(value any) any {
return jsonValue{value}
}
type jsonValue struct{ any }
func (j jsonValue) MarshalJSON() ([]byte, error) {
return json.Marshal(j.any)
}
func (j jsonValue) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, j.any)
}
func (j jsonValue) Scan(value any) error {
var buf []byte
switch v := value.(type) {
case []byte:
buf = v
case string:
buf = unsafe.Slice(unsafe.StringData(v), len(v))
case int64:
buf = strconv.AppendInt(nil, v, 10)
case float64:
buf = strconv.AppendFloat(nil, v, 'g', -1, 64)
case time.Time:
buf = append(buf, '"')
buf = v.AppendFormat(buf, time.RFC3339Nano)
buf = append(buf, '"')
case nil:
buf = append(buf, "null"...)
default:
panic(util.AssertErr())
}
return j.UnmarshalJSON(buf)
}

View File

@@ -3,7 +3,6 @@ package sqlite3
import (
"context"
"io"
"math"
"os"
"sync"
@@ -23,72 +22,72 @@ import (
var (
Binary []byte // WASM binary to load.
Path string // Path to load the binary from.
RuntimeConfig wazero.RuntimeConfig
)
var sqlite3 struct {
var instance struct {
runtime wazero.Runtime
compiled wazero.CompiledModule
err error
once sync.Once
}
func instantiateModule() (*module, error) {
ctx := context.Background()
sqlite3.once.Do(compileModule)
if sqlite3.err != nil {
return nil, sqlite3.err
func compileSQLite() {
if RuntimeConfig == nil {
RuntimeConfig = wazero.NewRuntimeConfig()
}
cfg := wazero.NewModuleConfig()
mod, err := sqlite3.runtime.InstantiateModule(ctx, sqlite3.compiled, cfg)
if err != nil {
return nil, err
}
return newModule(mod)
}
func compileModule() {
ctx := context.Background()
sqlite3.runtime = wazero.NewRuntime(ctx)
instance.runtime = wazero.NewRuntimeWithConfig(ctx, RuntimeConfig)
env := vfs.ExportHostFunctions(sqlite3.runtime.NewHostModuleBuilder("env"))
_, sqlite3.err = env.Instantiate(ctx)
if sqlite3.err != nil {
env := instance.runtime.NewHostModuleBuilder("env")
env = vfs.ExportHostFunctions(env)
env = exportCallbacks(env)
_, instance.err = env.Instantiate(ctx)
if instance.err != nil {
return
}
bin := Binary
if bin == nil && Path != "" {
bin, sqlite3.err = os.ReadFile(Path)
if sqlite3.err != nil {
bin, instance.err = os.ReadFile(Path)
if instance.err != nil {
return
}
}
if bin == nil {
sqlite3.err = util.BinaryErr
instance.err = util.BinaryErr
return
}
sqlite3.compiled, sqlite3.err = sqlite3.runtime.CompileModule(ctx, bin)
instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin)
}
type module struct {
ctx context.Context
mod api.Module
vfs io.Closer
api sqliteAPI
arg [8]uint64
type sqlite struct {
ctx context.Context
mod api.Module
api sqliteAPI
stack [8]uint64
}
func newModule(mod api.Module) (m *module, err error) {
m = new(module)
m.mod = mod
m.ctx, m.vfs = vfs.NewContext(context.Background())
func instantiateSQLite() (sqlt *sqlite, err error) {
instance.once.Do(compileSQLite)
if instance.err != nil {
return nil, instance.err
}
sqlt = new(sqlite)
sqlt.ctx = util.NewContext(context.Background())
sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
instance.compiled, wazero.NewModuleConfig())
if err != nil {
return nil, err
}
getFun := func(name string) api.Function {
f := mod.ExportedFunction(name)
f := sqlt.mod.ExportedFunction(name)
if f == nil {
err = util.NoFuncErr + util.ErrorString(name)
return nil
@@ -97,15 +96,15 @@ func newModule(mod api.Module) (m *module, err error) {
}
getVal := func(name string) uint32 {
g := mod.ExportedGlobal(name)
g := sqlt.mod.ExportedGlobal(name)
if g == nil {
err = util.NoGlobalErr + util.ErrorString(name)
return 0
}
return util.ReadUint32(mod, uint32(g.Get()))
return util.ReadUint32(sqlt.mod, uint32(g.Get()))
}
m.api = sqliteAPI{
sqlt.api = sqliteAPI{
free: getFun("free"),
malloc: getFun("malloc"),
destructor: getVal("malloc_destructor"),
@@ -121,6 +120,8 @@ func newModule(mod api.Module) (m *module, err error) {
reset: getFun("sqlite3_reset"),
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
interrupt: getFun("sqlite3_interrupt"),
progressHandler: getFun("sqlite3_progress_handler_go"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
@@ -153,20 +154,44 @@ func newModule(mod api.Module) (m *module, err error) {
changes: getFun("sqlite3_changes64"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
autocommit: getFun("sqlite3_get_autocommit"),
anyCollation: getFun("sqlite3_anycollseq_init"),
createCollation: getFun("sqlite3_create_collation_go"),
createFunction: getFun("sqlite3_create_function_go"),
createAggregate: getFun("sqlite3_create_aggregate_function_go"),
createWindow: getFun("sqlite3_create_window_function_go"),
aggregateCtx: getFun("sqlite3_aggregate_context"),
userData: getFun("sqlite3_user_data"),
setAuxData: getFun("sqlite3_set_auxdata_go"),
getAuxData: getFun("sqlite3_get_auxdata"),
valueType: getFun("sqlite3_value_type"),
valueInteger: getFun("sqlite3_value_int64"),
valueFloat: getFun("sqlite3_value_double"),
valueText: getFun("sqlite3_value_text"),
valueBlob: getFun("sqlite3_value_blob"),
valueBytes: getFun("sqlite3_value_bytes"),
resultNull: getFun("sqlite3_result_null"),
resultInteger: getFun("sqlite3_result_int64"),
resultFloat: getFun("sqlite3_result_double"),
resultText: getFun("sqlite3_result_text64"),
resultBlob: getFun("sqlite3_result_blob64"),
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
resultValue: getFun("sqlite3_result_value"),
resultError: getFun("sqlite3_result_error"),
resultErrorCode: getFun("sqlite3_result_error_code"),
resultErrorMem: getFun("sqlite3_result_error_nomem"),
resultErrorBig: getFun("sqlite3_result_error_toobig"),
}
if err != nil {
return nil, err
}
return m, nil
return sqlt, nil
}
func (m *module) close() error {
err := m.mod.Close(m.ctx)
m.vfs.Close()
return err
func (sqlt *sqlite) close() error {
return sqlt.mod.Close(sqlt.ctx)
}
func (m *module) error(rc uint64, handle uint32, sql ...string) error {
func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
if rc == _OK {
return nil
}
@@ -177,17 +202,19 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
panic(util.OOMErr)
}
if r := m.call(m.api.errstr, rc); r != 0 {
err.str = util.ReadString(m.mod, uint32(r), _MAX_STRING)
if r := sqlt.call(sqlt.api.errstr, rc); r != 0 {
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if r := m.call(m.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(m.mod, uint32(r), _MAX_STRING)
}
if handle != 0 {
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if sql != nil {
if r := m.call(m.api.erroff, uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:]
if sql != nil {
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:]
}
}
}
@@ -198,60 +225,58 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
return &err
}
func (m *module) call(fn api.Function, params ...uint64) uint64 {
copy(m.arg[:], params)
err := fn.CallWithStack(m.ctx, m.arg[:])
func (sqlt *sqlite) call(fn api.Function, params ...uint64) uint64 {
copy(sqlt.stack[:], params)
err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
if err != nil {
// The module closed or panicked; release resources.
m.vfs.Close()
panic(err)
}
return m.arg[0]
return sqlt.stack[0]
}
func (m *module) free(ptr uint32) {
func (sqlt *sqlite) free(ptr uint32) {
if ptr == 0 {
return
}
m.call(m.api.free, uint64(ptr))
sqlt.call(sqlt.api.free, uint64(ptr))
}
func (m *module) new(size uint64) uint32 {
func (sqlt *sqlite) new(size uint64) uint32 {
if size > _MAX_ALLOCATION_SIZE {
panic(util.OOMErr)
}
ptr := uint32(m.call(m.api.malloc, size))
ptr := uint32(sqlt.call(sqlt.api.malloc, size))
if ptr == 0 && size != 0 {
panic(util.OOMErr)
}
return ptr
}
func (m *module) newBytes(b []byte) uint32 {
if b == nil {
func (sqlt *sqlite) newBytes(b []byte) uint32 {
if (*[0]byte)(b) == nil {
return 0
}
ptr := m.new(uint64(len(b)))
util.WriteBytes(m.mod, ptr, b)
ptr := sqlt.new(uint64(len(b)))
util.WriteBytes(sqlt.mod, ptr, b)
return ptr
}
func (m *module) newString(s string) uint32 {
ptr := m.new(uint64(len(s) + 1))
util.WriteString(m.mod, ptr, s)
func (sqlt *sqlite) newString(s string) uint32 {
ptr := sqlt.new(uint64(len(s) + 1))
util.WriteString(sqlt.mod, ptr, s)
return ptr
}
func (m *module) newArena(size uint64) arena {
func (sqlt *sqlite) newArena(size uint64) arena {
return arena{
m: m,
base: m.new(size),
sqlt: sqlt,
size: uint32(size),
base: sqlt.new(size),
}
}
type arena struct {
m *module
sqlt *sqlite
ptrs []uint32
base uint32
next uint32
@@ -259,17 +284,17 @@ type arena struct {
}
func (a *arena) free() {
if a.m == nil {
if a.sqlt == nil {
return
}
a.reset()
a.m.free(a.base)
a.m = nil
a.sqlt.free(a.base)
a.sqlt = nil
}
func (a *arena) reset() {
for _, ptr := range a.ptrs {
a.m.free(ptr)
a.sqlt.free(ptr)
}
a.ptrs = nil
a.next = 0
@@ -281,7 +306,7 @@ func (a *arena) new(size uint64) uint32 {
a.next += uint32(size)
return ptr
}
ptr := a.m.new(size)
ptr := a.sqlt.new(size)
a.ptrs = append(a.ptrs, ptr)
return ptr
}
@@ -291,13 +316,13 @@ func (a *arena) bytes(b []byte) uint32 {
return 0
}
ptr := a.new(uint64(len(b)))
util.WriteBytes(a.m.mod, ptr, b)
util.WriteBytes(a.sqlt.mod, ptr, b)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint64(len(s) + 1))
util.WriteString(a.m.mod, ptr, s)
util.WriteString(a.sqlt.mod, ptr, s)
return ptr
}
@@ -316,11 +341,13 @@ type sqliteAPI struct {
reset api.Function
step api.Function
exec api.Function
interrupt api.Function
progressHandler api.Function
clearBindings api.Function
bindNull api.Function
bindCount api.Function
bindIndex api.Function
bindName api.Function
bindNull api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
@@ -348,5 +375,43 @@ type sqliteAPI struct {
changes api.Function
lastRowid api.Function
autocommit api.Function
anyCollation api.Function
createCollation api.Function
createFunction api.Function
createAggregate api.Function
createWindow api.Function
aggregateCtx api.Function
userData api.Function
setAuxData api.Function
getAuxData api.Function
valueType api.Function
valueInteger api.Function
valueFloat api.Function
valueText api.Function
valueBlob api.Function
valueBytes api.Function
resultNull api.Function
resultInteger api.Function
resultFloat api.Function
resultText api.Function
resultBlob api.Function
resultZeroBlob api.Function
resultValue api.Function
resultError api.Function
resultErrorCode api.Function
resultErrorMem api.Function
resultErrorBig api.Function
destructor uint32
}
func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncII(env, "go_progress", callbackProgress)
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
util.ExportFuncVIII(env, "go_func", callbackFunc)
util.ExportFuncVIII(env, "go_step", callbackStep)
util.ExportFuncVI(env, "go_final", callbackFinal)
util.ExportFuncVI(env, "go_value", callbackValue)
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
return env
}

1
sqlite3/.gitignore vendored
View File

@@ -1,3 +1,4 @@
ext/
sqlite3.c
sqlite3.h
sqlite3ext.h

View File

@@ -1,20 +0,0 @@
--- sqlite3.c.orig
+++ sqlite3.c
@@ -60425,7 +60425,7 @@
int rc = SQLITE_OK; /* Return code */
int tempFile = 0; /* True for temp files (incl. in-memory files) */
int memDb = 0; /* True if this is an in-memory file */
-#ifndef SQLITE_OMIT_DESERIALIZE
+#if 1
int memJM = 0; /* Memory journal mode */
#else
# define memJM 0
@@ -60628,7 +60628,7 @@
int fout = 0; /* VFS flags returned by xOpen() */
rc = sqlite3OsOpen(pVfs, pPager->zFilename, pPager->fd, vfsFlags, &fout);
assert( !memDb );
-#ifndef SQLITE_OMIT_DESERIALIZE
+#if 1
pPager->memVfs = memJM = (fout&SQLITE_OPEN_MEMORY)!=0;
#endif
readOnly = (fout&SQLITE_OPEN_READONLY)!=0;

View File

@@ -3,32 +3,33 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3420000.zip"
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3440000.zip"
unzip -d . sqlite-amalgamation-*.zip
mv sqlite-amalgamation-*/sqlite3* .
rm -rf sqlite-amalgamation-*
patch < vfs_find.patch
patch < deserialize.patch
cat *.patch | patch --posix
mkdir -p ext/
cd ext/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uuid.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/series.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/uuid.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/series.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/anycollseq.c"
cd ~-
cd ../vfs/tests/mptest/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/mptest.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config02.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash02.subtest"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/multiwrite01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/mptest.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/config02.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/crash01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/crash02.subtest"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/multiwrite01.test"
cd ~-
cd ../vfs/tests/speedtest1/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/test/speedtest1.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/test/speedtest1.c"
cd ~-

View File

@@ -1 +0,0 @@
*.c

41
sqlite3/func.c Normal file
View File

@@ -0,0 +1,41 @@
#include <stddef.h>
#include "sqlite3.h"
int go_compare(void *, int, const void *, int, const void *);
void go_func(sqlite3_context *, int, sqlite3_value **);
void go_step(sqlite3_context *, int, sqlite3_value **);
void go_final(sqlite3_context *);
void go_value(sqlite3_context *);
void go_inverse(sqlite3_context *, int, sqlite3_value **);
void go_destroy(void *);
int sqlite3_create_collation_go(sqlite3 *db, const char *zName, void *pApp) {
return sqlite3_create_collation_v2(db, zName, SQLITE_UTF8, pApp, go_compare,
go_destroy);
}
int sqlite3_create_function_go(sqlite3 *db, const char *zName, int nArg,
int flags, void *pApp) {
return sqlite3_create_function_v2(db, zName, nArg, SQLITE_UTF8 | flags, pApp,
go_func, /*step=*/NULL, /*final=*/NULL,
go_destroy);
}
int sqlite3_create_aggregate_function_go(sqlite3 *db, const char *zName,
int nArg, int flags, void *pApp) {
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
pApp, go_step, go_final, /*value=*/NULL,
/*inverse=*/NULL, go_destroy);
}
int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg,
int flags, void *pApp) {
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
pApp, go_step, go_final, go_value,
go_inverse, go_destroy);
}
void sqlite3_set_auxdata_go(sqlite3_context *ctx, int iArg, void *pAux) {
sqlite3_set_auxdata(ctx, iArg, pAux, go_destroy);
}

34
sqlite3/isoweek.patch Normal file
View File

@@ -0,0 +1,34 @@
# ISO week date specifiers.
# https://sqlite.org/forum/forumpost/73d99e4497e8e6a7
--- sqlite3.c.orig
+++ sqlite3.c
@@ -1373,6 +1373,29 @@ static void strftimeFunc(
sqlite3_str_appendchar(&sRes, 1, c);
break;
}
+ case 'V': /* Fall thru */
+ case 'G': {
+ DateTime y = x;
+ computeJD(&y);
+ y.validYMD = 0;
+ /* Adjust date to Thursday this week:
+ The number in parentheses is 0 for Monday, 3 for Thursday */
+ y.iJD += (3 - (((y.iJD+43200000)/86400000) % 7))*86400000;
+ computeYMD(&y);
+ if( cf=='G' ){
+ sqlite3_str_appendf(&sRes,"%04d",y.Y);
+ }else{
+ int nDay; /* Number of days since 1st day of year */
+ i64 tJD = y.iJD;
+ y.validJD = 0;
+ y.M = 1;
+ y.D = 1;
+ computeJD(&y);
+ nDay = (int)((tJD-y.iJD+43200000)/86400000);
+ sqlite3_str_appendf(&sRes,"%02d",nDay/7+1);
+ }
+ break;
+ }
case 'Y': {
sqlite3_str_appendf(&sRes,"%04d",x.Y);
break;

View File

@@ -0,0 +1,14 @@
# Use exclusive locking mode for WAL databases with v1 VFSes.
--- sqlite3.c.orig
+++ sqlite3.c
@@ -63210,7 +63210,9 @@
SQLITE_PRIVATE int sqlite3PagerWalSupported(Pager *pPager){
const sqlite3_io_methods *pMethods = pPager->fd->pMethods;
if( pPager->noLock ) return 0;
- return pPager->exclusiveMode || (pMethods->iVersion>=2 && pMethods->xShmMap);
+ if( pMethods->iVersion>=2 && pMethods->xShmMap ) return 1;
+ pPager->exclusiveMode = 1;
+ return 1;
}
/*

View File

@@ -1,19 +1,17 @@
#include <stdbool.h>
#include <stddef.h>
// Configuration
#include "sqlite_cfg.h"
// Amalgamation
#include "sqlite3.c"
// VFS
#include "vfs.c"
// Extensions
#include "ext/anycollseq.c"
#include "ext/base64.c"
#include "ext/decimal.c"
#include "ext/regexp.c"
#include "ext/series.c"
#include "ext/uint.c"
#include "ext/uuid.c"
#include "func.c"
#include "progress.c"
#include "time.c"
__attribute__((constructor)) void init() {

9
sqlite3/progress.c Normal file
View File

@@ -0,0 +1,9 @@
#include <stddef.h>
#include "sqlite3.h"
int go_progress(void *);
void sqlite3_progress_handler_go(sqlite3 *db, int n) {
sqlite3_progress_handler(db, n, go_progress, /*arg=*/NULL);
}

View File

@@ -5,12 +5,28 @@
#define SQLITE_OS_OTHER 1
#define SQLITE_BYTEORDER 1234
#define HAVE_INT8_T 1
#define HAVE_INT16_T 1
#define HAVE_INT32_T 1
#define HAVE_INT64_T 1
#define HAVE_UINT8_T 1
#define HAVE_UINT16_T 1
#define HAVE_UINT32_T 1
#define HAVE_UINT64_T 1
#define HAVE_STDINT_H 1
#define HAVE_INTTYPES_H 1
#define HAVE_LOG2 1
#define HAVE_LOG10 1
#define HAVE_ISNAN 1
#define HAVE_USLEEP 1
#define HAVE_NANOSLEEP 1
#define HAVE_GMTIME_R 1
#define HAVE_LOCALTIME_S 1
#define HAVE_MALLOC_H 1
#define HAVE_MALLOC_USABLE_SIZE 1
// Recommended Options
@@ -23,12 +39,12 @@
#define SQLITE_MAX_EXPR_DEPTH 0
#define SQLITE_OMIT_DECLTYPE
#define SQLITE_OMIT_DEPRECATED
#define SQLITE_OMIT_PROGRESS_CALLBACK
#define SQLITE_OMIT_SHARED_CACHE
#define SQLITE_OMIT_AUTOINIT
#define SQLITE_USE_ALLOCA
// Other Options
#define SQLITE_ALLOW_URI_AUTHORITY
#define SQLITE_ENABLE_BATCH_ATOMIC_WRITE
#define SQLITE_ENABLE_ATOMIC_WRITE
@@ -36,12 +52,9 @@
// Because WASM does not support shared memory,
// SQLite disables WAL for WASM builds.
// We set the default locking mode to EXCLUSIVE instead.
// We patch SQLite to use exclusive locking mode instead.
// https://www.sqlite.org/wal.html#noshm
#undef SQLITE_OMIT_WAL
#ifndef SQLITE_DEFAULT_LOCKING_MODE
#define SQLITE_DEFAULT_LOCKING_MODE 1
#endif
// Amalgamated Extensions
@@ -54,6 +67,8 @@
#define SQLITE_ENABLE_RTREE 1
#define SQLITE_ENABLE_GEOPOLY 1
#define SQLITE_SOUNDEX
// Session Extension
// #define SQLITE_ENABLE_SESSION
// #define SQLITE_ENABLE_PREUPDATE_HOOK

View File

@@ -1,3 +1,4 @@
#include <stddef.h>
#include <string.h>
#include "sqlite3.h"
@@ -26,7 +27,63 @@ static int time_collation(void *pArg, int nKey1, const void *pKey1, int nKey2,
return rc;
}
static void json_time_func(sqlite3_context *context, int argc,
sqlite3_value **argv) {
DateTime x;
if (isDate(context, argc, argv, &x)) return;
if (x.tzSet && x.tz) {
x.iJD += x.tz * 60000;
if (!validJulianDay(x.iJD)) return;
x.validYMD = 0;
x.validHMS = 0;
}
computeYMD_HMS(&x);
sqlite3 *db = sqlite3_context_db_handle(context);
sqlite3_str *res = sqlite3_str_new(db);
sqlite3_str_appendf(res, "%04d-%02d-%02dT%02d:%02d:%02d", //
x.Y, x.M, x.D, //
x.h, x.m, (int)(x.iJD / 1000 % 60));
if (x.useSubsec) {
int rem = x.iJD % 1000;
if (rem) {
sqlite3_str_appendchar(res, 1, '.');
sqlite3_str_appendchar(res, 1, '0' + rem / 100);
if ((rem %= 100)) {
sqlite3_str_appendchar(res, 1, '0' + rem / 10);
if ((rem %= 10)) {
sqlite3_str_appendchar(res, 1, '0' + rem);
}
}
}
}
if (x.tz) {
sqlite3_str_appendf(res, "%+03d:%02d", x.tz / 60, abs(x.tz) % 60);
} else {
sqlite3_str_appendchar(res, 1, 'Z');
}
int rc = sqlite3_str_errcode(res);
if (rc) {
sqlite3_result_error_code(context, rc);
return;
}
int n = sqlite3_str_length(res);
sqlite3_result_text(context, sqlite3_str_finish(res), n, sqlite3_free);
}
int sqlite3_time_init(sqlite3 *db, char **pzErrMsg,
const sqlite3_api_routines *pApi) {
return sqlite3_create_collation(db, "time", SQLITE_UTF8, 0, time_collation);
sqlite3_create_collation_v2(db, "time", SQLITE_UTF8, /*arg=*/NULL,
time_collation,
/*destroy=*/NULL);
sqlite3_create_function_v2(
db, "json_time", -1,
SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS, /*arg=*/NULL,
json_time_func, /*step=*/NULL, /*final=*/NULL, /*destroy=*/NULL);
return SQLITE_OK;
}

45
sqlite3/timezone.patch Normal file
View File

@@ -0,0 +1,45 @@
# Set UTC timezone, compute local offset.
--- sqlite3.c.orig
+++ sqlite3.c
@@ -340,6 +340,7 @@ static int setDateTimeToCurrent(sqlite3_context *context, DateTime *p){
p->iJD = sqlite3StmtCurrentTime(context);
if( p->iJD>0 ){
p->validJD = 1;
+ p->tzSet = 1;
return 0;
}else{
return 1;
@@ -355,6 +356,7 @@ static int setDateTimeToCurrent(sqlite3_context *context, DateTime *p){
static void setRawDateNumber(DateTime *p, double r){
p->s = r;
p->rawS = 1;
+ p->tzSet = 1;
if( r>=0.0 && r<5373484.5 ){
p->iJD = (sqlite3_int64)(r*86400000.0 + 0.5);
p->validJD = 1;
@@ -731,7 +733,16 @@ static int parseModifier(
** show local time.
*/
if( sqlite3_stricmp(z, "localtime")==0 && sqlite3NotPureFunc(pCtx) ){
- rc = toLocaltime(p, pCtx);
+ if( p->tzSet!=0 || p->tz==0 ) {
+ rc = toLocaltime(p, pCtx);
+ i64 iOrigJD = p->iJD;
+ p->tzSet = 0;
+ computeJD(p);
+ p->tz = (p->iJD-iOrigJD)/60000;
+ if( abs(p->tz)>= 900 ) p->tz = 0;
+ } else {
+ rc = 0;
+ }
}
break;
}
@@ -781,6 +792,7 @@ static int parseModifier(
p->validJD = 1;
p->tzSet = 1;
}
+ p->tz = 0;
rc = SQLITE_OK;
}
#endif

View File

@@ -1,3 +1,5 @@
#include <stdbool.h>
#include <stddef.h>
#include <time.h>
#include "sqlite3.h"
@@ -90,22 +92,25 @@ int localtime_s(struct tm *const pTm, time_t const *const pTime) {
sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
if (zVfsName) {
static sqlite3_vfs *go_vfs_list;
sqlite3_vfs *found = NULL;
for (sqlite3_vfs **next = &go_vfs_list; *next;) {
sqlite3_vfs *it = *next;
for (sqlite3_vfs *it = go_vfs_list; it; it = it->pNext) {
if (!strcmp(zVfsName, it->zName) && go_vfs_find(it->zName)) {
return it;
}
}
for (sqlite3_vfs **ptr = &go_vfs_list; *ptr;) {
sqlite3_vfs *it = *ptr;
if (go_vfs_find(it->zName)) {
if (!strcmp(zVfsName, it->zName)) found = it;
next = &it->pNext;
ptr = &it->pNext;
} else {
*next = it->pNext;
*ptr = it->pNext;
free(it);
}
}
if (found) {
return found;
}
if (go_vfs_find(zVfsName)) {
sqlite3_vfs *prev = go_vfs_list;
sqlite3_vfs *head = go_vfs_list;
go_vfs_list = malloc(sizeof(sqlite3_vfs) + strlen(zVfsName) + 1);
char *name = (char *)(go_vfs_list + 1);
strcpy(name, zVfsName);
@@ -114,7 +119,7 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
.szOsFile = sizeof(struct go_file),
.mxPathname = 512,
.zName = name,
.pNext = prev,
.pNext = head,
.xOpen = go_open_wrapper,
.xDelete = go_delete,
@@ -132,6 +137,5 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
return sqlite3_vfs_find_orig(zVfsName);
}
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");
static_assert(offsetof(sqlite3_vfs, zName) == 16, "Unexpected offset");
static_assert(offsetof(sqlite3, u1.isInterrupted) == 280, "Unexpected offset");
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");

View File

@@ -1,3 +1,4 @@
# Wrap sqlite3_vfs_find.
--- sqlite3.c.orig
+++ sqlite3.c
@@ -25394,7 +25394,7 @@

View File

@@ -3,6 +3,7 @@ package sqlite3
import (
"bytes"
"math"
"os"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -12,67 +13,73 @@ func init() {
Path = "./embed/sqlite3.wasm"
}
func TestConn_error_OOM(t *testing.T) {
func Test_sqlite_error_OOM(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
defer func() { _ = recover() }()
m.error(uint64(NOMEM), 0)
sqlite.error(uint64(NOMEM), 0)
t.Error("want panic")
}
func TestConn_call_closed(t *testing.T) {
func Test_sqlite_call_closed(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
m.close()
sqlite.close()
defer func() { _ = recover() }()
m.call(m.api.free)
sqlite.call(sqlite.api.free)
t.Error("want panic")
}
func TestConn_new(t *testing.T) {
func Test_sqlite_new(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
t.Run("MaxUint32", func(t *testing.T) {
defer func() { _ = recover() }()
m.new(math.MaxUint32)
sqlite.new(math.MaxUint32)
t.Error("want panic")
})
t.Run("_MAX_ALLOCATION_SIZE", func(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
if os.Getenv("CI") != "" {
t.Skip("skipping in CI")
}
defer func() { _ = recover() }()
m.new(_MAX_ALLOCATION_SIZE)
m.new(_MAX_ALLOCATION_SIZE)
sqlite.new(_MAX_ALLOCATION_SIZE)
sqlite.new(_MAX_ALLOCATION_SIZE)
t.Error("want panic")
})
}
func TestConn_newArena(t *testing.T) {
func Test_sqlite_newArena(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
arena := m.newArena(16)
arena := sqlite.newArena(16)
defer arena.free()
const title = "Lorem ipsum"
@@ -80,7 +87,7 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != title {
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != title {
t.Errorf("got %q, want %q", got, title)
}
@@ -89,7 +96,7 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != body {
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != body {
t.Errorf("got %q, want %q", got, body)
}
@@ -101,121 +108,130 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := util.View(m.mod, ptr, uint64(len(title))); string(got) != title {
if got := util.View(sqlite.mod, ptr, uint64(len(title))); string(got) != title {
t.Errorf("got %q, want %q", got, title)
}
arena.free()
}
func TestConn_newBytes(t *testing.T) {
func Test_sqlite_newBytes(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
ptr := m.newBytes(nil)
ptr := sqlite.newBytes(nil)
if ptr != 0 {
t.Errorf("got %#x, want nullptr", ptr)
}
buf := []byte("sqlite3")
ptr = m.newBytes(buf)
ptr = sqlite.newBytes(buf)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := buf
if got := util.View(m.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
ptr = sqlite.newBytes(buf[:0])
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
if got := util.View(sqlite.mod, ptr, 0); got != nil {
t.Errorf("got %q, want nil", got)
}
}
func TestConn_newString(t *testing.T) {
func Test_sqlite_newString(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
ptr := m.newString("")
ptr := sqlite.newString("")
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
str := "sqlite3\000sqlite3"
ptr = m.newString(str)
ptr = sqlite.newString(str)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := str + "\000"
if got := util.View(m.mod, ptr, uint64(len(want))); string(got) != want {
if got := util.View(sqlite.mod, ptr, uint64(len(want))); string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestConn_getString(t *testing.T) {
func Test_sqlite_getString(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
ptr := m.newString("")
ptr := sqlite.newString("")
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
str := "sqlite3" + "\000 drop this"
ptr = m.newString(str)
ptr = sqlite.newString(str)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := "sqlite3"
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != want {
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != want {
t.Errorf("got %q, want %q", got, want)
}
if got := util.ReadString(m.mod, ptr, 0); got != "" {
if got := util.ReadString(sqlite.mod, ptr, 0); got != "" {
t.Errorf("got %q, want empty", got)
}
func() {
defer func() { _ = recover() }()
util.ReadString(m.mod, ptr, uint32(len(want)/2))
util.ReadString(sqlite.mod, ptr, uint32(len(want)/2))
t.Error("want panic")
}()
func() {
defer func() { _ = recover() }()
util.ReadString(m.mod, 0, math.MaxUint32)
util.ReadString(sqlite.mod, 0, math.MaxUint32)
t.Error("want panic")
}()
}
func TestConn_free(t *testing.T) {
func Test_sqlite_free(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
m.free(0)
sqlite.free(0)
ptr := m.new(1)
ptr := sqlite.new(1)
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
m.free(ptr)
sqlite.free(ptr)
}

80
stmt.go
View File

@@ -1,7 +1,9 @@
package sqlite3
import (
"encoding/json"
"math"
"strconv"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -61,12 +63,12 @@ func (s *Stmt) ClearBindings() error {
func (s *Stmt) Step() bool {
s.c.checkInterrupt()
r := s.c.call(s.c.api.step, uint64(s.handle))
if r == _ROW {
switch r {
case _ROW:
return true
}
if r == _DONE {
case _DONE:
s.err = nil
} else {
default:
s.err = s.c.error(r)
}
return false
@@ -131,10 +133,11 @@ func (s *Stmt) BindName(param int) string {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindBool(param int, value bool) error {
var i int64
if value {
return s.BindInt64(param, 1)
i = 1
}
return s.BindInt64(param, 0)
return s.BindInt64(param, i)
}
// BindInt binds an int to the prepared statement.
@@ -234,7 +237,7 @@ func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
}
func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
const maxlen = uint64(len(time.RFC3339Nano))
const maxlen = uint64(len(time.RFC3339Nano)) + 5
ptr := s.c.new(maxlen)
buf := util.View(s.c.mod, ptr, maxlen)
@@ -247,6 +250,23 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
return s.c.error(r)
}
// BindJSON binds the JSON encoding of value to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindJSON(param int, value any) error {
data, err := json.Marshal(value)
if err != nil {
return err
}
ptr := s.c.newBytes(data)
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(data)),
uint64(s.c.api.destructor), _UTF8)
return s.c.error(r)
}
// ColumnCount returns the number of columns in a result set.
//
// https://www.sqlite.org/c3ref/column_count.html
@@ -374,18 +394,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
func (s *Stmt) ColumnRawText(col int) []byte {
r := s.c.call(s.c.api.columnText,
uint64(s.handle), uint64(col))
ptr := uint32(r)
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r)
return nil
}
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
return util.View(s.c.mod, ptr, r)
return s.columnRawBytes(col, uint32(r))
}
// ColumnRawBlob returns the value of the result column as a []byte.
@@ -397,20 +406,45 @@ func (s *Stmt) ColumnRawText(col int) []byte {
func (s *Stmt) ColumnRawBlob(col int) []byte {
r := s.c.call(s.c.api.columnBlob,
uint64(s.handle), uint64(col))
return s.columnRawBytes(col, uint32(r))
}
ptr := uint32(r)
func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte {
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
r := s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r)
return nil
}
r = s.c.call(s.c.api.columnBytes,
r := s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
return util.View(s.c.mod, ptr, r)
}
// ColumnJSON parses the JSON-encoded value of the result column
// and stores it in the value pointed to by ptr.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnJSON(col int, ptr any) error {
var data []byte
switch s.ColumnType(col) {
case NULL:
data = append(data, "null"...)
case TEXT:
data = s.ColumnRawText(col)
case BLOB:
data = s.ColumnRawBlob(col)
case INTEGER:
data = strconv.AppendInt(nil, s.ColumnInt64(col), 10)
case FLOAT:
data = strconv.AppendFloat(nil, s.ColumnFloat(col), 'g', -1, 64)
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}
// Return true if stmt is an empty SQL statement.
// This is used as an optimization.
// It's OK to always return false here.

View File

@@ -43,7 +43,7 @@ func (t params) mustExec(sql string, args ...interface{}) sql.Result {
func (sqliteDB) RunTest(t *testing.T, fn func(params)) {
db, err := sql.Open("sqlite3", "file:"+
filepath.Join(t.TempDir(), "foo.db")+
"?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)&_pragma=synchronous(off)")
"?_pragma=busy_timeout(10000)&_pragma=synchronous(off)")
if err != nil {
t.Fatalf("foo.db open fail: %v", err)
}

View File

@@ -182,7 +182,7 @@ func TestConn_SetInterrupt(t *testing.T) {
defer stmt.Close()
db.SetInterrupt(ctx)
cancel()
go cancel()
// Interrupting works.
err = stmt.Exec()

View File

@@ -1,14 +1,20 @@
package tests
import (
"os"
"path/filepath"
"testing"
_ "embed"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
//go:embed testdata/wal.db
var waldb []byte
func TestDB_memory(t *testing.T) {
t.Parallel()
testDB(t, ":memory:")
@@ -19,6 +25,23 @@ func TestDB_file(t *testing.T) {
testDB(t, filepath.Join(t.TempDir(), "test.db"))
}
func TestDB_nolock(t *testing.T) {
t.Parallel()
testDB(t, "file:"+
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
"?nolock=1")
}
func TestDB_wal(t *testing.T) {
t.Parallel()
wal := filepath.Join(t.TempDir(), "test.db")
err := os.WriteFile(wal, waldb, 0666)
if err != nil {
t.Fatal(err)
}
testDB(t, wal)
}
func TestDB_vfs(t *testing.T) {
testDB(t, "file:test.db?vfs=memdb")
}

View File

@@ -2,10 +2,9 @@ package tests
import (
"context"
"database/sql"
"testing"
_ "github.com/ncruces/go-sqlite3/driver"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
@@ -15,7 +14,7 @@ func TestDriver(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", ":memory:")
db, err := driver.Open(":memory:", nil)
if err != nil {
t.Fatal(err)
}

218
tests/func_test.go Normal file
View File

@@ -0,0 +1,218 @@
package tests
import (
"errors"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestCreateFunction(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.CreateFunction("test", 1, sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
switch arg := arg[0]; arg.Int() {
case 0:
ctx.ResultInt(arg.Int())
case 1:
ctx.ResultInt64(arg.Int64())
case 2:
ctx.ResultBool(arg.Bool())
case 3:
ctx.ResultFloat(arg.Float())
case 4:
ctx.ResultText(arg.Text())
case 5:
ctx.ResultBlob(arg.Blob(nil))
case 6:
ctx.ResultZeroBlob(arg.Int64())
case 7:
ctx.ResultTime(arg.Time(sqlite3.TimeFormatUnix), sqlite3.TimeFormatDefault)
case 8:
var v any
if err := arg.JSON(&v); err != nil {
ctx.ResultError(err)
} else {
ctx.ResultJSON(v)
}
case 9:
ctx.ResultValue(arg)
case 10:
ctx.ResultNull()
case 11:
ctx.ResultError(sqlite3.FULL)
}
})
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0)`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want 1", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnInt64(0); got != 1 {
t.Errorf("got %v, want 2", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != true {
t.Errorf("got %v, want true", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.FLOAT {
t.Errorf("got %v, want FLOAT", got)
}
if got := stmt.ColumnInt64(0); got != 3 {
t.Errorf("got %v, want 3", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnText(0); got != "4" {
t.Errorf("got %s, want 4", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnRawBlob(0); string(got) != "5" {
t.Errorf("got %s, want 5", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnRawBlob(0); len(got) != 6 {
t.Errorf("got %v, want 6", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); got.Unix() != 7 {
t.Errorf("got %v, want 7", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
var got int
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != 8 {
t.Errorf("got %v, want 8", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnInt64(0); got != 9 {
t.Errorf("got %v, want 9", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
}
if stmt.Step() {
t.Error("want error")
}
if err := stmt.Err(); !errors.Is(err, sqlite3.FULL) {
t.Errorf("got %v, want sqlite3.FULL", err)
}
}
func TestAnyCollationNeeded(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}
db.AnyCollationNeeded()
stmt, _, err := db.Prepare(`SELECT id, name FROM users ORDER BY name COLLATE silly`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
row := 0
ids := []int{0, 2, 1}
names := []string{"go", "whatever", "zig"}
for ; stmt.Step(); row++ {
id := stmt.ColumnInt(0)
name := stmt.ColumnText(1)
if id != ids[row] {
t.Errorf("got %d, want %d", id, ids[row])
}
if name != names[row] {
t.Errorf("got %q, want %q", name, names[row])
}
}
if row != 3 {
t.Errorf("got %d, want %d", row, len(ids))
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
}

68
tests/json_test.go Normal file
View File

@@ -0,0 +1,68 @@
package tests
import (
"encoding/json"
"math"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
"github.com/ncruces/julianday"
)
func TestJSON(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", nil)
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
_, err = db.Exec(
`INSERT INTO test (col) VALUES (?), (?), (?), (?)`,
nil, 1, math.Pi, reference,
)
if err != nil {
t.Fatal(err)
}
_, err = db.Exec(
`INSERT INTO test (col) VALUES (?), (?), (?), (?)`,
sqlite3.JSON(math.Pi), sqlite3.JSON(false),
julianday.Format(reference), sqlite3.JSON([]string{}))
if err != nil {
t.Fatal(err)
}
rows, err := db.Query("SELECT * FROM test")
if err != nil {
t.Fatal(err)
}
want := []string{
"null", "1", "3.141592653589793",
`"2013-10-07T04:23:19.12-04:00"`,
"3.141592653589793", "false",
"2456572.849526851851852", "[]",
}
for rows.Next() {
var got json.RawMessage
err = rows.Scan(sqlite3.JSON(&got))
if err != nil {
t.Fatal(err)
}
if string(got) != want[0] {
t.Errorf("got %q, want %q", got, want[0])
}
want = want[1:]
}
}

View File

@@ -25,7 +25,6 @@ func TestParallel(t *testing.T) {
name := "file:" +
filepath.Join(t.TempDir(), "test.db") +
"?_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
testParallel(t, name, iter)
@@ -42,7 +41,6 @@ func TestMemory(t *testing.T) {
name := "file:/test.db?vfs=memdb" +
"&_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(memory)" +
"&_pragma=synchronous(off)"
testParallel(t, name, iter)
@@ -59,7 +57,6 @@ func TestMultiProcess(t *testing.T) {
name := "file:" + file +
"?_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
@@ -93,7 +90,6 @@ func TestChildProcess(t *testing.T) {
name := "file:" + file +
"?_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
@@ -128,10 +124,7 @@ func testParallel(t *testing.T, name string, n int) {
}
defer db.Close()
err = db.Exec(`
PRAGMA busy_timeout=10000;
PRAGMA locking_mode=normal;
`)
err = db.Exec(`PRAGMA busy_timeout=10000`)
if err != nil {
return err
}

View File

@@ -1,6 +1,7 @@
package tests
import (
"encoding/json"
"math"
"testing"
"time"
@@ -81,6 +82,13 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
if err := stmt.BindBlob(1, []byte("")); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.BindBlob(1, []byte("blob")); err != nil {
t.Fatal(err)
}
@@ -102,6 +110,13 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
if err := stmt.BindJSON(1, true); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.ClearBindings(); err != nil {
t.Fatal(err)
}
@@ -114,7 +129,7 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
// The table should have: 0, 1, 2, π, NULL, "", "text", "blob", NULL, "\0\0\0\0", NULL
// The table should have: 0, 1, 2, π, NULL, "", "text", "", "blob", NULL, "\0\0\0\0", "true", NULL
stmt, _, err = db.Prepare(`SELECT col FROM test`)
if err != nil {
t.Fatal(err)
@@ -140,6 +155,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "0" {
t.Errorf("got %q, want zero", got)
}
var got int
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != 0 {
t.Errorf("got %v, want zero", got)
}
}
if stmt.Step() {
@@ -161,6 +182,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "1" {
t.Errorf("got %q, want one", got)
}
var got float32
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != 1 {
t.Errorf("got %v, want one", got)
}
}
if stmt.Step() {
@@ -182,6 +209,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "2" {
t.Errorf("got %q, want two", got)
}
var got json.Number
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != "2" {
t.Errorf("got %v, want two", got)
}
}
if stmt.Step() {
@@ -203,6 +236,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "3.14159265358979" {
t.Errorf("got %q, want π", got)
}
var got float64
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != math.Pi {
t.Errorf("got %v, want π", got)
}
}
if stmt.Step() {
@@ -224,6 +263,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != nil {
t.Errorf("got %v, want NULL", got)
}
}
if stmt.Step() {
@@ -245,6 +290,10 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
@@ -266,6 +315,35 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "text" {
t.Errorf(`got %q, want "text"`, got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
@@ -287,6 +365,10 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "blob" {
t.Errorf(`got %q, want "blob"`, got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
@@ -308,6 +390,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != nil {
t.Errorf("got %v, want NULL", got)
}
}
if stmt.Step() {
@@ -329,6 +417,37 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "\x00\x00\x00\x00" {
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "true" {
t.Errorf("got %q, want true", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "true" {
t.Errorf("got %q, want true", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != true {
t.Errorf("got %v, want true", got)
}
}
if stmt.Step() {
@@ -350,6 +469,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != nil {
t.Errorf("got %v, want NULL", got)
}
}
if err := stmt.Close(); err != nil {

BIN
tests/testdata/wal.db vendored Normal file

Binary file not shown.

View File

@@ -1,11 +1,13 @@
package tests
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
)
func TestTimeFormat_Encode(t *testing.T) {
@@ -39,70 +41,72 @@ func TestTimeFormat_Encode(t *testing.T) {
func TestTimeFormat_Decode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
zone := time.FixedZone("", -4*3600)
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, zone)
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, zone)
tests := []struct {
fmt sqlite3.TimeFormat
val any
want time.Time
wantDelta time.Duration
wantLoc *time.Location
wantErr bool
}{
{sqlite3.TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
{sqlite3.TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
{sqlite3.TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
{sqlite3.TimeFormatJulianDay, false, time.Time{}, 0, true},
{sqlite3.TimeFormatJulianDay, "2456572.849526851851852", reference, 0, time.UTC, false},
{sqlite3.TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, time.UTC, false},
{sqlite3.TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, time.UTC, false},
{sqlite3.TimeFormatJulianDay, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
{sqlite3.TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
{sqlite3.TimeFormatUnix, int64(1381134199), reference, time.Second, false},
{sqlite3.TimeFormatUnix, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnix, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnix, "1381134199.120", reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatUnix, 1381134199.120, reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatUnix, int64(1381134199), reference, time.Second, time.UTC, false},
{sqlite3.TimeFormatUnix, "abc", time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnix, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixMilli, "1381134199120", reference, 0, false},
{sqlite3.TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
{sqlite3.TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
{sqlite3.TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMilli, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMilli, "1381134199120", reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMilli, 1381134199.120e3, reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMilli, int64(1381134199_120), reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMilli, "abc", time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixMilli, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
{sqlite3.TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
{sqlite3.TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
{sqlite3.TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMicro, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMicro, "1381134199120000", reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMicro, 1381134199.120e6, reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMicro, "abc", time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixMicro, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
{sqlite3.TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
{sqlite3.TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
{sqlite3.TimeFormatUnixNano, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixNano, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixNano, "1381134199120000000", reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixNano, 1381134199.120e9, reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixNano, "abc", time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixNano, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
{sqlite3.TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
{sqlite3.TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199", reference, time.Second, false},
{sqlite3.TimeFormatAuto, "1381134199120", reference, 0, false},
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, false},
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, false},
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, true},
{sqlite3.TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, time.UTC, false},
{sqlite3.TimeFormatAuto, "2456572", reference, 24 * time.Hour, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199.120", reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199", reference, time.Second, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199120", reference, 0, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, time.UTC, false},
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, zone, false},
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, zone, false},
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, nil, true},
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, false},
{sqlite3.TimeFormat3, false, time.Time{}, 0, true},
{sqlite3.TimeFormat9, false, time.Time{}, 0, true},
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, zone, false},
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, time.UTC, false},
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, zone, false},
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, time.UTC, false},
{sqlite3.TimeFormat3, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormat9, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
{sqlite3.TimeFormatDefault, false, time.Time{}, 0, true},
{sqlite3.TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, zone, false},
{sqlite3.TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, time.UTC, false},
{sqlite3.TimeFormatDefault, false, time.Time{}, 0, nil, true},
}
for _, tt := range tests {
@@ -112,13 +116,48 @@ func TestTimeFormat_Decode(t *testing.T) {
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
return
}
if tt.want.Sub(got).Abs() > tt.wantDelta {
if got.Sub(tt.want).Abs() > tt.wantDelta {
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
}
if got.Location().String() != tt.wantLoc.String() {
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got.Location(), tt.wantLoc)
}
})
}
}
func TestTimeFormat_Scanner(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", nil)
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Exec(
`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
_, err = db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.TimeFormat7TZ.Encode(reference))
if err != nil {
t.Fatal(err)
}
var got time.Time
err = db.QueryRow("SELECT * FROM test").Scan(sqlite3.TimeFormatAuto.Scanner(&got))
if err != nil {
t.Fatal(err)
}
if !got.Equal(reference) {
t.Errorf("got %v, want %v", got, reference)
}
}
func TestDB_timeCollation(t *testing.T) {
t.Parallel()
@@ -167,3 +206,57 @@ func TestDB_timeCollation(t *testing.T) {
t.Fatal(err)
}
}
func TestDB_isoWeek(t *testing.T) {
t.Parallel()
tests := []time.Time{
time.Date(1977, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1977, 1, 2, 0, 0, 0, 0, time.UTC),
time.Date(1977, 12, 31, 0, 0, 0, 0, time.UTC),
time.Date(1978, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1978, 1, 2, 0, 0, 0, 0, time.UTC),
time.Date(1978, 12, 31, 0, 0, 0, 0, time.UTC),
time.Date(1979, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1979, 1, 2, 0, 0, 0, 0, time.UTC),
time.Date(1979, 12, 31, 0, 0, 0, 0, time.UTC),
time.Date(1980, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1980, 12, 28, 0, 0, 0, 0, time.UTC),
time.Date(1980, 12, 29, 0, 0, 0, 0, time.UTC),
time.Date(1980, 12, 30, 0, 0, 0, 0, time.UTC),
time.Date(1980, 12, 31, 0, 0, 0, 0, time.UTC),
time.Date(1981, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1981, 12, 31, 0, 0, 0, 0, time.UTC),
time.Date(1982, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1982, 1, 2, 0, 0, 0, 0, time.UTC),
time.Date(1982, 1, 3, 0, 0, 0, 0, time.UTC),
}
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT strftime('%G-W%V-%u', ?)`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
for _, tm := range tests {
stmt.BindTime(1, tm, sqlite3.TimeFormatDefault)
if stmt.Step() {
y, w := tm.ISOWeek()
d := tm.Weekday()
if d == 0 {
d = 7
}
want := fmt.Sprintf("%04d-W%02d-%d", y, w, d)
if got := stmt.ColumnText(0); got != want {
t.Errorf("got %q, want %q (%v)", got, want, tm)
}
}
stmt.Reset()
}
}

56
time.go
View File

@@ -164,9 +164,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case float64:
sec, frac := math.Modf(v)
nsec := math.Floor(frac * 1e9)
return time.Unix(int64(sec), int64(nsec)), nil
return time.Unix(int64(sec), int64(nsec)).UTC(), nil
case int64:
return time.Unix(v, 0), nil
return time.Unix(v, 0).UTC(), nil
default:
return time.Time{}, util.TimeErr
}
@@ -181,9 +181,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
switch v := v.(type) {
case float64:
return time.UnixMilli(int64(math.Floor(v))), nil
return time.UnixMilli(int64(math.Floor(v))).UTC(), nil
case int64:
return time.UnixMilli(int64(v)), nil
return time.UnixMilli(int64(v)).UTC(), nil
default:
return time.Time{}, util.TimeErr
}
@@ -198,9 +198,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
switch v := v.(type) {
case float64:
return time.UnixMicro(int64(math.Floor(v))), nil
return time.UnixMicro(int64(math.Floor(v))).UTC(), nil
case int64:
return time.UnixMicro(int64(v)), nil
return time.UnixMicro(int64(v)).UTC(), nil
default:
return time.Time{}, util.TimeErr
}
@@ -215,9 +215,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
switch v := v.(type) {
case float64:
return time.Unix(0, int64(math.Floor(v))), nil
return time.Unix(0, int64(math.Floor(v))).UTC(), nil
case int64:
return time.Unix(0, int64(v)), nil
return time.Unix(0, int64(v)).UTC(), nil
default:
return time.Time{}, util.TimeErr
}
@@ -238,26 +238,16 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
dates := []TimeFormat{
TimeFormat6TZ, TimeFormat6, TimeFormat3TZ, TimeFormat3,
TimeFormat5TZ, TimeFormat5, TimeFormat2TZ, TimeFormat2,
TimeFormat1,
TimeFormat9, TimeFormat8,
TimeFormat6, TimeFormat5,
TimeFormat3, TimeFormat2, TimeFormat1,
}
for _, f := range dates {
t, err := time.Parse(string(f), s)
t, err := f.Decode(s)
if err == nil {
return t, nil
}
}
times := []TimeFormat{
TimeFormat9TZ, TimeFormat9, TimeFormat8TZ, TimeFormat8,
}
for _, f := range times {
t, err := time.Parse(string(f), s)
if err == nil {
return t.AddDate(2000, 0, 0), nil
}
}
}
switch v := v.(type) {
case float64:
@@ -314,7 +304,10 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
return time.Time{}, util.TimeErr
}
t, err := f.parseRelaxed(s)
return t.AddDate(2000, 0, 0), err
if err != nil {
return time.Time{}, err
}
return t.AddDate(2000, 0, 0), nil
default:
s, ok := v.(string)
@@ -338,3 +331,20 @@ func (f TimeFormat) parseRelaxed(s string) (time.Time, error) {
}
return t, nil
}
// Scanner returns a [database/sql.Scanner] that can be used as an argument to
// [database/sql.Row.Scan] and similar methods to
// decode a time value into dest using this format.
func (f TimeFormat) Scanner(dest *time.Time) interface{ Scan(any) error } {
return timeScanner{dest, f}
}
type timeScanner struct {
*time.Time
TimeFormat
}
func (s timeScanner) Scan(src any) (err error) {
*s.Time, err = s.Decode(src)
return
}

148
value.go Normal file
View File

@@ -0,0 +1,148 @@
package sqlite3
import (
"encoding/json"
"math"
"strconv"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Value is any value that can be stored in a database table.
//
// https://www.sqlite.org/c3ref/value.html
type Value struct {
*sqlite
handle uint32
}
// Type returns the initial [Datatype] of the value.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Type() Datatype {
r := v.call(v.api.valueType, uint64(v.handle))
return Datatype(r)
}
// Bool returns the value as a bool.
// SQLite does not have a separate boolean storage class.
// Instead, boolean values are retrieved as integers,
// with 0 converted to false and any other value to true.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Bool() bool {
if i := v.Int64(); i != 0 {
return true
}
return false
}
// Int returns the value as an int.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Int() int {
return int(v.Int64())
}
// Int64 returns the value as an int64.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Int64() int64 {
r := v.call(v.api.valueInteger, uint64(v.handle))
return int64(r)
}
// Float returns the value as a float64.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Float() float64 {
r := v.call(v.api.valueFloat, uint64(v.handle))
return math.Float64frombits(r)
}
// Time returns the value as a [time.Time].
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Time(format TimeFormat) time.Time {
var a any
switch v.Type() {
case INTEGER:
a = v.Int64()
case FLOAT:
a = v.Float()
case TEXT, BLOB:
a = v.Text()
case NULL:
return time.Time{}
default:
panic(util.AssertErr())
}
t, _ := format.Decode(a)
return t
}
// Text returns the value as a string.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Text() string {
return string(v.RawText())
}
// Blob appends to buf and returns
// the value as a []byte.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Blob(buf []byte) []byte {
return append(buf, v.RawBlob()...)
}
// RawText returns the value as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Value] methods.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) RawText() []byte {
r := v.call(v.api.valueText, uint64(v.handle))
return v.rawBytes(uint32(r))
}
// RawBlob returns the value as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Value] methods.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) RawBlob() []byte {
r := v.call(v.api.valueBlob, uint64(v.handle))
return v.rawBytes(uint32(r))
}
func (v Value) rawBytes(ptr uint32) []byte {
if ptr == 0 {
return nil
}
r := v.call(v.api.valueBytes, uint64(v.handle))
return util.View(v.mod, ptr, r)
}
// JSON parses a JSON-encoded value
// and stores the result in the value pointed to by ptr.
func (v Value) JSON(ptr any) error {
var data []byte
switch v.Type() {
case NULL:
data = append(data, "null"...)
case TEXT:
data = v.RawText()
case BLOB:
data = v.RawBlob()
case INTEGER:
data = strconv.AppendInt(nil, v.Int64(), 10)
case FLOAT:
data = strconv.AppendFloat(nil, v.Float(), 'g', -1, 64)
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}

View File

@@ -2,8 +2,6 @@
This package implements the SQLite [OS Interface](https://www.sqlite.org/vfs.html) (aka VFS).
It replaces the default VFS with a pure Go implementation,
that is tested on Linux, macOS and Windows,
but which should also work on illumos and the various BSDs.
It replaces the default SQLite VFS with a pure Go implementation.
It also exposes interfaces that should allow you to implement your own custom VFSes.

View File

@@ -15,7 +15,7 @@ type VFS interface {
FullPathname(name string) (string, error)
}
// VFSParams extends VFS to with the ability to handle URI parameters
// VFSParams extends VFS with the ability to handle URI parameters
// through the OpenParams method.
//
// https://www.sqlite.org/c3ref/uri_boolean.html
@@ -47,7 +47,7 @@ type File interface {
// FileLockState extends File to implement the
// SQLITE_FCNTL_LOCKSTATE file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntllockstate
type FileLockState interface {
File
LockState() LockLevel
@@ -56,7 +56,7 @@ type FileLockState interface {
// FileSizeHint extends File to implement the
// SQLITE_FCNTL_SIZE_HINT file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlsizehint
type FileSizeHint interface {
File
SizeHint(size int64) error
@@ -65,16 +65,25 @@ type FileSizeHint interface {
// FileHasMoved extends File to implement the
// SQLITE_FCNTL_HAS_MOVED file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlhasmoved
type FileHasMoved interface {
File
HasMoved() (bool, error)
}
// FileOverwrite extends File to implement the
// SQLITE_FCNTL_OVERWRITE file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntloverwrite
type FileOverwrite interface {
File
Overwrite() error
}
// FilePowersafeOverwrite extends File to implement the
// SQLITE_FCNTL_POWERSAFE_OVERWRITE file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlpowersafeoverwrite
type FilePowersafeOverwrite interface {
File
PowersafeOverwrite() bool
@@ -84,7 +93,7 @@ type FilePowersafeOverwrite interface {
// FilePowersafeOverwrite extends File to implement the
// SQLITE_FCNTL_COMMIT_PHASETWO file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlcommitphasetwo
type FileCommitPhaseTwo interface {
File
CommitPhaseTwo() error
@@ -94,7 +103,7 @@ type FileCommitPhaseTwo interface {
// SQLITE_FCNTL_BEGIN_ATOMIC_WRITE, SQLITE_FCNTL_COMMIT_ATOMIC_WRITE
// and SQLITE_FCNTL_ROLLBACK_ATOMIC_WRITE file control opcodes.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlbeginatomicwrite
type FileBatchAtomicWrite interface {
File
BeginAtomicWrite() error

9
vfs/clear.go Normal file
View File

@@ -0,0 +1,9 @@
//go:build !go1.21
package vfs
func clear(b []byte) {
for i := range b {
b[i] = 0
}
}

View File

@@ -9,7 +9,6 @@ import (
"path/filepath"
"runtime"
"syscall"
"time"
)
type vfsOS struct{}
@@ -124,11 +123,10 @@ func (vfsOS) OpenParams(name string, flags OpenFlag, params url.Values) (File, O
type vfsFile struct {
*os.File
lockTimeout time.Duration
lock LockLevel
psow bool
syncDir bool
readOnly bool
lock LockLevel
psow bool
syncDir bool
readOnly bool
}
var (

View File

@@ -1,11 +1,6 @@
package vfs
import (
"os"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
import "github.com/ncruces/go-sqlite3/internal/util"
const (
_PENDING_BYTE = 0x40000000
@@ -48,7 +43,7 @@ func (f *vfsFile) Lock(lock LockLevel) error {
if f.lock != LOCK_NONE {
panic(util.AssertErr())
}
if rc := osGetSharedLock(f.File, f.lockTimeout); rc != _OK {
if rc := osGetSharedLock(f.File); rc != _OK {
return rc
}
f.lock = LOCK_SHARED
@@ -59,7 +54,7 @@ func (f *vfsFile) Lock(lock LockLevel) error {
if f.lock != LOCK_SHARED {
panic(util.AssertErr())
}
if rc := osGetReservedLock(f.File, f.lockTimeout); rc != _OK {
if rc := osGetReservedLock(f.File); rc != _OK {
return rc
}
f.lock = LOCK_RESERVED
@@ -77,7 +72,7 @@ func (f *vfsFile) Lock(lock LockLevel) error {
}
f.lock = LOCK_PENDING
}
if rc := osGetExclusiveLock(f.File, f.lockTimeout); rc != _OK {
if rc := osGetExclusiveLock(f.File); rc != _OK {
return rc
}
f.lock = LOCK_EXCLUSIVE
@@ -133,18 +128,3 @@ func (f *vfsFile) CheckReservedLock() (bool, error) {
}
return osCheckReservedLock(f.File)
}
func osGetReservedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Acquire the RESERVED lock.
return osWriteLock(file, _RESERVED_BYTE, 1, timeout)
}
func osGetPendingLock(file *os.File) _ErrorCode {
// Acquire the PENDING lock.
return osWriteLock(file, _PENDING_BYTE, 1, 0)
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
// Test the RESERVED lock.
return osCheckLock(file, _RESERVED_BYTE, 1)
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -12,13 +11,6 @@ import (
)
func Test_vfsLock(t *testing.T) {
switch runtime.GOOS {
case "linux", "darwin", "windows":
break
default:
t.Skip("OS lacks OFD locks")
}
name := filepath.Join(t.TempDir(), "test.db")
// Create a temporary file.
@@ -41,8 +33,7 @@ func Test_vfsLock(t *testing.T) {
pOutput = 32
)
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
ctx, vfs := NewContext(context.TODO())
defer vfs.Close()
ctx := util.NewContext(context.TODO())
vfsFileRegister(ctx, mod, pFile1, &vfsFile{File: file1})
vfsFileRegister(ctx, mod, pFile2, &vfsFile{File: file2})
@@ -212,9 +203,4 @@ func Test_vfsLock(t *testing.T) {
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) {
t.Error("invalid lock state", got)
}
rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCK_TIMEOUT, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
}

10
vfs/memdb/clear.go Normal file
View File

@@ -0,0 +1,10 @@
//go:build !go1.21
package memdb
func clear[T any](b []T) {
var zero T
for i := range b {
b[i] = zero
}
}

View File

@@ -133,7 +133,7 @@ func (m *memFile) WriteAt(b []byte, off int64) (n int, err error) {
n = copy((*m.data[base])[rest:], b)
if n < len(b) {
// Assume writes are page aligned.
return 0, io.ErrShortWrite
return n, io.ErrShortWrite
}
if size := off + int64(len(b)); size > m.size {
m.size = size
@@ -176,6 +176,8 @@ func (m *memFile) Size() (int64, error) {
return m.size, nil
}
const spinWait = 25 * time.Microsecond
func (m *memFile) Lock(lock vfs.LockLevel) error {
if m.lock >= lock {
return nil
@@ -187,17 +189,11 @@ func (m *memFile) Lock(lock vfs.LockLevel) error {
m.lockMtx.Lock()
defer m.lockMtx.Unlock()
deadline := time.Now().Add(time.Millisecond)
switch lock {
case vfs.LOCK_SHARED:
for m.pending != nil {
if time.Now().After(deadline) {
return sqlite3.BUSY
}
m.lockMtx.Unlock()
runtime.Gosched()
m.lockMtx.Lock()
if m.pending != nil {
return sqlite3.BUSY
}
m.shared++
@@ -216,8 +212,8 @@ func (m *memFile) Lock(lock vfs.LockLevel) error {
m.pending = m
}
for m.shared > 1 {
if time.Now().After(deadline) {
for before := time.Now(); m.shared > 1; {
if time.Since(before) > spinWait {
return sqlite3.BUSY
}
m.lockMtx.Unlock()
@@ -291,10 +287,3 @@ func divRoundUp(a, b int64) int64 {
func modRoundUp(a, b int64) int64 {
return b - (b-a%b)%b
}
func clear[T any](b []T) {
var zero T
for i := range b {
b[i] = zero
}
}

View File

@@ -1,4 +1,4 @@
//go:build freebsd || openbsd || netbsd || dragonfly || (darwin && sqlite3_bsd)
//go:build (freebsd || openbsd || netbsd || dragonfly || sqlite3_flock) && !sqlite3_nosys
package vfs
@@ -20,16 +20,16 @@ func osUnlock(file *os.File, start, len int64) _ErrorCode {
}
func osLock(file *os.File, how int, timeout time.Duration, def _ErrorCode) _ErrorCode {
before := time.Now()
var err error
for {
err = unix.Flock(int(file.Fd()), how)
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
break
}
if timeout < time.Millisecond {
if timeout <= 0 || timeout < time.Since(before) {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)

View File

@@ -1,4 +1,4 @@
//go:build !sqlite3_bsd
//go:build !sqlite3_flock && !sqlite3_nosys
package vfs

View File

@@ -1,3 +1,5 @@
//go:build !sqlite3_nosys
package vfs
import (

33
vfs/os_nolock.go Normal file
View File

@@ -0,0 +1,33 @@
//go:build !(linux || darwin || windows || freebsd || openbsd || netbsd || dragonfly || illumos) || sqlite3_nosys
package vfs
import "os"
func osGetSharedLock(file *os.File) _ErrorCode {
return _IOERR_RDLOCK
}
func osGetReservedLock(file *os.File) _ErrorCode {
return _IOERR_LOCK
}
func osGetPendingLock(file *os.File) _ErrorCode {
return _IOERR_LOCK
}
func osGetExclusiveLock(file *os.File) _ErrorCode {
return _IOERR_LOCK
}
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
return _IOERR_RDLOCK
}
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
return _IOERR_UNLOCK
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
return false, _IOERR_CHECKRESERVEDLOCK
}

View File

@@ -1,4 +1,4 @@
//go:build linux || illumos
//go:build (linux || illumos) && !sqlite3_nosys
package vfs
@@ -27,16 +27,16 @@ func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, d
Start: start,
Len: len,
}
before := time.Now()
var err error
for {
err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock)
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
break
}
if timeout < time.Millisecond {
if timeout <= 0 || timeout < time.Since(before) {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)

36
vfs/os_std_access.go Normal file
View File

@@ -0,0 +1,36 @@
//go:build !unix || sqlite3_nosys
package vfs
import (
"io/fs"
"os"
)
const (
_S_IREAD = 0400
_S_IWRITE = 0200
_S_IEXEC = 0100
)
func osAccess(path string, flags AccessFlag) error {
fi, err := os.Stat(path)
if err != nil {
return err
}
if flags == ACCESS_EXISTS {
return nil
}
var want fs.FileMode = _S_IREAD
if flags == ACCESS_READWRITE {
want |= _S_IWRITE
}
if fi.IsDir() {
want |= _S_IEXEC
}
if fi.Mode()&want != want {
return fs.ErrPermission
}
return nil
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && (!darwin || sqlite3_bsd)
//go:build !(linux || darwin) || sqlite3_flock || sqlite3_nosys
package vfs
@@ -7,10 +7,6 @@ import (
"os"
)
func osSync(file *os.File, fullsync, dataonly bool) error {
return file.Sync()
}
func osAllocate(file *os.File, size int64) error {
off, err := file.Seek(0, io.SeekEnd)
if err != nil {

14
vfs/os_std_mode.go Normal file
View File

@@ -0,0 +1,14 @@
//go:build !unix || sqlite3_nosys
package vfs
import "os"
func osSetMode(file *os.File, modeof string) error {
fi, err := os.Stat(modeof)
if err != nil {
return err
}
file.Chmod(fi.Mode())
return nil
}

12
vfs/os_std_open.go Normal file
View File

@@ -0,0 +1,12 @@
//go:build !windows || sqlite3_nosys
package vfs
import (
"io/fs"
"os"
)
func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(name, flag, perm)
}

9
vfs/os_std_sync.go Normal file
View File

@@ -0,0 +1,9 @@
//go:build !(linux || darwin) || sqlite3_flock || sqlite3_nosys
package vfs
import "os"
func osSync(file *os.File, fullsync, dataonly bool) error {
return file.Sync()
}

View File

@@ -1,20 +1,14 @@
//go:build unix
//go:build unix && !sqlite3_nosys
package vfs
import (
"io/fs"
"os"
"syscall"
"time"
"golang.org/x/sys/unix"
)
func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(name, flag, perm)
}
func osAccess(path string, flags AccessFlag) error {
var access uint32 // unix.F_OK
switch flags {
@@ -37,64 +31,3 @@ func osSetMode(file *os.File, modeof string) error {
}
return nil
}
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Test the PENDING lock before acquiring a new SHARED lock.
if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending {
return _BUSY
}
// Acquire the SHARED lock.
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
}
func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
if timeout == 0 {
timeout = time.Millisecond
}
// Acquire the EXCLUSIVE lock.
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
}
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
if state >= LOCK_EXCLUSIVE {
// Downgrade to a SHARED lock.
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return _IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return _IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return osUnlock(file, _PENDING_BYTE, 2)
}
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
// Release all locks.
return osUnlock(file, 0, 0)
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(unix.Errno); ok {
switch errno {
case
unix.EACCES,
unix.EAGAIN,
unix.EBUSY,
unix.EINTR,
unix.ENOLCK,
unix.EDEADLK,
unix.ETIMEDOUT:
return _BUSY
case unix.EPERM:
return _PERM
}
}
return def
}

82
vfs/os_unix2.go Normal file
View File

@@ -0,0 +1,82 @@
//go:build (linux || darwin || freebsd || openbsd || netbsd || dragonfly || illumos) && !sqlite3_nosys
package vfs
import (
"os"
"time"
"golang.org/x/sys/unix"
)
func osGetSharedLock(file *os.File) _ErrorCode {
// Test the PENDING lock before acquiring a new SHARED lock.
if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending {
return _BUSY
}
// Acquire the SHARED lock.
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
}
func osGetReservedLock(file *os.File) _ErrorCode {
// Acquire the RESERVED lock.
return osWriteLock(file, _RESERVED_BYTE, 1, 0)
}
func osGetPendingLock(file *os.File) _ErrorCode {
// Acquire the PENDING lock.
return osWriteLock(file, _PENDING_BYTE, 1, 0)
}
func osGetExclusiveLock(file *os.File) _ErrorCode {
// Acquire the EXCLUSIVE lock.
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, time.Millisecond)
}
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
if state >= LOCK_EXCLUSIVE {
// Downgrade to a SHARED lock.
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return _IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return _IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return osUnlock(file, _PENDING_BYTE, 2)
}
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
// Release all locks.
return osUnlock(file, 0, 0)
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
// Test the RESERVED lock.
return osCheckLock(file, _RESERVED_BYTE, 1)
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(unix.Errno); ok {
switch errno {
case
unix.EACCES,
unix.EAGAIN,
unix.EBUSY,
unix.EINTR,
unix.ENOLCK,
unix.EDEADLK,
unix.ETIMEDOUT:
return _BUSY
case unix.EPERM:
return _PERM
}
}
return def
}

View File

@@ -1,3 +1,5 @@
//go:build !sqlite3_nosys
package vfs
import (
@@ -25,40 +27,9 @@ func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.NewFile(uintptr(r), name), nil
}
func osAccess(path string, flags AccessFlag) error {
fi, err := os.Stat(path)
if err != nil {
return err
}
if flags == ACCESS_EXISTS {
return nil
}
var want fs.FileMode = windows.S_IRUSR
if flags == ACCESS_READWRITE {
want |= windows.S_IWUSR
}
if fi.IsDir() {
want |= windows.S_IXUSR
}
if fi.Mode()&want != want {
return fs.ErrPermission
}
return nil
}
func osSetMode(file *os.File, modeof string) error {
fi, err := os.Stat(modeof)
if err != nil {
return err
}
file.Chmod(fi.Mode())
return nil
}
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
func osGetSharedLock(file *os.File) _ErrorCode {
// Acquire the PENDING lock temporarily before acquiring a new SHARED lock.
rc := osReadLock(file, _PENDING_BYTE, 1, timeout)
rc := osReadLock(file, _PENDING_BYTE, 1, 0)
if rc == _OK {
// Acquire the SHARED lock.
@@ -70,16 +41,22 @@ func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
return rc
}
func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
if timeout == 0 {
timeout = time.Millisecond
}
func osGetReservedLock(file *os.File) _ErrorCode {
// Acquire the RESERVED lock.
return osWriteLock(file, _RESERVED_BYTE, 1, 0)
}
func osGetPendingLock(file *os.File) _ErrorCode {
// Acquire the PENDING lock.
return osWriteLock(file, _PENDING_BYTE, 1, 0)
}
func osGetExclusiveLock(file *os.File) _ErrorCode {
// Release the SHARED lock.
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Acquire the EXCLUSIVE lock.
rc := osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
rc := osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, time.Millisecond)
if rc != _OK {
// Reacquire the SHARED lock.
@@ -125,6 +102,11 @@ func osReleaseLock(file *os.File, state LockLevel) _ErrorCode {
return _OK
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
// Test the RESERVED lock.
return osCheckLock(file, _RESERVED_BYTE, 1)
}
func osUnlock(file *os.File, start, len uint32) _ErrorCode {
err := windows.UnlockFileEx(windows.Handle(file.Fd()),
0, len, 0, &windows.Overlapped{Offset: start})
@@ -138,6 +120,7 @@ func osUnlock(file *os.File, start, len uint32) _ErrorCode {
}
func osLock(file *os.File, flags, start, len uint32, timeout time.Duration, def _ErrorCode) _ErrorCode {
before := time.Now()
var err error
for {
err = windows.LockFileEx(windows.Handle(file.Fd()), flags,
@@ -145,11 +128,16 @@ func osLock(file *os.File, flags, start, len uint32, timeout time.Duration, def
if errno, _ := err.(windows.Errno); errno != windows.ERROR_LOCK_VIOLATION {
break
}
if timeout < time.Millisecond {
if timeout <= 0 || timeout < time.Since(before) {
break
}
if err := windows.TimeBeginPeriod(1); err != nil {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
if err := windows.TimeEndPeriod(1); err != nil {
break
}
}
return osLockErrorCode(err, def)
}

View File

@@ -28,7 +28,7 @@ var (
)
// Create creates an immutable database from reader.
// The caller should insure that data from reader does not mutate,
// The caller should ensure that data from reader does not mutate,
// otherwise SQLite might return incorrect query results and/or [sqlite3.CORRUPT] errors.
func Create(name string, reader SizeReaderAt) {
readerMtx.Lock()

View File

@@ -1,10 +1,10 @@
package readervfs_test
import (
"bytes"
"database/sql"
"fmt"
"log"
"strings"
_ "embed"
@@ -15,7 +15,7 @@ import (
)
//go:embed testdata/test.db
var testDB []byte
var testDB string
func Example_http() {
readervfs.Create("demo.db", httpreadat.New("https://www.sanford.io/demo.db"))
@@ -65,7 +65,7 @@ func Example_http() {
}
func Example_embed() {
readervfs.Create("test.db", readervfs.NewSizeReaderAt(bytes.NewReader(testDB)))
readervfs.Create("test.db", readervfs.NewSizeReaderAt(strings.NewReader(testDB)))
defer readervfs.Delete("test.db")
db, err := sql.Open("sqlite3", "file:test.db?vfs=reader")

Some files were not shown because too many files have changed in this diff Show More