Compare commits

..

70 Commits

Author SHA1 Message Date
Nuno Cruces
bc840dcefb SQLite 3.45.0. 2024-01-16 15:53:47 +00:00
Nuno Cruces
c822fa95c7 Batch column scans. (#52) 2024-01-16 15:18:14 +00:00
Nuno Cruces
1b2c267b2b Optimize interrupts. 2024-01-16 15:08:26 +00:00
Nuno Cruces
3d99af86bf Ensure arena alignment. 2024-01-15 10:43:36 +00:00
Nuno Cruces
145bc228af Avoid allocation. 2024-01-12 13:35:21 +00:00
Nuno Cruces
6b0c2c0554 Optimize. (#51) 2024-01-11 02:18:12 +00:00
Nuno Cruces
97f2b73701 Optimize. 2024-01-10 16:53:18 +00:00
Nuno Cruces
cb1e33a32d Benchmarks. 2024-01-10 12:27:19 +00:00
Nuno Cruces
ee48dd5c96 More stats. 2024-01-10 11:39:26 +00:00
Nuno Cruces
af42af2978 More stats. 2024-01-09 03:20:59 +00:00
dependabot[bot]
d48a92fcdf Bump golang.org/x/crypto from 0.17.0 to 0.18.0
Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.17.0 to 0.18.0.
- [Commits](https://github.com/golang/crypto/compare/v0.17.0...v0.18.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-01-08 23:05:11 +00:00
Nuno Cruces
69937fbee5 More vtab API. 2024-01-08 19:23:56 +00:00
dependabot[bot]
2fb325b223 Bump golang.org/x/sync from 0.5.0 to 0.6.0
Bumps [golang.org/x/sync](https://github.com/golang/sync) from 0.5.0 to 0.6.0.
- [Commits](https://github.com/golang/sync/compare/v0.5.0...v0.6.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sync
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-01-05 02:16:24 +00:00
dependabot[bot]
f0c583a581 Bump golang.org/x/sys from 0.15.0 to 0.16.0
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.15.0 to 0.16.0.
- [Commits](https://github.com/golang/sys/compare/v0.15.0...v0.16.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-01-05 02:00:28 +00:00
Nuno Cruces
17ce949c55 Create osutil. 2024-01-03 12:54:26 +00:00
Nuno Cruces
ae850191c8 Refactor extensions. 2024-01-03 12:43:03 +00:00
Nuno Cruces
fab70ddbec IEEE754 extension. 2023-12-30 10:50:35 +00:00
Nuno Cruces
a3c5f47d79 Update README.md 2023-12-30 00:47:16 +00:00
Nuno Cruces
16b5d80ef7 Internal JSON and pointer wrappers. 2023-12-29 23:42:37 +00:00
Nuno Cruces
7e5a143214 Hash functions. 2023-12-29 23:42:30 +00:00
dependabot[bot]
92d75f7446 Bump cross-platform-actions/action from 0.21.1 to 0.22.0
Bumps [cross-platform-actions/action](https://github.com/cross-platform-actions/action) from 0.21.1 to 0.22.0.
- [Release notes](https://github.com/cross-platform-actions/action/releases)
- [Changelog](https://github.com/cross-platform-actions/action/blob/master/changelog.md)
- [Commits](https://github.com/cross-platform-actions/action/compare/v0.21.1...v0.22.0)

---
updated-dependencies:
- dependency-name: cross-platform-actions/action
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-12-28 00:16:27 +00:00
Nuno Cruces
d56ee4ac2c Error logging. 2023-12-27 14:16:00 +00:00
Nuno Cruces
e944d5d8e7 Config. 2023-12-23 14:53:15 +00:00
Nuno Cruces
fde2277b4a wazero v1.6.0. 2023-12-23 13:19:33 +00:00
Nuno Cruces
1ebdeed565 Documentation, issue #45. 2023-12-22 02:45:26 +00:00
Nuno Cruces
89202629ec Increase various limits, fix #45. 2023-12-21 15:08:19 +00:00
Nuno Cruces
cb62771a45 Examples. 2023-12-20 16:59:16 +00:00
Nuno Cruces
0bb1cd5e2e Rework error messages, see #45. 2023-12-20 16:10:50 +00:00
Danlock
7bbd4f1e3c Fix regex link typo 2023-12-19 16:01:58 +00:00
Nuno Cruces
ed4a3a894b Extension API tweaks. 2023-12-19 15:24:54 +00:00
Nuno Cruces
f1b00a9944 wasi-sdk-21. 2023-12-19 00:33:04 +00:00
Nuno Cruces
9281948f57 Extension API tweaks. 2023-12-19 00:13:51 +00:00
Nuno Cruces
b0b27439b5 Fix macOS osAllocate.
Mozilla is just wrong.
https://searchfox.org/mozilla-central/source/xpcom/glue/FileUtils.cpp
2023-12-17 05:19:27 +00:00
Nuno Cruces
c938577763 Update README.md 2023-12-15 11:05:53 +00:00
Nuno Cruces
ebbb969cd7 Tweaks. 2023-12-15 00:46:12 +00:00
Nuno Cruces
0171743e88 Blob IO extension. 2023-12-14 23:04:18 +00:00
Nuno Cruces
c68413bd53 Optimize interrupts. 2023-12-14 17:23:46 +00:00
Nuno Cruces
3f8b480ba0 Optimize declared types. 2023-12-14 17:23:46 +00:00
Nuno Cruces
9866067701 Improve function cache.
Assume interned strings.
2023-12-14 17:22:49 +00:00
Nuno Cruces
964a42c76d Improve function cache.
Implement a 4x larger, PLRU bit cache.
2023-12-14 11:32:43 +00:00
Nuno Cruces
0b093b7c0e More tests. 2023-12-12 16:55:17 +00:00
Nuno Cruces
32a824cb6c Tests. 2023-12-12 14:06:54 +00:00
Nuno Cruces
2e1c65147a BSD tests. 2023-12-12 12:03:16 +00:00
Nuno Cruces
86cc08e4d6 Fix BSD tests. 2023-12-12 02:48:44 +00:00
dependabot[bot]
05077b8845 Bump actions/setup-go from 4 to 5
Bumps [actions/setup-go](https://github.com/actions/setup-go) from 4 to 5.
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](https://github.com/actions/setup-go/compare/v4...v5)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-12-12 01:04:25 +00:00
Nuno Cruces
6e8d5e5be6 More fileio. 2023-12-12 01:00:13 +00:00
Nuno Cruces
c99fbcea6f Towards fileio extension. 2023-12-11 14:48:15 +00:00
Nuno Cruces
831a34a4c4 Updated dependencies. 2023-12-07 14:00:08 +00:00
Nuno Cruces
7c820ede3c Driver time formatting. 2023-12-07 13:49:33 +00:00
Nuno Cruces
089a0c0670 Pivot virtual table. 2023-12-06 17:49:48 +00:00
Nuno Cruces
8b45cac16b Improved error handling. 2023-12-05 18:17:33 +00:00
Nuno Cruces
06d2ff6752 Optimize VFS find. 2023-12-05 14:11:20 +00:00
Nuno Cruces
987f0f13a2 Test CPUs. 2023-12-04 14:01:25 +00:00
Nuno Cruces
cd40213898 Reuse statement, API. 2023-12-04 13:46:48 +00:00
Nuno Cruces
8a0baedc10 Tests, fixes. 2023-12-02 12:17:18 +00:00
Nuno Cruces
c667a1f469 Declared type. 2023-12-02 12:17:18 +00:00
Nuno Cruces
9c562f5d8b Cache functions. 2023-12-02 12:17:18 +00:00
Nuno Cruces
d862f47d95 Deoptimize. 2023-12-02 12:17:18 +00:00
Nuno Cruces
a9e32fd3f0 Fix compiler crash. 2023-12-02 12:17:18 +00:00
Nuno Cruces
b262f5cd01 Statement virtual table. 2023-12-02 12:17:18 +00:00
Nuno Cruces
4160b9a4bb Simplify tails. 2023-11-30 18:18:27 +00:00
Nuno Cruces
dbaf2d99cd Unprotected values. 2023-11-30 00:29:41 +00:00
Nuno Cruces
3f05115cd7 Virtual table API. 2023-11-29 10:46:11 +00:00
Nuno Cruces
9bf14becaf Reentrant arenas. 2023-11-29 10:46:02 +00:00
Nuno Cruces
997e197f54 VFS tweaks. 2023-11-28 16:38:02 +00:00
Nuno Cruces
b81fe284b6 memdb WAL. 2023-11-28 11:40:04 +00:00
dependabot[bot]
269306c5c8 Bump golang.org/x/sys from 0.14.0 to 0.15.0
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.14.0 to 0.15.0.
- [Commits](https://github.com/golang/sys/compare/v0.14.0...v0.15.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-11-28 02:00:06 +00:00
Nuno Cruces
8a18243830 Reproducible builds. 2023-11-28 01:55:26 +00:00
Nuno Cruces
fcd6cc91d8 Skip BOM. 2023-11-27 23:35:43 +00:00
Nuno Cruces
c1838fc0bc Fix encoding issues. 2023-11-27 15:37:53 +00:00
137 changed files with 5040 additions and 1504 deletions

View File

@@ -9,3 +9,7 @@ updates:
directory: "/" # Location of package manifests
schedule:
interval: "daily"
- package-ecosystem: "github-actions" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "daily"

View File

@@ -12,18 +12,13 @@ jobs:
with:
lfs: 'true'
- name: Set up
uses: actions/setup-go@v4
with:
go-version: stable
- name: Build
run: GOOS=freebsd go test -c ./...
- name: Test
uses: cross-platform-actions/action@v0.21.1
uses: cross-platform-actions/action@v0.22.0
with:
operating_system: freebsd
version: '13.2'
memory: 8G
sync_files: runner-to-vm
run: find . -name '*.test' -maxdepth 1 -exec {} -test.v \;
run: |
sudo pkg install -y go121
go121 test -v ./...

40
.github/workflows/cpu.yml vendored Normal file
View File

@@ -0,0 +1,40 @@
name: CPUs
on:
workflow_dispatch:
jobs:
test-386:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
lfs: 'true'
- name: Set up
uses: actions/setup-go@v5
with:
go-version: stable
- name: Test
run: GOARCH=386 go test -v ./...
test-arm:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
lfs: 'true'
- name: Set up
uses: actions/setup-go@v5
with:
go-version: stable
- name: Install QEMU
uses: docker/setup-qemu-action@v3
- name: Test
run: GOARCH=arm64 go test -v ./...

View File

@@ -12,11 +12,11 @@ echo openbsd ; GOOS=openbsd GOARCH=amd64 go build .
echo plan9 ; GOOS=plan9 GOARCH=amd64 go build .
echo solaris ; GOOS=solaris GOARCH=amd64 go build .
echo windows ; GOOS=windows GOARCH=amd64 go build .
# echo aix ; GOOS=aix GOARCH=ppc64 go build .
echo aix ; GOOS=aix GOARCH=ppc64 go build .
echo js ; GOOS=js GOARCH=wasm go build .
echo wasip1 ; GOOS=wasip1 GOARCH=wasm go build .
echo darwin-flock ; GOOS=darwin GOARCH=amd64 go build -tags sqlite3_flock .
echo darwin-nosys ; GOOS=darwin GOARCH=amd64 go build -tags sqlite3_nosys .
echo linux-nosys ; GOOS=linux GOARCH=amd64 go build -tags sqlite3_nosys .
echo windows-nosys ; GOOS=windows GOARCH=amd64 go build -tags sqlite3_nosys .
echo freebsd-nosys ; GOOS=freebsd GOARCH=amd64 go build -tags sqlite3_nosys .
echo freebsd-nosys ; GOOS=freebsd GOARCH=amd64 go build -tags sqlite3_nosys .

View File

@@ -4,16 +4,14 @@ on:
workflow_dispatch:
jobs:
test:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
lfs: 'true'
- name: Set up
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: stable

View File

@@ -20,7 +20,7 @@ jobs:
lfs: 'true'
- name: Set up
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: stable
@@ -58,7 +58,6 @@ jobs:
with:
chart: true
amend: true
reuse-go: true
if: |
github.event_name == 'push' &&
matrix.os == 'ubuntu-latest'

23
.github/workflows/repro.sh vendored Executable file
View File

@@ -0,0 +1,23 @@
#!/usr/bin/env bash
set -euo pipefail
if [[ "$OSTYPE" == "linux"* ]]; then
WASI_SDK="https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-21/wasi-sdk-21.0-linux.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz"
elif [[ "$OSTYPE" == "darwin"* ]]; then
WASI_SDK="https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-21/wasi-sdk-21.0-macos.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-macos.tar.gz"
elif [[ "$OSTYPE" == "msys" || "$OSTYPE" == "cygwin" ]]; then
WASI_SDK="https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-21/wasi-sdk-21.0.m-mingw.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-windows.tar.gz"
fi
# Download tools
mkdir -p tools
[ -d "tools/wasi-sdk"* ] || curl -#L "$WASI_SDK" | tar xzC tools &
[ -d "tools/binaryen-version"* ] || curl -#L "$BINARYEN" | tar xzC tools &
wait
sqlite3/download.sh # Download SQLite
embed/build.sh # Build WASM
git diff --exit-code # Check diffs

24
.github/workflows/repro.yml vendored Normal file
View File

@@ -0,0 +1,24 @@
name: Reproducible build
on:
workflow_dispatch:
jobs:
build:
strategy:
matrix:
os: [macos-latest, ubuntu-latest, windows-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
with:
lfs: 'true'
- name: Set up
uses: actions/setup-go@v5
with:
go-version: stable
- name: Build
run: .github/workflows/repro.sh

View File

@@ -4,8 +4,14 @@
[![Go Report](https://goreportcard.com/badge/github.com/ncruces/go-sqlite3)](https://goreportcard.com/report/github.com/ncruces/go-sqlite3)
[![Go Coverage](https://github.com/ncruces/go-sqlite3/wiki/coverage.svg)](https://github.com/ncruces/go-sqlite3/wiki/Test-coverage-report)
Go module `github.com/ncruces/go-sqlite3` wraps a [WASM](https://webassembly.org/) build of [SQLite](https://sqlite.org/),
and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings.
Go module `github.com/ncruces/go-sqlite3` is `cgo`-free [SQLite](https://sqlite.org/) wrapper.\
It provides a [`database/sql`](https://pkg.go.dev/database/sql) compatible driver,
as well as direct access to most of the [C SQLite API](https://sqlite.org/cintro.html).
It wraps a [WASM](https://webassembly.org/) build of SQLite, and uses [wazero](https://wazero.io/) as the runtime.\
Go, wazero and [`x/sys`](https://pkg.go.dev/golang.org/x/sys) are the _only_ runtime dependencies.
### Packages
- [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
wraps the [C SQLite API](https://sqlite.org/cintro.html)
@@ -20,18 +26,26 @@ and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings.
- [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
provides a [GORM](https://gorm.io) driver.
### Loadable extensions
### Extensions
- [`github.com/ncruces/go-sqlite3/ext/array`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/blob)
- [`github.com/ncruces/go-sqlite3/ext/array`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/array)
provides the [`array`](https://sqlite.org/carray.html) table-valued function.
- [`github.com/ncruces/go-sqlite3/ext/blob`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/blob)
- [`github.com/ncruces/go-sqlite3/ext/blobio`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/blobio)
simplifies [incremental BLOB I/O](https://sqlite.org/c3ref/blob_open.html).
- [`github.com/ncruces/go-sqlite3/ext/csv`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/csv)
reads [comma-separated values](https://sqlite.org/csv.html).
- [`github.com/ncruces/go-sqlite3/ext/fileio`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/fileio)
reads, writes and lists files.
- [`github.com/ncruces/go-sqlite3/ext/hash`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/hash)
provides cryptographic hash functions.
- [`github.com/ncruces/go-sqlite3/ext/lines`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/lines)
reads files [line-by-line](https://github.com/asg017/sqlite-lines).
reads data [line-by-line](https://github.com/asg017/sqlite-lines).
- [`github.com/ncruces/go-sqlite3/ext/pivot`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/pivot)
creates [pivot tables](https://github.com/jakethaw/pivot_vtab).
- [`github.com/ncruces/go-sqlite3/ext/statement`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/statement)
creates [parameterized views](https://github.com/0x09/sqlite-statement-vtab).
- [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
provides [statistics functions](https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html).
provides [statistics](https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html) functions.
- [`github.com/ncruces/go-sqlite3/ext/unicode`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/unicode)
provides [Unicode aware](https://sqlite.org/src/dir/ext/icu) functions.
- [`github.com/ncruces/go-sqlite3/vfs/memdb`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/memdb)
@@ -41,14 +55,17 @@ and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings.
### Advanced features
- [x] [incremental BLOB I/O](https://sqlite.org/c3ref/blob_open.html)
- [x] [nested transactions](https://sqlite.org/lang_savepoint.html)
- [x] [custom functions](https://sqlite.org/c3ref/create_function.html)
- [x] [virtual tables](https://sqlite.org/vtab.html)
- [x] [custom VFSes](https://sqlite.org/vfs.html)
- [x] [online backup](https://sqlite.org/backup.html)
- [x] [JSON support](https://www.sqlite.org/json1.html)
- [x] [Unicode support](https://sqlite.org/src/dir/ext/icu)
- [incremental BLOB I/O](https://sqlite.org/c3ref/blob_open.html)
- [nested transactions](https://sqlite.org/lang_savepoint.html)
- [custom functions](https://sqlite.org/c3ref/create_function.html)
- [virtual tables](https://sqlite.org/vtab.html)
- [custom VFSes](https://sqlite.org/vfs.html)
- [online backup](https://sqlite.org/backup.html)
- [JSON support](https://sqlite.org/json1.html)
- [math functions](https://sqlite.org/lang_mathfunc.html)
- [full-text search](https://sqlite.org/fts5.html)
- [geospatial search](https://sqlite.org/geopoly.html)
- [and more…](embed/README.md)
### Caveats
@@ -93,12 +110,22 @@ To use the [`database/sql`](https://pkg.go.dev/database/sql) driver
with `nolock=1` you must disable connection pooling by calling
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
#### Testing
### Testing
This project aims for [high test coverage](https://github.com/ncruces/go-sqlite3/wiki/Test-coverage-report).
It also benefits greatly from [SQLite's](https://www.sqlite.org/testing.html) and
[wazero's](https://tetrate.io/blog/introducing-wazero-from-tetrate/#:~:text=Rock%2Dsolid%20test%20approach) thorough testing.
The pure Go VFS is tested by running SQLite's
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
on Linux, macOS, Windows and FreeBSD.
Performance is tested by running
### Performance
Perfomance of the [`database/sql`](https://pkg.go.dev/database/sql) driver is
[competitive](https://github.com/cvilsmeier/go-sqlite-bench) with alternatives.
The WASM and VFS layers are also tested by running SQLite's
[speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c).
### Alternatives

View File

@@ -62,7 +62,7 @@ func (src *Conn) BackupInit(srcDB, dstURI string) (*Backup, error) {
}
func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string) (*Backup, error) {
defer c.arena.reset()
defer c.arena.mark()()
dstPtr := c.arena.string(dstName)
srcPtr := c.arena.string(srcName)
@@ -71,12 +71,12 @@ func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string
other = src
}
r := c.call(c.api.backupInit,
r := c.call("sqlite3_backup_init",
uint64(dst), uint64(dstPtr),
uint64(src), uint64(srcPtr))
if r == 0 {
defer c.closeDB(other)
r = c.call(c.api.errcode, uint64(dst))
r = c.call("sqlite3_errcode", uint64(dst))
return nil, c.sqlite.error(r, dst)
}
@@ -97,7 +97,7 @@ func (b *Backup) Close() error {
return nil
}
r := b.c.call(b.c.api.backupFinish, uint64(b.handle))
r := b.c.call("sqlite3_backup_finish", uint64(b.handle))
b.c.closeDB(b.otherc)
b.handle = 0
return b.c.error(r)
@@ -108,7 +108,7 @@ func (b *Backup) Close() error {
//
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backupstep
func (b *Backup) Step(nPage int) (done bool, err error) {
r := b.c.call(b.c.api.backupStep, uint64(b.handle), uint64(nPage))
r := b.c.call("sqlite3_backup_step", uint64(b.handle), uint64(nPage))
if r == _DONE {
return true, nil
}
@@ -120,7 +120,7 @@ func (b *Backup) Step(nPage int) (done bool, err error) {
//
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backupremaining
func (b *Backup) Remaining() int {
r := b.c.call(b.c.api.backupRemaining, uint64(b.handle))
r := b.c.call("sqlite3_backup_remaining", uint64(b.handle))
return int(r)
}
@@ -129,6 +129,6 @@ func (b *Backup) Remaining() int {
//
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backuppagecount
func (b *Backup) PageCount() int {
r := b.c.call(b.c.api.backupPageCount, uint64(b.handle))
r := b.c.call("sqlite3_backup_pagecount", uint64(b.handle))
return int(r)
}

36
blob.go
View File

@@ -30,7 +30,7 @@ var _ io.ReadWriteSeeker = &Blob{}
// https://sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
c.checkInterrupt()
defer c.arena.reset()
defer c.arena.mark()()
blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db)
tablePtr := c.arena.string(table)
@@ -41,7 +41,7 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob,
flags = 1
}
r := c.call(c.api.blobOpen, uint64(c.handle),
r := c.call("sqlite3_blob_open", uint64(c.handle),
uint64(dbPtr), uint64(tablePtr), uint64(columnPtr),
uint64(row), flags, uint64(blobPtr))
@@ -51,7 +51,7 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob,
blob := Blob{c: c}
blob.handle = util.ReadUint32(c.mod, blobPtr)
blob.bytes = int64(c.call(c.api.blobBytes, uint64(blob.handle)))
blob.bytes = int64(c.call("sqlite3_blob_bytes", uint64(blob.handle)))
return &blob, nil
}
@@ -65,7 +65,7 @@ func (b *Blob) Close() error {
return nil
}
r := b.c.call(b.c.api.blobClose, uint64(b.handle))
r := b.c.call("sqlite3_blob_close", uint64(b.handle))
b.handle = 0
return b.c.error(r)
@@ -92,10 +92,10 @@ func (b *Blob) Read(p []byte) (n int, err error) {
want = avail
}
ptr := b.c.new(uint64(want))
defer b.c.free(ptr)
defer b.c.arena.mark()()
ptr := b.c.arena.new(uint64(want))
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
r := b.c.call("sqlite3_blob_read", uint64(b.handle),
uint64(ptr), uint64(want), uint64(b.offset))
err = b.c.error(r)
if err != nil {
@@ -124,11 +124,11 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
want = avail
}
ptr := b.c.new(uint64(want))
defer b.c.free(ptr)
defer b.c.arena.mark()()
ptr := b.c.arena.new(uint64(want))
for want > 0 {
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
r := b.c.call("sqlite3_blob_read", uint64(b.handle),
uint64(ptr), uint64(want), uint64(b.offset))
err = b.c.error(r)
if err != nil {
@@ -158,10 +158,10 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
//
// https://sqlite.org/c3ref/blob_write.html
func (b *Blob) Write(p []byte) (n int, err error) {
ptr := b.c.newBytes(p)
defer b.c.free(ptr)
defer b.c.arena.mark()()
ptr := b.c.arena.bytes(p)
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
r := b.c.call("sqlite3_blob_write", uint64(b.handle),
uint64(ptr), uint64(len(p)), uint64(b.offset))
err = b.c.error(r)
if err != nil {
@@ -187,14 +187,14 @@ func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) {
want = 1
}
ptr := b.c.new(uint64(want))
defer b.c.free(ptr)
defer b.c.arena.mark()()
ptr := b.c.arena.new(uint64(want))
for {
mem := util.View(b.c.mod, ptr, uint64(want))
m, err := r.Read(mem[:want])
if m > 0 {
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
r := b.c.call("sqlite3_blob_write", uint64(b.handle),
uint64(ptr), uint64(m), uint64(b.offset))
err := b.c.error(r)
if err != nil {
@@ -243,8 +243,8 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
//
// https://sqlite.org/c3ref/blob_reopen.html
func (b *Blob) Reopen(row int64) error {
err := b.c.error(b.c.call(b.c.api.blobReopen, uint64(b.handle), uint64(row)))
b.bytes = int64(b.c.call(b.c.api.blobBytes, uint64(b.handle)))
err := b.c.error(b.c.call("sqlite3_blob_reopen", uint64(b.handle), uint64(row)))
b.bytes = int64(b.c.call("sqlite3_blob_bytes", uint64(b.handle)))
b.offset = 0
return err
}

57
config.go Normal file
View File

@@ -0,0 +1,57 @@
package sqlite3
import (
"context"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
)
// Config makes configuration changes to a database connection.
// Only boolean configuration options are supported.
// Called with no arg reads the current configuration value,
// called with one arg sets and returns the new value.
//
// https://sqlite.org/c3ref/db_config.html
func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) {
defer c.arena.mark()()
argsPtr := c.arena.new(2 * ptrlen)
var flag int
switch {
case len(arg) == 0:
flag = -1
case arg[0]:
flag = 1
}
util.WriteUint32(c.mod, argsPtr+0*ptrlen, uint32(flag))
util.WriteUint32(c.mod, argsPtr+1*ptrlen, argsPtr)
r := c.call("sqlite3_db_config", uint64(c.handle),
uint64(op), uint64(argsPtr))
return util.ReadUint32(c.mod, argsPtr) != 0, c.error(r)
}
// ConfigLog sets up the error logging callback for the connection.
//
// https://www.sqlite.org/errlog.html
func (c *Conn) ConfigLog(cb func(code ExtendedErrorCode, msg string)) error {
var enable uint64
if cb != nil {
enable = 1
}
r := c.call("sqlite3_config_log_go", enable)
if err := c.error(r); err != nil {
return err
}
c.log = cb
return nil
}
func logCallback(ctx context.Context, mod api.Module, _, iCode, zMsg uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.log != nil {
msg := util.ReadString(mod, zMsg, _MAX_LENGTH)
c.log(xErrorCode(iCode), msg)
}
}

79
conn.go
View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net/url"
"runtime"
"strings"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -21,6 +20,7 @@ type Conn struct {
interrupt context.Context
pending *Stmt
log func(code xErrorCode, msg string)
arena arena
handle uint32
@@ -56,8 +56,6 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
defer func() {
if conn == nil {
sqlite.close()
} else {
runtime.SetFinalizer(conn, util.Finalizer[Conn](3))
}
}()
@@ -72,12 +70,12 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
}
func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
defer c.arena.reset()
defer c.arena.mark()()
connPtr := c.arena.new(ptrlen)
namePtr := c.arena.string(filename)
flags |= OPEN_EXRESCODE
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
r := c.call("sqlite3_open_v2", uint64(namePtr), uint64(connPtr), uint64(flags), 0)
handle := util.ReadUint32(c.mod, connPtr)
if err := c.sqlite.error(r, handle); err != nil {
@@ -92,13 +90,12 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
for _, p := range query["_pragma"] {
pragmas.WriteString(`PRAGMA `)
pragmas.WriteString(p)
pragmas.WriteByte(';')
pragmas.WriteString(`;`)
}
}
c.arena.reset()
pragmaPtr := c.arena.string(pragmas.String())
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
r := c.call("sqlite3_exec", uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.sqlite.error(r, handle, pragmas.String()); err != nil {
if errors.Is(err, ERROR) {
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
@@ -107,12 +104,12 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
return 0, err
}
}
c.call("sqlite3_progress_handler_go", uint64(handle), 100)
return handle, nil
}
func (c *Conn) closeDB(handle uint32) {
r := c.call(c.api.closeZombie, uint64(handle))
r := c.call("sqlite3_close_v2", uint64(handle))
if err := c.sqlite.error(r, handle); err != nil {
panic(err)
}
@@ -135,13 +132,12 @@ func (c *Conn) Close() error {
c.pending.Close()
c.pending = nil
r := c.call(c.api.close, uint64(c.handle))
r := c.call("sqlite3_close", uint64(c.handle))
if err := c.error(r); err != nil {
return err
}
c.handle = 0
runtime.SetFinalizer(c, nil)
return c.close()
}
@@ -151,11 +147,11 @@ func (c *Conn) Close() error {
// https://sqlite.org/c3ref/exec.html
func (c *Conn) Exec(sql string) error {
c.checkInterrupt()
defer c.arena.reset()
defer c.arena.mark()()
sqlPtr := c.arena.string(sql)
r := c.call(c.api.exec, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
return c.error(r)
r := c.call("sqlite3_exec", uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
return c.error(r, sql)
}
// Prepare calls [Conn.PrepareFlags] with no flags.
@@ -173,23 +169,21 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
if len(sql) > _MAX_LENGTH {
return nil, "", TOOBIG
}
if emptyStatement(sql) {
return nil, "", nil
}
defer c.arena.reset()
defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen)
tailPtr := c.arena.new(ptrlen)
sqlPtr := c.arena.string(sql)
r := c.call(c.api.prepare, uint64(c.handle),
r := c.call("sqlite3_prepare_v3", uint64(c.handle),
uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),
uint64(stmtPtr), uint64(tailPtr))
stmt = &Stmt{c: c}
stmt.handle = util.ReadUint32(c.mod, stmtPtr)
i := util.ReadUint32(c.mod, tailPtr)
tail = sql[i-sqlPtr:]
if sql := sql[util.ReadUint32(c.mod, tailPtr)-sqlPtr:]; sql != "" {
tail = sql
}
if err := c.error(r, sql); err != nil {
return nil, "", err
@@ -197,14 +191,14 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
if stmt.handle == 0 {
return nil, "", nil
}
return
return stmt, tail, nil
}
// GetAutocommit tests the connection for auto-commit mode.
//
// https://sqlite.org/c3ref/get_autocommit.html
func (c *Conn) GetAutocommit() bool {
r := c.call(c.api.autocommit, uint64(c.handle))
r := c.call("sqlite3_get_autocommit", uint64(c.handle))
return r != 0
}
@@ -213,7 +207,7 @@ func (c *Conn) GetAutocommit() bool {
//
// https://sqlite.org/c3ref/last_insert_rowid.html
func (c *Conn) LastInsertRowID() int64 {
r := c.call(c.api.lastRowid, uint64(c.handle))
r := c.call("sqlite3_last_insert_rowid", uint64(c.handle))
return int64(r)
}
@@ -223,7 +217,7 @@ func (c *Conn) LastInsertRowID() int64 {
//
// https://sqlite.org/c3ref/changes.html
func (c *Conn) Changes() int64 {
r := c.call(c.api.changes, uint64(c.handle))
r := c.call("sqlite3_changes64", uint64(c.handle))
return int64(r)
}
@@ -247,27 +241,30 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
return ctx
}
// An uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
// A busy SQL statement prevents SQLite from ignoring an interrupt
// that comes before any other statements are started.
if c.pending == nil {
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
} else {
c.pending.Reset()
c.pending, _, _ = c.Prepare(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`)
}
old = c.interrupt
c.interrupt = ctx
// Remove the handler if the context can't be canceled.
if ctx == nil || ctx.Done() == nil {
c.call(c.api.progressHandler, uint64(c.handle), 0)
return old
}
c.pending.Step()
c.call(c.api.progressHandler, uint64(c.handle), 100)
if old != nil && old.Done() != nil && (ctx == nil || ctx.Err() == nil) {
c.pending.Reset()
}
if ctx != nil && ctx.Done() != nil {
c.pending.Step()
}
return old
}
func (c *Conn) checkInterrupt() {
if c.interrupt != nil && c.interrupt.Err() != nil {
c.call("sqlite3_interrupt", uint64(c.handle))
}
}
func progressCallback(ctx context.Context, mod api.Module, _ uint32) uint32 {
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
if c.interrupt != nil && c.interrupt.Err() != nil {
@@ -277,12 +274,6 @@ func progressCallback(ctx context.Context, mod api.Module, _ uint32) uint32 {
return 0
}
func (c *Conn) checkInterrupt() {
if c.interrupt != nil && c.interrupt.Err() != nil {
c.call(c.api.interrupt, uint64(c.handle))
}
}
// Pragma executes a PRAGMA statement and returns any results.
//
// https://sqlite.org/pragma.html

View File

@@ -9,10 +9,11 @@ const (
_UTF8 = 1
_MAX_NAME = 512 // Used for short strings: names, error messages…
_MAX_NAME = 1e6 // Self-imposed limit for most NUL terminated strings.
_MAX_LENGTH = 1e9
_MAX_SQL_LENGTH = 1e9
_MAX_ALLOCATION_SIZE = 0x7ffffeff
_MAX_FUNCTION_ARG = 100
ptrlen = 4
)
@@ -169,16 +170,63 @@ const (
PREPARE_NO_VTAB PrepareFlag = 0x04
)
// FunctionFlag is a flag that can be passed to [Conn.PrepareFlags].
// FunctionFlag is a flag that can be passed to
// [Conn.CreateFunction] and [Conn.CreateWindowFunction].
//
// https://sqlite.org/c3ref/c_deterministic.html
type FunctionFlag uint32
const (
DETERMINISTIC FunctionFlag = 0x000000800
DIRECTONLY FunctionFlag = 0x000080000
SUBTYPE FunctionFlag = 0x000100000
INNOCUOUS FunctionFlag = 0x000200000
DETERMINISTIC FunctionFlag = 0x000000800
DIRECTONLY FunctionFlag = 0x000080000
SUBTYPE FunctionFlag = 0x000100000
INNOCUOUS FunctionFlag = 0x000200000
RESULT_SUBTYPE FunctionFlag = 0x001000000
)
// StmtStatus name counter values associated with the [Stmt.Status] method.
//
// https://sqlite.org/c3ref/c_stmtstatus_counter.html
type StmtStatus uint32
const (
STMTSTATUS_FULLSCAN_STEP StmtStatus = 1
STMTSTATUS_SORT StmtStatus = 2
STMTSTATUS_AUTOINDEX StmtStatus = 3
STMTSTATUS_VM_STEP StmtStatus = 4
STMTSTATUS_REPREPARE StmtStatus = 5
STMTSTATUS_RUN StmtStatus = 6
STMTSTATUS_FILTER_MISS StmtStatus = 7
STMTSTATUS_FILTER_HIT StmtStatus = 8
STMTSTATUS_MEMUSED StmtStatus = 99
)
// DBConfig are the available database connection configuration options.
//
// https://sqlite.org/c3ref/c_dbconfig_defensive.html
type DBConfig uint32
const (
// DBCONFIG_MAINDBNAME DBConfig = 1000
// DBCONFIG_LOOKASIDE DBConfig = 1001
DBCONFIG_ENABLE_FKEY DBConfig = 1002
DBCONFIG_ENABLE_TRIGGER DBConfig = 1003
DBCONFIG_ENABLE_FTS3_TOKENIZER DBConfig = 1004
DBCONFIG_ENABLE_LOAD_EXTENSION DBConfig = 1005
DBCONFIG_NO_CKPT_ON_CLOSE DBConfig = 1006
DBCONFIG_ENABLE_QPSG DBConfig = 1007
DBCONFIG_TRIGGER_EQP DBConfig = 1008
DBCONFIG_RESET_DATABASE DBConfig = 1009
DBCONFIG_DEFENSIVE DBConfig = 1010
DBCONFIG_WRITABLE_SCHEMA DBConfig = 1011
DBCONFIG_LEGACY_ALTER_TABLE DBConfig = 1012
DBCONFIG_DQS_DML DBConfig = 1013
DBCONFIG_DQS_DDL DBConfig = 1014
DBCONFIG_ENABLE_VIEW DBConfig = 1015
DBCONFIG_LEGACY_FILE_FORMAT DBConfig = 1016
DBCONFIG_TRUSTED_SCHEMA DBConfig = 1017
DBCONFIG_STMT_SCANSTATUS DBConfig = 1018
DBCONFIG_REVERSE_SCANORDER DBConfig = 1019
)
// Datatype is a fundamental datatype of SQLite.

View File

@@ -32,14 +32,14 @@ func (ctx Context) Conn() *Conn {
// https://sqlite.org/c3ref/get_auxdata.html
func (ctx Context) SetAuxData(n int, data any) {
ptr := util.AddHandle(ctx.c.ctx, data)
ctx.c.call(ctx.c.api.setAuxData, uint64(ctx.handle), uint64(n), uint64(ptr))
ctx.c.call("sqlite3_set_auxdata_go", uint64(ctx.handle), uint64(n), uint64(ptr))
}
// GetAuxData returns metadata for argument n of the function.
//
// https://sqlite.org/c3ref/get_auxdata.html
func (ctx Context) GetAuxData(n int) any {
ptr := uint32(ctx.c.call(ctx.c.api.getAuxData, uint64(ctx.handle), uint64(n)))
ptr := uint32(ctx.c.call("sqlite3_get_auxdata", uint64(ctx.handle), uint64(n)))
return util.GetHandle(ctx.c.ctx, ptr)
}
@@ -67,7 +67,7 @@ func (ctx Context) ResultInt(value int) {
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultInt64(value int64) {
ctx.c.call(ctx.c.api.resultInteger,
ctx.c.call("sqlite3_result_int64",
uint64(ctx.handle), uint64(value))
}
@@ -75,7 +75,7 @@ func (ctx Context) ResultInt64(value int64) {
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultFloat(value float64) {
ctx.c.call(ctx.c.api.resultFloat,
ctx.c.call("sqlite3_result_double",
uint64(ctx.handle), math.Float64bits(value))
}
@@ -84,9 +84,9 @@ func (ctx Context) ResultFloat(value float64) {
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultText(value string) {
ptr := ctx.c.newString(value)
ctx.c.call(ctx.c.api.resultText,
ctx.c.call("sqlite3_result_text64",
uint64(ctx.handle), uint64(ptr), uint64(len(value)),
uint64(ctx.c.api.destructor), _UTF8)
uint64(ctx.c.freer), _UTF8)
}
// ResultRawText sets the text result of the function to a []byte.
@@ -94,9 +94,9 @@ func (ctx Context) ResultText(value string) {
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultRawText(value []byte) {
ptr := ctx.c.newBytes(value)
ctx.c.call(ctx.c.api.resultText,
ctx.c.call("sqlite3_result_text64",
uint64(ctx.handle), uint64(ptr), uint64(len(value)),
uint64(ctx.c.api.destructor), _UTF8)
uint64(ctx.c.freer), _UTF8)
}
// ResultBlob sets the result of the function to a []byte.
@@ -105,16 +105,16 @@ func (ctx Context) ResultRawText(value []byte) {
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultBlob(value []byte) {
ptr := ctx.c.newBytes(value)
ctx.c.call(ctx.c.api.resultBlob,
ctx.c.call("sqlite3_result_blob64",
uint64(ctx.handle), uint64(ptr), uint64(len(value)),
uint64(ctx.c.api.destructor))
uint64(ctx.c.freer))
}
// ResultZeroBlob sets the result of the function to a zero-filled, length n BLOB.
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultZeroBlob(n int64) {
ctx.c.call(ctx.c.api.resultZeroBlob,
ctx.c.call("sqlite3_result_zeroblob64",
uint64(ctx.handle), uint64(n))
}
@@ -122,7 +122,7 @@ func (ctx Context) ResultZeroBlob(n int64) {
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultNull() {
ctx.c.call(ctx.c.api.resultNull,
ctx.c.call("sqlite3_result_null",
uint64(ctx.handle))
}
@@ -153,9 +153,9 @@ func (ctx Context) resultRFC3339Nano(value time.Time) {
buf := util.View(ctx.c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
ctx.c.call(ctx.c.api.resultText,
ctx.c.call("sqlite3_result_text64",
uint64(ctx.handle), uint64(ptr), uint64(len(buf)),
uint64(ctx.c.api.destructor), _UTF8)
uint64(ctx.c.freer), _UTF8)
}
// ResultPointer sets the result of the function to NULL, just like [Context.ResultNull],
@@ -165,7 +165,7 @@ func (ctx Context) resultRFC3339Nano(value time.Time) {
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultPointer(ptr any) {
valPtr := util.AddHandle(ctx.c.ctx, ptr)
ctx.c.call(ctx.c.api.resultPointer, uint64(valPtr))
ctx.c.call("sqlite3_result_pointer_go", uint64(valPtr))
}
// ResultJSON sets the result of the function to the JSON encoding of value.
@@ -175,6 +175,7 @@ func (ctx Context) ResultJSON(value any) {
data, err := json.Marshal(value)
if err != nil {
ctx.ResultError(err)
return
}
ctx.ResultRawText(data)
}
@@ -183,10 +184,11 @@ func (ctx Context) ResultJSON(value any) {
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultValue(value Value) {
if value.sqlite != ctx.c.sqlite {
if value.c != ctx.c {
ctx.ResultError(MISUSE)
return
}
ctx.c.call(ctx.c.api.resultValue,
ctx.c.call("sqlite3_result_value",
uint64(ctx.handle), uint64(value.handle))
}
@@ -195,24 +197,33 @@ func (ctx Context) ResultValue(value Value) {
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultError(err error) {
if errors.Is(err, NOMEM) {
ctx.c.call(ctx.c.api.resultErrorMem, uint64(ctx.handle))
ctx.c.call("sqlite3_result_error_nomem", uint64(ctx.handle))
return
}
if errors.Is(err, TOOBIG) {
ctx.c.call(ctx.c.api.resultErrorBig, uint64(ctx.handle))
ctx.c.call("sqlite3_result_error_toobig", uint64(ctx.handle))
return
}
msg, code := errorCode(err, _OK)
if msg != "" {
ptr := ctx.c.newString(msg)
ctx.c.call(ctx.c.api.resultError,
defer ctx.c.arena.mark()()
ptr := ctx.c.arena.string(msg)
ctx.c.call("sqlite3_result_error",
uint64(ctx.handle), uint64(ptr), uint64(len(msg)))
ctx.c.free(ptr)
}
if code != _OK {
ctx.c.call(ctx.c.api.resultErrorCode,
ctx.c.call("sqlite3_result_error_code",
uint64(ctx.handle), uint64(code))
}
}
// VTabNoChange may return true if a column is being fetched as part
// of an update during which the column value will not change.
//
// https://www.sqlite.org/c3ref/vtab_nochange.html
func (ctx Context) VTabNoChange() bool {
r := ctx.c.call("sqlite3_vtab_nochange", uint64(ctx.handle))
return r != 0
}

View File

@@ -12,6 +12,18 @@
//
// sql.Open("sqlite3", "file:demo.db?_txlock=immediate")
//
// Possible values are: "deferred", "immediate", "exclusive".
// A [read-only] transaction is always "deferred", regardless of "_txlock".
//
// The time encoding/decoding format can be specified using "_timefmt":
//
// sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite")
//
// Possible values are: "auto" (the default), "sqlite", "rfc3339";
// "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite;
// "sqlite" encodes as SQLite and decodes any [format] supported by SQLite;
// "rfc3339" encodes and decodes RFC 3339 only.
//
// [PRAGMA] statements can be specified using "_pragma":
//
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)")
@@ -23,7 +35,9 @@
//
// [URI]: https://sqlite.org/uri.html
// [PRAGMA]: https://sqlite.org/pragma.html
// [format]: https://sqlite.org/lang_datefunc.html#time_values
// [TRANSACTION]: https://sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
// [read-only]: https://pkg.go.dev/database/sql#TxOptions
package driver
import (
@@ -36,6 +50,7 @@ import (
"net/url"
"strings"
"time"
"unsafe"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -43,7 +58,7 @@ import (
// This variable can be replaced with -ldflags:
//
// go build -ldflags="-X github.com/ncruces/go-sqlite3.driverName=sqlite"
// go build -ldflags="-X github.com/ncruces/go-sqlite3/driver.driverName=sqlite"
var driverName = "sqlite3"
func init() {
@@ -81,23 +96,52 @@ func (sqlite) OpenConnector(name string) (driver.Connector, error) {
func newConnector(name string, init func(*sqlite3.Conn) error) (*connector, error) {
c := connector{name: name, init: init}
var txlock, timefmt string
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
query, err := url.ParseQuery(after)
if err != nil {
return nil, err
}
c.txlock = query.Get("_txlock")
c.pragmas = len(query["_pragma"]) > 0
txlock = query.Get("_txlock")
timefmt = query.Get("_timefmt")
c.pragmas = query.Has("_pragma")
}
}
switch txlock {
case "":
c.txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
c.txBegin = "BEGIN " + txlock
default:
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", txlock)
}
switch timefmt {
case "":
c.tmRead = sqlite3.TimeFormatAuto
c.tmWrite = sqlite3.TimeFormatDefault
case "sqlite":
c.tmRead = sqlite3.TimeFormatAuto
c.tmWrite = sqlite3.TimeFormat3
case "rfc3339":
c.tmRead = sqlite3.TimeFormatDefault
c.tmWrite = sqlite3.TimeFormatDefault
default:
c.tmRead = sqlite3.TimeFormat(timefmt)
c.tmWrite = sqlite3.TimeFormat(timefmt)
}
return &c, nil
}
type connector struct {
init func(*sqlite3.Conn) error
name string
txlock string
txBegin string
tmRead sqlite3.TimeFormat
tmWrite sqlite3.TimeFormat
pragmas bool
}
@@ -106,7 +150,12 @@ func (n *connector) Driver() driver.Driver {
}
func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
var c conn
c := &conn{
txBegin: n.txBegin,
tmRead: n.tmRead,
tmWrite: n.tmWrite,
}
c.Conn, err = sqlite3.Open(n.name)
if err != nil {
return nil, err
@@ -120,14 +169,6 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
switch n.txlock {
case "":
c.txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
c.txBegin = "BEGIN " + n.txlock
default:
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", n.txlock)
}
if !n.pragmas {
err = c.Conn.Exec(`PRAGMA busy_timeout=60000`)
if err != nil {
@@ -155,7 +196,7 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
return nil, err
}
}
return &c, nil
return c, nil
}
type conn struct {
@@ -163,6 +204,8 @@ type conn struct {
txBegin string
txCommit string
txRollback string
tmRead sqlite3.TimeFormat
tmWrite sqlite3.TimeFormat
readOnly byte
}
@@ -247,19 +290,10 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
return nil, err
}
if tail != "" {
// Check if the tail contains any SQL.
st, _, err := c.Conn.Prepare(tail)
if err != nil {
s.Close()
return nil, err
}
if st != nil {
s.Close()
st.Close()
return nil, util.TailErr
}
s.Close()
return nil, util.TailErr
}
return &stmt{s, c.Conn}, nil
return &stmt{Stmt: s, tmRead: c.tmRead, tmWrite: c.tmWrite}, nil
}
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
@@ -270,7 +304,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
if savept, ok := ctx.(*saveptCtx); ok {
// Called from driver.Savepoint.
savept.Savepoint = c.Savepoint()
savept.Savepoint = c.Conn.Savepoint()
return resultRowsAffected(0), nil
}
@@ -290,8 +324,9 @@ func (*conn) CheckNamedValue(arg *driver.NamedValue) error {
}
type stmt struct {
Stmt *sqlite3.Stmt
Conn *sqlite3.Conn
*sqlite3.Stmt
tmWrite sqlite3.TimeFormat
tmRead sqlite3.TimeFormat
}
var (
@@ -301,10 +336,6 @@ var (
_ driver.NamedValueChecker = &stmt{}
)
func (s *stmt) Close() error {
return s.Stmt.Close()
}
func (s *stmt) NumInput() int {
n := s.Stmt.BindCount()
for i := 1; i <= n; i++ {
@@ -331,15 +362,15 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
return nil, err
}
old := s.Conn.SetInterrupt(ctx)
defer s.Conn.SetInterrupt(old)
old := s.Stmt.Conn().SetInterrupt(ctx)
defer s.Stmt.Conn().SetInterrupt(old)
err = s.Stmt.Exec()
if err != nil {
return nil, err
}
return newResult(s.Conn), nil
return newResult(s.Stmt.Conn()), nil
}
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
@@ -347,7 +378,7 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
if err != nil {
return nil, err
}
return &rows{ctx, s.Stmt, s.Conn}, nil
return &rows{ctx: ctx, stmt: s}, nil
}
func (s *stmt) setupBindings(args []driver.NamedValue) error {
@@ -386,11 +417,11 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
case sqlite3.ZeroBlob:
err = s.Stmt.BindZeroBlob(id, int64(a))
case time.Time:
err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
case interface{ Pointer() any }:
err = s.Stmt.BindPointer(id, a.Pointer())
case interface{ JSON() any }:
err = s.Stmt.BindJSON(id, a.JSON())
err = s.Stmt.BindTime(id, a, s.tmWrite)
case util.JSON:
err = s.Stmt.BindJSON(id, a.Value)
case util.PointerUnwrap:
err = s.Stmt.BindPointer(id, util.UnwrapPointer(a))
case nil:
err = s.Stmt.BindNull(id)
default:
@@ -407,9 +438,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) {
case bool, int, int64, float64, string, []byte,
sqlite3.ZeroBlob, time.Time,
interface{ Pointer() any },
interface{ JSON() any },
time.Time, sqlite3.ZeroBlob,
util.JSON, util.PointerUnwrap,
nil:
return nil
default:
@@ -449,27 +479,52 @@ func (r resultRowsAffected) RowsAffected() (int64, error) {
}
type rows struct {
ctx context.Context
Stmt *sqlite3.Stmt
Conn *sqlite3.Conn
ctx context.Context
*stmt
names []string
types []string
}
func (r *rows) Close() error {
r.Stmt.ClearBindings()
return r.Stmt.Reset()
}
func (r *rows) Columns() []string {
count := r.Stmt.ColumnCount()
columns := make([]string, count)
for i := range columns {
columns[i] = r.Stmt.ColumnName(i)
if r.names == nil {
count := r.Stmt.ColumnCount()
r.names = make([]string, count)
for i := range r.names {
r.names[i] = r.Stmt.ColumnName(i)
}
}
return columns
return r.names
}
func (r *rows) declType(index int) string {
if r.types == nil {
count := r.Stmt.ColumnCount()
r.types = make([]string, count)
for i := range r.types {
r.types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i))
}
}
return r.types[index]
}
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
decltype := r.declType(index)
if len := len(decltype); len > 0 && decltype[len-1] == ')' {
if i := strings.LastIndexByte(decltype, '('); i >= 0 {
decltype = decltype[:i]
}
}
return strings.TrimSpace(decltype)
}
func (r *rows) Next(dest []driver.Value) error {
old := r.Conn.SetInterrupt(r.ctx)
defer r.Conn.SetInterrupt(old)
old := r.Stmt.Conn().SetInterrupt(r.ctx)
defer r.Stmt.Conn().SetInterrupt(old)
if !r.Stmt.Step() {
if err := r.Stmt.Err(); err != nil {
@@ -478,22 +533,34 @@ func (r *rows) Next(dest []driver.Value) error {
return io.EOF
}
data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest))
err := r.Stmt.Columns(data)
for i := range dest {
switch r.Stmt.ColumnType(i) {
case sqlite3.INTEGER:
dest[i] = r.Stmt.ColumnInt64(i)
case sqlite3.FLOAT:
dest[i] = r.Stmt.ColumnFloat(i)
case sqlite3.BLOB:
dest[i] = r.Stmt.ColumnRawBlob(i)
case sqlite3.TEXT:
dest[i] = stringOrTime(r.Stmt.ColumnRawText(i))
case sqlite3.NULL:
dest[i] = nil
default:
panic(util.AssertErr())
if t, ok := r.decodeTime(i, dest[i]); ok {
dest[i] = t
} else if s, ok := dest[i].(string); ok {
dest[i] = stringOrTime(s)
}
}
return r.Stmt.Err()
return err
}
func (r *rows) decodeTime(i int, v any) (_ time.Time, _ bool) {
if r.tmRead == sqlite3.TimeFormatDefault {
return
}
switch r.declType(i) {
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
// maybe
default:
return
}
switch v.(type) {
case int64, float64, string:
// maybe
default:
return
}
t, err := r.tmRead.Decode(v)
return t, err == nil
}

View File

@@ -6,6 +6,7 @@ import (
"database/sql"
"errors"
"math"
"net/url"
"path/filepath"
"testing"
"time"
@@ -114,13 +115,7 @@ func Test_Open_txLock(t *testing.T) {
func Test_Open_txLock_invalid(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", "file::memory:?_txlock=xclusive")
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Conn(context.TODO())
_, err := sql.Open("sqlite3", "file::memory:?_txlock=xclusive")
if err == nil {
t.Fatal("want error")
}
@@ -186,12 +181,6 @@ func Test_Prepare(t *testing.T) {
}
defer db.Close()
stmt, err := db.Prepare(`SELECT 1; -- HERE`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
var serr *sqlite3.Error
_, err = db.Prepare(`SELECT`)
if err == nil {
@@ -207,18 +196,14 @@ func Test_Prepare(t *testing.T) {
t.Error("got message:", got)
}
_, err = db.Prepare(`SELECT 1; `)
if err.Error() != string(util.TailErr) {
t.Error("want tailErr")
}
_, err = db.Prepare(`SELECT 1; SELECT`)
if err == nil {
t.Error("want error")
}
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != sqlite3.ERROR {
t.Errorf("got %d, want sqlite3.ERROR", rc)
}
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
t.Error("got message:", got)
if err.Error() != string(util.TailErr) {
t.Error("want tailErr")
}
_, err = db.Prepare(`SELECT 1; SELECT 2`)
@@ -311,3 +296,39 @@ func Test_QueryRow_blob_null(t *testing.T) {
}
}
}
func Test_time(t *testing.T) {
t.Parallel()
for _, fmt := range []string{"auto", "sqlite", "rfc3339", time.ANSIC} {
t.Run(fmt, func(t *testing.T) {
db, err := sql.Open("sqlite3", "file::memory:?_timefmt="+url.QueryEscape(fmt))
if err != nil {
t.Fatal(err)
}
defer db.Close()
twosday := time.Date(2022, 2, 22, 22, 22, 22, 0, time.UTC)
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (at DATETIME)`)
if err != nil {
t.Fatal(err)
}
_, err = db.Exec(`INSERT INTO test VALUES (?)`, twosday)
if err != nil {
t.Fatal(err)
}
var got time.Time
err = db.QueryRow(`SELECT * FROM test`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if !got.Equal(twosday) {
t.Errorf("got: %v", got)
}
})
}
}

View File

@@ -9,23 +9,24 @@ import (
// if it roundtrips back to the same string.
// This way times can be persisted to, and recovered from, the database,
// but if a string is needed, [database/sql] will recover the same string.
func stringOrTime(text []byte) driver.Value {
func stringOrTime(text string) driver.Value {
// Weed out (some) values that can't possibly be
// [time.RFC3339Nano] timestamps.
if len(text) < len("2006-01-02T15:04:05Z") {
return string(text)
return text
}
if len(text) > len(time.RFC3339Nano) {
return string(text)
return text
}
if text[4] != '-' || text[10] != 'T' || text[16] != ':' {
return string(text)
return text
}
// Slow path.
date, err := time.Parse(time.RFC3339Nano, string(text))
if err == nil && date.Format(time.RFC3339Nano) == string(text) {
var buf [len(time.RFC3339Nano)]byte
date, err := time.Parse(time.RFC3339Nano, text)
if err == nil && text == string(date.AppendFormat(buf[:0], time.RFC3339Nano)) {
return date
}
return string(text)
return text
}

View File

@@ -22,7 +22,7 @@ func Fuzz_stringOrTime_1(f *testing.F) {
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
f.Fuzz(func(t *testing.T, str string) {
value := stringOrTime([]byte(str))
value := stringOrTime(str)
switch v := value.(type) {
case time.Time:
@@ -49,17 +49,17 @@ func Fuzz_stringOrTime_1(f *testing.F) {
// This checks that any [time.Time] can be recovered as a [time.Time],
// with nanosecond accuracy, and preserving any timezone offset.
func Fuzz_stringOrTime_2(f *testing.F) {
f.Add(0, 0)
f.Add(0, 1)
f.Add(0, -1)
f.Add(0, 999_999_999)
f.Add(0, 1_000_000_000)
f.Add(7956915742, 222_222_222) // twosday
f.Add(639095955742, 222_222_222) // twosday, year 22222AD
f.Add(-763421161058, 222_222_222) // twosday, year 22222BC
f.Add(int64(0), int64(0))
f.Add(int64(0), int64(1))
f.Add(int64(0), int64(-1))
f.Add(int64(0), int64(999_999_999))
f.Add(int64(0), int64(1_000_000_000))
f.Add(int64(7956915742), int64(222_222_222)) // twosday
f.Add(int64(639095955742), int64(222_222_222)) // twosday, year 22222AD
f.Add(int64(-763421161058), int64(222_222_222)) // twosday, year 22222BC
checkTime := func(t testing.TB, date time.Time) {
value := stringOrTime([]byte(date.Format(time.RFC3339Nano)))
value := stringOrTime(date.Format(time.RFC3339Nano))
switch v := value.(type) {
case time.Time:
@@ -67,7 +67,7 @@ func Fuzz_stringOrTime_2(f *testing.F) {
if !v.Equal(date) {
t.Fatalf("did not round-trip: %v", date)
}
// Make with the same zone offset:
// With the same zone offset:
_, off1 := v.Zone()
_, off2 := date.Zone()
if off1 != off2 {
@@ -80,7 +80,7 @@ func Fuzz_stringOrTime_2(f *testing.F) {
}
}
f.Fuzz(func(t *testing.T, sec, nsec int) {
f.Fuzz(func(t *testing.T, sec, nsec int64) {
// Reduce the search space.
if 1e12 < sec || sec < -1e12 {
// Dates before 29000BC and after 33000AD; I think we're safe.
@@ -91,7 +91,7 @@ func Fuzz_stringOrTime_2(f *testing.F) {
return
}
unix := time.Unix(int64(sec), int64(nsec))
unix := time.Unix(sec, nsec)
checkTime(t, unix)
checkTime(t, unix.UTC())
checkTime(t, unix.In(time.FixedZone("", -8*3600)))

View File

@@ -1,6 +1,6 @@
# Embeddable WASM build of SQLite
This folder includes an embeddable WASM build of SQLite 3.44.2 for use with
This folder includes an embeddable WASM build of SQLite 3.45.0 for use with
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
The following optional features are compiled in:
@@ -12,6 +12,7 @@ The following optional features are compiled in:
- [soundex](https://sqlite.org/lang_corefunc.html#soundex)
- [base64](https://github.com/sqlite/sqlite/blob/master/ext/misc/base64.c)
- [decimal](https://github.com/sqlite/sqlite/blob/master/ext/misc/decimal.c)
- [ieee754](https://github.com/sqlite/sqlite/blob/master/ext/misc/ieee754.c)
- [regexp](https://github.com/sqlite/sqlite/blob/master/ext/misc/regexp.c)
- [series](https://github.com/sqlite/sqlite/blob/master/ext/misc/series.c)
- [uint](https://github.com/sqlite/sqlite/blob/master/ext/misc/uint.c)

View File

@@ -5,9 +5,10 @@ cd -P -- "$(dirname -- "$0")"
ROOT=../
BINARYEN="$ROOT/tools/binaryen-version_116/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-21.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
"$WASI_SDK/clang" --target=wasm32-wasi -std=c17 -flto -g0 -O2 \
-Wall -Wextra -Wno-unused-parameter \
-o sqlite3.wasm "$ROOT/sqlite3/main.c" \
-I"$ROOT/sqlite3" \
-mexec-model=reactor \
@@ -23,7 +24,7 @@ WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
trap 'rm -f sqlite3.tmp' EXIT
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
"$BINARYEN/wasm-opt" -g --strip -c -O3 \
"$BINARYEN/wasm-opt" -g --strip --strip-producers -c -O3 \
sqlite3.tmp -o sqlite3.wasm \
--enable-simd --enable-mutable-globals --enable-multivalue \
--enable-bulk-memory --enable-reference-types \

View File

@@ -1,91 +1,105 @@
free
malloc
malloc_destructor
sqlite3_errcode
sqlite3_errstr
sqlite3_errmsg
sqlite3_error_offset
sqlite3_open_v2
sqlite3_close
sqlite3_close_v2
sqlite3_prepare_v3
sqlite3_finalize
sqlite3_reset
sqlite3_step
sqlite3_exec
sqlite3_interrupt
sqlite3_progress_handler_go
sqlite3_clear_bindings
sqlite3_aggregate_context
sqlite3_anycollseq_init
sqlite3_backup_finish
sqlite3_backup_init
sqlite3_backup_pagecount
sqlite3_backup_remaining
sqlite3_backup_step
sqlite3_bind_blob64
sqlite3_bind_double
sqlite3_bind_int64
sqlite3_bind_null
sqlite3_bind_parameter_count
sqlite3_bind_parameter_index
sqlite3_bind_parameter_name
sqlite3_bind_null
sqlite3_bind_int64
sqlite3_bind_double
sqlite3_bind_text64
sqlite3_bind_blob64
sqlite3_bind_zeroblob64
sqlite3_bind_pointer_go
sqlite3_column_count
sqlite3_column_name
sqlite3_column_type
sqlite3_column_int64
sqlite3_column_double
sqlite3_column_text
sqlite3_bind_text64
sqlite3_bind_value
sqlite3_bind_zeroblob64
sqlite3_blob_bytes
sqlite3_blob_close
sqlite3_blob_open
sqlite3_blob_read
sqlite3_blob_reopen
sqlite3_blob_write
sqlite3_changes64
sqlite3_clear_bindings
sqlite3_close
sqlite3_close_v2
sqlite3_column_blob
sqlite3_column_bytes
sqlite3_blob_open
sqlite3_blob_close
sqlite3_blob_reopen
sqlite3_blob_bytes
sqlite3_blob_read
sqlite3_blob_write
sqlite3_backup_init
sqlite3_backup_step
sqlite3_backup_finish
sqlite3_backup_remaining
sqlite3_backup_pagecount
sqlite3_uri_parameter
sqlite3_uri_key
sqlite3_changes64
sqlite3_last_insert_rowid
sqlite3_get_autocommit
sqlite3_anycollseq_init
sqlite3_column_count
sqlite3_column_decltype
sqlite3_column_double
sqlite3_column_int64
sqlite3_column_name
sqlite3_column_text
sqlite3_column_type
sqlite3_column_value
sqlite3_columns_go
sqlite3_config_log_go
sqlite3_create_aggregate_function_go
sqlite3_create_collation_go
sqlite3_create_function_go
sqlite3_create_aggregate_function_go
sqlite3_create_module_go
sqlite3_create_window_function_go
sqlite3_aggregate_context
sqlite3_user_data
sqlite3_set_auxdata_go
sqlite3_db_config
sqlite3_declare_vtab
sqlite3_errcode
sqlite3_errmsg
sqlite3_error_offset
sqlite3_errstr
sqlite3_exec
sqlite3_finalize
sqlite3_get_autocommit
sqlite3_get_auxdata
sqlite3_value_type
sqlite3_value_int64
sqlite3_value_double
sqlite3_value_text
sqlite3_value_blob
sqlite3_value_bytes
sqlite3_value_pointer_go
sqlite3_result_null
sqlite3_result_int64
sqlite3_result_double
sqlite3_result_text64
sqlite3_interrupt
sqlite3_last_insert_rowid
sqlite3_open_v2
sqlite3_overload_function
sqlite3_prepare_v3
sqlite3_progress_handler_go
sqlite3_reset
sqlite3_result_blob64
sqlite3_result_zeroblob64
sqlite3_result_pointer_go
sqlite3_result_value
sqlite3_result_double
sqlite3_result_error
sqlite3_result_error_code
sqlite3_result_error_nomem
sqlite3_result_error_toobig
sqlite3_create_module_go
sqlite3_declare_vtab
sqlite3_vtab_config_go
sqlite3_result_int64
sqlite3_result_null
sqlite3_result_pointer_go
sqlite3_result_text64
sqlite3_result_value
sqlite3_result_zeroblob64
sqlite3_set_auxdata_go
sqlite3_step
sqlite3_stmt_busy
sqlite3_stmt_readonly
sqlite3_stmt_status
sqlite3_uri_key
sqlite3_uri_parameter
sqlite3_user_data
sqlite3_value_blob
sqlite3_value_bytes
sqlite3_value_double
sqlite3_value_dup
sqlite3_value_free
sqlite3_value_int64
sqlite3_value_nochange
sqlite3_value_numeric_type
sqlite3_value_pointer_go
sqlite3_value_text
sqlite3_value_type
sqlite3_vtab_collation
sqlite3_vtab_config_go
sqlite3_vtab_distinct
sqlite3_vtab_in
sqlite3_vtab_in_first
sqlite3_vtab_in_next
sqlite3_vtab_rhs_value
sqlite3_vtab_nochange
sqlite3_vtab_on_conflict
sqlite3_vtab_on_conflict
sqlite3_vtab_rhs_value

Binary file not shown.

View File

@@ -44,8 +44,7 @@ func (e *Error) Error() string {
}
if e.msg != "" {
b.WriteByte(':')
b.WriteByte(' ')
b.WriteString(": ")
b.WriteString(e.msg)
}
@@ -139,12 +138,14 @@ func (e ExtendedErrorCode) Timeout() bool {
func errorCode(err error, def ErrorCode) (msg string, code uint32) {
switch code := err.(type) {
case ErrorCode:
return "", uint32(code)
case ExtendedErrorCode:
return "", uint32(code)
case nil:
return "", _OK
case ErrorCode:
return "", uint32(code)
case xErrorCode:
return "", uint32(code)
case *Error:
return code.msg, uint32(code.code)
}
var ecode ErrorCode

View File

@@ -135,7 +135,7 @@ func Test_ErrorCode_Error(t *testing.T) {
// Test all error codes.
for i := 0; i == int(ErrorCode(i)); i++ {
want := "sqlite3: "
r := db.call(db.api.errstr, uint64(i))
r := db.call("sqlite3_errstr", uint64(i))
want += util.ReadString(db.mod, uint32(r), _MAX_NAME)
got := ErrorCode(i).Error()
@@ -157,7 +157,7 @@ func Test_ExtendedErrorCode_Error(t *testing.T) {
// Test all extended error codes.
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
want := "sqlite3: "
r := db.call(db.api.errstr, uint64(i))
r := db.call("sqlite3_errstr", uint64(i))
want += util.ReadString(db.mod, uint32(r), _MAX_NAME)
got := ExtendedErrorCode(i).Error()

View File

@@ -1,4 +1,6 @@
// Package array provides the array table-valued SQL function.
//
// https://sqlite.org/carray.html
package array
import (
@@ -6,17 +8,17 @@ import (
"reflect"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the array single-argument, table-valued SQL function.
// The argument must be an [sqlite3.Pointer] to a Go slice or array
// of ints, floats, bools, strings or blobs.
//
// https://sqlite.org/carray.html
// The argument must be bound to a Go slice or array of
// ints, floats, bools, strings or byte slices,
// using [sqlite3.BindPointer] or [sqlite3.Pointer].
func Register(db *sqlite3.Conn) {
sqlite3.CreateModule[array](db, "array", nil,
func(db *sqlite3.Conn, arg ...string) (array, error) {
err := db.DeclareVtab(`CREATE TABLE x(value, array HIDDEN)`)
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (array, error) {
err := db.DeclareVTab(`CREATE TABLE x(value, array HIDDEN)`)
return array{}, err
})
}
@@ -102,7 +104,7 @@ func (c *cursor) Column(ctx *sqlite3.Context, n int) error {
ctx.ResultBlob(v.Bytes())
default:
return fmt.Errorf("array: unsupported element:%.0w %v", sqlite3.MISMATCH, v.Type())
return fmt.Errorf("array: unsupported element:%.0w %v", sqlite3.MISMATCH, util.ReflectType(v))
}
return nil
}
@@ -119,17 +121,16 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
return nil
}
func indexable(v reflect.Value) (_ reflect.Value, err error) {
if v.Kind() == reflect.Slice {
func indexable(v reflect.Value) (reflect.Value, error) {
switch v.Kind() {
case reflect.Slice:
return v, nil
}
if v.Kind() == reflect.Array {
case reflect.Array:
return v, nil
}
if v.Kind() == reflect.Pointer {
case reflect.Pointer:
if v := v.Elem(); v.Kind() == reflect.Array {
return v, nil
}
}
return v, fmt.Errorf("array: unsupported argument:%.0w %v", sqlite3.MISMATCH, v.Type())
return v, fmt.Errorf("array: unsupported argument:%.0w %v", sqlite3.MISMATCH, util.ReflectType(v))
}

View File

@@ -13,7 +13,7 @@ import (
"github.com/ncruces/go-sqlite3/ext/array"
)
func Example() {
func Example_driver() {
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
array.Register(c)
return nil
@@ -51,7 +51,45 @@ func Example() {
// geopoly_within
}
func Example() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
array.Register(db)
stmt, _, err := db.Prepare(`
SELECT name
FROM pragma_function_list
WHERE name like 'geopoly%' AND narg IN array(?)`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
err = stmt.BindPointer(1, [...]int{2, 3, 4})
if err != nil {
log.Fatal(err)
}
for stmt.Step() {
fmt.Printf("%s\n", stmt.ColumnText(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Unordered output:
// geopoly_regular
// geopoly_overlap
// geopoly_contains_point
// geopoly_within
}
func Test_cursor_Column(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
array.Register(c)
return nil
@@ -90,3 +128,29 @@ func Test_cursor_Column(t *testing.T) {
log.Fatal(err)
}
}
func Test_array_errors(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
array.Register(db)
err = db.Exec(`SELECT * FROM array()`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`SELECT * FROM array(?)`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
}

View File

@@ -1,70 +0,0 @@
// Package blob provides an alternative interface to incremental BLOB I/O.
package blob
import (
"errors"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the blob_open SQL function:
//
// blob_open(schema, table, column, rowid, flags, callback, args...)
//
// The callback must be an [sqlite3.Pointer] to an [OpenCallback].
// Any optional args will be passed to the callback,
// along with the [sqlite3.Blob] handle.
//
// https://sqlite.org/c3ref/blob.html
func Register(db *sqlite3.Conn) {
db.CreateFunction("blob_open", -1,
sqlite3.DETERMINISTIC|sqlite3.DIRECTONLY, openBlob)
}
func openBlob(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) < 6 {
ctx.ResultError(util.ErrorString("wrong number of arguments to function blob_open()"))
return
}
row := arg[3].Int64()
var err error
blob, ok := ctx.GetAuxData(0).(*sqlite3.Blob)
if ok {
err = blob.Reopen(row)
if errors.Is(err, sqlite3.MISUSE) {
// Blob was closed (db, table, column or write changed).
ok = false
}
}
if !ok {
db := arg[0].Text()
table := arg[1].Text()
column := arg[2].Text()
write := arg[4].Bool()
blob, err = ctx.Conn().OpenBlob(db, table, column, row, write)
}
if err != nil {
ctx.ResultError(err)
return
}
fn := arg[5].Pointer().(OpenCallback)
err = fn(blob, arg[6:]...)
if err != nil {
ctx.ResultError(err)
return
}
// This ensures the blob is closed if db, table, column or write change.
ctx.SetAuxData(0, blob)
ctx.SetAuxData(1, blob)
ctx.SetAuxData(2, blob)
ctx.SetAuxData(4, blob)
}
// OpenCallback is the type for the blob_open callback.
type OpenCallback func(*sqlite3.Blob, ...sqlite3.Value) error

View File

@@ -1,61 +0,0 @@
package blob_test
import (
"io"
"log"
"os"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/blob"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
func Example() {
// Open the database, registering the extension.
db, err := driver.Open("file:/test.db?vfs=memdb", func(conn *sqlite3.Conn) error {
blob.Register(conn)
return nil
})
if err != nil {
log.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
log.Fatal(err)
}
const message = "Hello BLOB!"
// Create the BLOB.
_, err = db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(len(message)))
if err != nil {
log.Fatal(err)
}
// Write the BLOB.
_, err = db.Exec(`SELECT blob_open('main', 'test', 'col', last_insert_rowid(), true, ?)`,
sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
_, err = io.WriteString(blob, message)
return err
}))
if err != nil {
log.Fatal(err)
}
// Read the BLOB.
_, err = db.Exec(`SELECT blob_open('main', 'test', 'col', rowid, false, ?) FROM test`,
sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
_, err = io.Copy(os.Stdout, blob)
return err
}))
if err != nil {
log.Fatal(err)
}
// Output:
// Hello BLOB!
}

139
ext/blobio/blob.go Normal file
View File

@@ -0,0 +1,139 @@
// Package blobio provides an SQL interface to incremental BLOB I/O.
package blobio
import (
"errors"
"io"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the SQL functions:
//
// readblob(schema, table, column, rowid, offset, n)
//
// Reads n bytes of a blob, starting at offset.
//
// writeblob(schema, table, column, rowid, offset, data)
//
// Writes data into a blob, at the given offset.
//
// openblob(schema, table, column, rowid, write, callback, args...)
//
// Opens blobs for reading or writing.
// The callback is invoked for each open blob,
// and must be bound to an [OpenCallback],
// using [sqlite3.BindPointer] or [sqlite3.Pointer].
// The optional args will be passed to the callback,
// along with the [sqlite3.Blob] handle.
//
// https://sqlite.org/c3ref/blob.html
func Register(db *sqlite3.Conn) {
db.CreateFunction("readblob", 6, sqlite3.DIRECTONLY, readblob)
db.CreateFunction("writeblob", 6, sqlite3.DIRECTONLY, writeblob)
db.CreateFunction("openblob", -1, sqlite3.DIRECTONLY, openblob)
}
// OpenCallback is the type for the openblob callback.
type OpenCallback func(*sqlite3.Blob, ...sqlite3.Value) error
func readblob(ctx sqlite3.Context, arg ...sqlite3.Value) {
blob, err := getAuxBlob(ctx, arg, false)
if err != nil {
ctx.ResultError(err)
return
}
_, err = blob.Seek(arg[4].Int64(), io.SeekStart)
if err != nil {
ctx.ResultError(err)
return
}
n := arg[5].Int64()
if n <= 0 {
return
}
buf := make([]byte, n)
_, err = io.ReadFull(blob, buf)
if err != nil {
ctx.ResultError(err)
return
}
ctx.ResultBlob(buf)
setAuxBlob(ctx, blob, false)
}
func writeblob(ctx sqlite3.Context, arg ...sqlite3.Value) {
blob, err := getAuxBlob(ctx, arg, true)
if err != nil {
ctx.ResultError(err)
return
}
_, err = blob.Seek(arg[4].Int64(), io.SeekStart)
if err != nil {
ctx.ResultError(err)
return
}
_, err = blob.Write(arg[5].RawBlob())
if err != nil {
ctx.ResultError(err)
return
}
setAuxBlob(ctx, blob, false)
}
func openblob(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) < 6 {
ctx.ResultError(util.ErrorString("openblob: wrong number of arguments"))
return
}
blob, err := getAuxBlob(ctx, arg, arg[4].Bool())
if err != nil {
ctx.ResultError(err)
return
}
fn := arg[5].Pointer().(OpenCallback)
err = fn(blob, arg[6:]...)
if err != nil {
ctx.ResultError(err)
return
}
setAuxBlob(ctx, blob, true)
}
func getAuxBlob(ctx sqlite3.Context, arg []sqlite3.Value, write bool) (*sqlite3.Blob, error) {
row := arg[3].Int64()
if blob, ok := ctx.GetAuxData(0).(*sqlite3.Blob); ok {
if err := blob.Reopen(row); errors.Is(err, sqlite3.MISUSE) {
// Blob was closed (db, table, column or write changed).
} else {
return blob, err
}
}
db := arg[0].Text()
table := arg[1].Text()
column := arg[2].Text()
return ctx.Conn().OpenBlob(db, table, column, row, write)
}
func setAuxBlob(ctx sqlite3.Context, blob *sqlite3.Blob, writer bool) {
// This ensures the blob is closed if db, table, column or write change.
ctx.SetAuxData(0, blob) // db
ctx.SetAuxData(1, blob) // table
ctx.SetAuxData(2, blob) // column
if writer {
ctx.SetAuxData(4, blob) // write
}
}

184
ext/blobio/blob_test.go Normal file
View File

@@ -0,0 +1,184 @@
package blobio_test
import (
"io"
"log"
"os"
"reflect"
"testing"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/array"
"github.com/ncruces/go-sqlite3/ext/blobio"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
func Example() {
// Open the database, registering the extension.
db, err := driver.Open("file:/test.db?vfs=memdb", func(conn *sqlite3.Conn) error {
blobio.Register(conn)
return nil
})
if err != nil {
log.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
log.Fatal(err)
}
const message = "Hello BLOB!"
// Create the BLOB.
_, err = db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(len(message)))
if err != nil {
log.Fatal(err)
}
// Write the BLOB.
_, err = db.Exec(`SELECT writeblob('main', 'test', 'col', last_insert_rowid(), 0, ?)`, message)
if err != nil {
log.Fatal(err)
}
// Read the BLOB.
_, err = db.Exec(`SELECT openblob('main', 'test', 'col', rowid, false, ?) FROM test`,
sqlite3.Pointer[blobio.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
_, err = io.Copy(os.Stdout, blob)
return err
}))
if err != nil {
log.Fatal(err)
}
// Output:
// Hello BLOB!
}
func Test_readblob(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
blobio.Register(db)
array.Register(db)
err = db.Exec(`SELECT readblob()`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`
CREATE TABLE IF NOT EXISTS test1 (col);
CREATE TABLE IF NOT EXISTS test2 (col);
INSERT INTO test1 VALUES (x'cafe');
INSERT INTO test2 VALUES (x'babe');
`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT readblob('main', value, 'col', 1, 1, 1) FROM array(?)`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
err = stmt.BindPointer(1, []string{"test1", "test2"})
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
got := stmt.ColumnText(0)
if got != "\xfe" {
t.Errorf("got %q", got)
}
}
if stmt.Step() {
got := stmt.ColumnText(0)
if got != "\xbe" {
t.Errorf("got %q", got)
}
}
err = stmt.Err()
if err != nil {
t.Fatal(err)
}
}
func Test_openblob(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
blobio.Register(db)
array.Register(db)
err = db.Exec(`SELECT openblob()`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`
CREATE TABLE IF NOT EXISTS test1 (col);
CREATE TABLE IF NOT EXISTS test2 (col);
INSERT INTO test1 VALUES (x'cafe');
INSERT INTO test2 VALUES (x'babe');
`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT openblob('main', value, 'col', 1, false, ?) FROM array(?)`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
var got []string
err = stmt.BindPointer(1, blobio.OpenCallback(func(b *sqlite3.Blob, _ ...sqlite3.Value) error {
d, err := io.ReadAll(b)
if err != nil {
return err
}
got = append(got, string(d))
return nil
}))
if err != nil {
t.Fatal(err)
}
err = stmt.BindPointer(2, []string{"test1", "test2"})
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
want := []string{"\xca\xfe", "\xba\xbe"}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
}

36
ext/csv/arg.go Normal file
View File

@@ -0,0 +1,36 @@
package csv
import (
"fmt"
"strconv"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/vtabutil"
)
func uintArg(key, val string) (int, error) {
i, err := strconv.ParseUint(val, 10, 15)
if err != nil {
return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
}
return int(i), nil
}
func boolArg(key, val string) (bool, error) {
if val == "" {
return true, nil
}
b, ok := util.ParseBool(val)
if ok {
return b, nil
}
return false, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
}
func runeArg(key, val string) (rune, error) {
r, _, tail, err := strconv.UnquoteChar(vtabutil.Unquote(val), 0)
if tail != "" || err != nil {
return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
}
return r, nil
}

View File

@@ -1,8 +1,14 @@
package csv
import "testing"
import (
"testing"
"github.com/ncruces/go-sqlite3/util/vtabutil"
)
func Test_uintArg(t *testing.T) {
t.Parallel()
func Test_uintParam(t *testing.T) {
tests := []struct {
arg string
key string
@@ -18,22 +24,22 @@ func Test_uintParam(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.arg, func(t *testing.T) {
key, val := getParam(tt.arg)
key, val := vtabutil.NamedArg(tt.arg)
if key != tt.key {
t.Errorf("getParam() %v, want err %v", key, tt.key)
t.Errorf("NamedArg() %v, want err %v", key, tt.key)
}
got, err := uintParam(key, val)
got, err := uintArg(key, val)
if (err != nil) != tt.err {
t.Fatalf("uintParam() error = %v, want err %v", err, tt.err)
t.Fatalf("uintArg() error = %v, want err %v", err, tt.err)
}
if got != tt.val {
t.Errorf("uintParam() = %v, want %v", got, tt.val)
t.Errorf("uintArg() = %v, want %v", got, tt.val)
}
})
}
}
func Test_boolParam(t *testing.T) {
func Test_boolArg(t *testing.T) {
tests := []struct {
arg string
key string
@@ -54,22 +60,22 @@ func Test_boolParam(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.arg, func(t *testing.T) {
key, val := getParam(tt.arg)
key, val := vtabutil.NamedArg(tt.arg)
if key != tt.key {
t.Errorf("getParam() %v, want err %v", key, tt.key)
t.Errorf("NamedArg() %v, want err %v", key, tt.key)
}
got, err := boolParam(key, val)
got, err := boolArg(key, val)
if (err != nil) != tt.err {
t.Fatalf("boolParam() error = %v, want err %v", err, tt.err)
t.Fatalf("boolArg() error = %v, want err %v", err, tt.err)
}
if got != tt.val {
t.Errorf("boolParam() = %v, want %v", got, tt.val)
t.Errorf("boolArg() = %v, want %v", got, tt.val)
}
})
}
}
func Test_runeParam(t *testing.T) {
func Test_runeArg(t *testing.T) {
tests := []struct {
arg string
key string
@@ -86,16 +92,16 @@ func Test_runeParam(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.arg, func(t *testing.T) {
key, val := getParam(tt.arg)
key, val := vtabutil.NamedArg(tt.arg)
if key != tt.key {
t.Errorf("getParam() %v, want err %v", key, tt.key)
t.Errorf("NamedArg() %v, want err %v", key, tt.key)
}
got, err := runeParam(key, val)
got, err := runeArg(key, val)
if (err != nil) != tt.err {
t.Fatalf("runeParam() error = %v, want err %v", err, tt.err)
t.Fatalf("runeArg() error = %v, want err %v", err, tt.err)
}
if got != tt.val {
t.Errorf("runeParam() = %v, want %v", got, tt.val)
t.Errorf("runeArg() = %v, want %v", got, tt.val)
}
})
}

View File

@@ -7,28 +7,28 @@
package csv
import (
"bufio"
"encoding/csv"
"fmt"
"io"
"math"
"os"
"io/fs"
"strings"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/util/osutil"
"github.com/ncruces/go-sqlite3/util/vtabutil"
)
// Register registers the CSV virtual table.
// If a filename is specified, `os.Open` is used to read it from disk.
// If a filename is specified, [os.Open] is used to open the file.
func Register(db *sqlite3.Conn) {
RegisterOpen(db, func(name string) (io.ReaderAt, error) {
return os.Open(name)
})
RegisterFS(db, osutil.FS{})
}
// RegisterOpen registers the CSV virtual table.
// If a filename is specified, open is used to open the file.
func RegisterOpen(db *sqlite3.Conn, open func(name string) (io.ReaderAt, error)) {
declare := func(db *sqlite3.Conn, arg ...string) (_ *table, err error) {
// RegisterFS registers the CSV virtual table.
// If a filename is specified, fsys is used to open the file.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
var (
filename string
data string
@@ -40,24 +40,24 @@ func RegisterOpen(db *sqlite3.Conn, open func(name string) (io.ReaderAt, error))
done = map[string]struct{}{}
)
for _, arg := range arg[3:] {
key, val := getParam(arg)
for _, arg := range arg {
key, val := vtabutil.NamedArg(arg)
if _, ok := done[key]; ok {
return nil, fmt.Errorf("csv: more than one %q parameter", key)
}
switch key {
case "filename":
filename = unquoteParam(val)
filename = vtabutil.Unquote(val)
case "data":
data = unquoteParam(val)
data = vtabutil.Unquote(val)
case "schema":
schema = unquoteParam(val)
schema = vtabutil.Unquote(val)
case "header":
header, err = boolParam(key, val)
header, err = boolArg(key, val)
case "columns":
columns, err = uintParam(key, val)
columns, err = uintArg(key, val)
case "comma":
comma, err = runeParam(key, val)
comma, err = runeArg(key, val)
default:
return nil, fmt.Errorf("csv: unknown %q parameter", key)
}
@@ -71,41 +71,35 @@ func RegisterOpen(db *sqlite3.Conn, open func(name string) (io.ReaderAt, error))
return nil, fmt.Errorf(`csv: must specify either "filename" or "data" but not both`)
}
var r io.ReaderAt
if filename != "" {
r, err = open(filename)
} else {
r = strings.NewReader(data)
}
if err != nil {
return nil, err
}
table := &table{
r: r,
fsys: fsys,
name: filename,
data: data,
comma: comma,
header: header,
}
defer func() {
if err != nil {
table.Close()
}
}()
if schema == "" && (header || columns < 0) {
csv := table.newReader()
row, err := csv.Read()
if err != nil {
return nil, err
if schema == "" {
var row []string
if header || columns < 0 {
csv, c, err := table.newReader()
defer c.Close()
if err != nil {
return nil, err
}
row, err = csv.Read()
if err != nil {
return nil, err
}
}
schema = getSchema(header, columns, row)
}
err = db.DeclareVtab(schema)
err = db.DeclareVTab(schema)
if err != nil {
return nil, err
}
err = db.VtabConfig(sqlite3.VTAB_DIRECTONLY)
err = db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
if err != nil {
return nil, err
}
@@ -116,20 +110,13 @@ func RegisterOpen(db *sqlite3.Conn, open func(name string) (io.ReaderAt, error))
}
type table struct {
r io.ReaderAt
fsys fs.FS
name string
data string
comma rune
header bool
}
func (t *table) Close() error {
if c, ok := t.r.(io.Closer); ok {
err := c.Close()
t.r = nil
return err
}
return nil
}
func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
idx.EstimatedCost = 1e6
return nil
@@ -143,29 +130,76 @@ func (t *table) Rename(new string) error {
return nil
}
func (t *table) Integrity(schema, table string, flags int) (err error) {
if flags&1 == 0 {
_, err = t.newReader().ReadAll()
func (t *table) Integrity(schema, table string, flags int) error {
if flags&1 != 0 {
return nil
}
csv, c, err := t.newReader()
if err != nil {
return err
}
defer c.Close()
_, err = csv.ReadAll()
return err
}
func (t *table) newReader() (*csv.Reader, io.Closer, error) {
var r io.Reader
var c io.Closer
if t.name != "" {
f, err := t.fsys.Open(t.name)
if err != nil {
return nil, f, err
}
buf := bufio.NewReader(f)
bom, err := buf.Peek(3)
if err != nil {
return nil, f, err
}
if string(bom) == "\xEF\xBB\xBF" {
buf.Discard(3)
}
r = buf
c = f
} else {
r = strings.NewReader(t.data)
c = io.NopCloser(r)
}
csv := csv.NewReader(r)
csv.ReuseRecord = true
csv.Comma = t.comma
return csv, c, nil
}
type cursor struct {
table *table
closer io.Closer
csv *csv.Reader
row []string
rowID int64
}
func (c *cursor) Close() (err error) {
if c.closer != nil {
err = c.closer.Close()
c.closer = nil
}
return err
}
func (t *table) newReader() *csv.Reader {
csv := csv.NewReader(io.NewSectionReader(t.r, 0, math.MaxInt64))
csv.ReuseRecord = true
csv.Comma = t.comma
return csv
}
type cursor struct {
table *table
rowID int64
row []string
csv *csv.Reader
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
c.csv = c.table.newReader()
err := c.Close()
if err != nil {
return err
}
c.csv, c.closer, err = c.table.newReader()
if err != nil {
return err
}
if c.table.header {
c.Next() // skip header
}

View File

@@ -51,6 +51,8 @@ func Example() {
}
func TestRegister(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
@@ -82,8 +84,8 @@ Robert "Griesemer" "gri"`
if !stmt.Step() {
t.Fatal("no rows")
}
if got := stmt.ColumnText(1); got != "Pike" {
t.Errorf("got %q want Pike", got)
if got := stmt.ColumnText(0); got != "Rob" {
t.Errorf("got %q want Rob", got)
}
if stmt.Step() {
t.Fatal("more rows")
@@ -96,16 +98,23 @@ Robert "Griesemer" "gri"`
err = db.Exec(`PRAGMA integrity_check`)
if err != nil {
t.Fatal(err)
t.Error(err)
}
err = db.Exec(`PRAGMA quick_check`)
if err != nil {
t.Error(err)
}
err = db.Exec(`DROP TABLE temp.csv`)
if err != nil {
log.Fatal(err)
t.Error(err)
}
}
func TestRegister_errors(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)

View File

@@ -1,65 +0,0 @@
package csv
import (
"fmt"
"strconv"
"strings"
)
func getParam(arg string) (key, val string) {
key, val, _ = strings.Cut(arg, "=")
key = strings.TrimSpace(key)
val = strings.TrimSpace(val)
return
}
func uintParam(key, val string) (int, error) {
i, err := strconv.ParseUint(val, 10, 15)
if err != nil {
return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
}
return int(i), nil
}
func boolParam(key, val string) (bool, error) {
if val == "" || val == "1" ||
strings.EqualFold(val, "true") ||
strings.EqualFold(val, "yes") ||
strings.EqualFold(val, "on") {
return true, nil
}
if val == "0" ||
strings.EqualFold(val, "false") ||
strings.EqualFold(val, "no") ||
strings.EqualFold(val, "off") {
return false, nil
}
return false, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
}
func runeParam(key, val string) (rune, error) {
r, _, tail, err := strconv.UnquoteChar(unquoteParam(val), 0)
if tail != "" || err != nil {
return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
}
return r, nil
}
func unquoteParam(val string) string {
if len(val) < 2 {
return val
}
if val[0] != val[len(val)-1] {
return val
}
var old, new string
switch val[0] {
default:
return val
case '"':
old, new = `""`, `"`
case '\'':
old, new = `''`, `'`
}
return strings.ReplaceAll(val[1:len(val)-1], old, new)
}

View File

@@ -8,9 +8,9 @@ import (
)
func getSchema(header bool, columns int, row []string) string {
var sep = ""
var sep string
var str strings.Builder
str.WriteString(`CREATE TABLE x(`)
str.WriteString("CREATE TABLE x(")
if 0 <= columns && columns < len(row) {
row = row[:columns]
@@ -20,15 +20,17 @@ func getSchema(header bool, columns int, row []string) string {
if header && f != "" {
str.WriteString(sqlite3.QuoteIdentifier(f))
} else {
str.WriteByte('c')
str.WriteString("c")
str.WriteString(strconv.Itoa(i + 1))
}
str.WriteString(" TEXT")
sep = ","
}
for i := len(row); i < columns; i++ {
str.WriteString(sep)
str.WriteByte('c')
str.WriteString("c")
str.WriteString(strconv.Itoa(i + 1))
str.WriteString(" TEXT")
sep = ","
}
str.WriteByte(')')

View File

@@ -3,16 +3,20 @@ package csv
import "testing"
func Test_getSchema(t *testing.T) {
t.Parallel()
tests := []struct {
header bool
columns int
row []string
want string
}{
{true, 2, nil, `CREATE TABLE x(c1,c2)`},
{false, 2, nil, `CREATE TABLE x(c1,c2)`},
{true, 3, []string{"abc", ""}, `CREATE TABLE x("abc",c2,c3)`},
{true, 1, []string{"abc", "def"}, `CREATE TABLE x("abc")`},
{true, 2, nil, `CREATE TABLE x(c1 TEXT,c2 TEXT)`},
{false, 2, nil, `CREATE TABLE x(c1 TEXT,c2 TEXT)`},
{false, -1, []string{"abc", ""}, `CREATE TABLE x(c1 TEXT,c2 TEXT)`},
{true, 3, []string{"abc", ""}, `CREATE TABLE x("abc" TEXT,c2 TEXT,c3 TEXT)`},
{true, -1, []string{"abc", "def"}, `CREATE TABLE x("abc" TEXT,"def" TEXT)`},
{true, 1, []string{"abc", "def"}, `CREATE TABLE x("abc" TEXT)`},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {

View File

@@ -1,4 +1,4 @@
Date,USD,JPY,BGN,CYP,CZK,DKK,EEK,GBP,HUF,LTL,LVL,MTL,PLN,ROL,RON,SEK,SIT,SKK,CHF,ISK,NOK,HRK,RUB,TRL,TRY,AUD,BRL,CAD,CNY,HKD,IDR,ILS,INR,KRW,MXN,MYR,NZD,PHP,SGD,THB,ZAR,
Date,USD,JPY,BGN,CYP,CZK,DKK,EEK,GBP,HUF,LTL,LVL,MTL,PLN,ROL,RON,SEK,SIT,SKK,CHF,ISK,NOK,HRK,RUB,TRL,TRY,AUD,BRL,CAD,CNY,HKD,IDR,ILS,INR,KRW,MXN,MYR,NZD,PHP,SGD,THB,ZAR,
2022-12-30,1.0666,140.66,1.9558,N/A,24.116,7.4365,N/A,0.88693,400.87,N/A,N/A,N/A,4.6808,N/A,4.9495,11.1218,N/A,N/A,0.9847,151.5,10.5138,7.5365,N/A,N/A,19.9649,1.5693,5.6386,1.444,7.3582,8.3163,16519.82,3.7554,88.171,1344.09,20.856,4.6984,1.6798,59.32,1.43,36.835,18.0986,
2022-12-29,1.0649,142.24,1.9558,N/A,24.191,7.4365,N/A,0.88549,399.6,N/A,N/A,N/A,4.6855,N/A,4.9493,11.158,N/A,N/A,0.984,152.5,10.55,7.5365,N/A,N/A,19.934,1.5859,5.5351,1.4475,7.4151,8.2994,16680.38,3.7575,88.2295,1350.18,20.651,4.7106,1.6887,59.367,1.436,36.877,18.1967,
2022-12-28,1.064,142.21,1.9558,N/A,24.252,7.4365,N/A,0.88058,403.3,N/A,N/A,N/A,4.7008,N/A,4.946,11.1038,N/A,N/A,0.9863,151.9,10.4495,7.5365,N/A,N/A,19.9144,1.566,5.6109,1.4361,7.4224,8.2931,16765.93,3.7526,88.0943,1348.59,20.6856,4.7055,1.6772,59.613,1.4323,36.953,18.289,
1 Date USD JPY BGN CYP CZK DKK EEK GBP HUF LTL LVL MTL PLN ROL RON SEK SIT SKK CHF ISK NOK HRK RUB TRL TRY AUD BRL CAD CNY HKD IDR ILS INR KRW MXN MYR NZD PHP SGD THB ZAR
2 2022-12-30 1.0666 140.66 1.9558 N/A 24.116 7.4365 N/A 0.88693 400.87 N/A N/A N/A 4.6808 N/A 4.9495 11.1218 N/A N/A 0.9847 151.5 10.5138 7.5365 N/A N/A 19.9649 1.5693 5.6386 1.444 7.3582 8.3163 16519.82 3.7554 88.171 1344.09 20.856 4.6984 1.6798 59.32 1.43 36.835 18.0986
3 2022-12-29 1.0649 142.24 1.9558 N/A 24.191 7.4365 N/A 0.88549 399.6 N/A N/A N/A 4.6855 N/A 4.9493 11.158 N/A N/A 0.984 152.5 10.55 7.5365 N/A N/A 19.934 1.5859 5.5351 1.4475 7.4151 8.2994 16680.38 3.7575 88.2295 1350.18 20.651 4.7106 1.6887 59.367 1.436 36.877 18.1967
4 2022-12-28 1.064 142.21 1.9558 N/A 24.252 7.4365 N/A 0.88058 403.3 N/A N/A N/A 4.7008 N/A 4.946 11.1038 N/A N/A 0.9863 151.9 10.4495 7.5365 N/A N/A 19.9144 1.566 5.6109 1.4361 7.4224 8.2931 16765.93 3.7526 88.0943 1348.59 20.6856 4.7055 1.6772 59.613 1.4323 36.953 18.289

59
ext/fileio/fileio.go Normal file
View File

@@ -0,0 +1,59 @@
// Package fileio provides SQL functions to read, write and list files.
//
// https://sqlite.org/src/doc/tip/ext/misc/fileio.c
package fileio
import (
"errors"
"fmt"
"io/fs"
"os"
"github.com/ncruces/go-sqlite3"
)
// Register registers SQL functions readfile, writefile, lsmode,
// and the table-valued function fsdir.
func Register(db *sqlite3.Conn) {
RegisterFS(db, nil)
}
// Register registers SQL functions readfile, lsmode,
// and the table-valued function fsdir;
// fsys will be used to read files and list directories.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
db.CreateFunction("lsmode", 1, 0, lsmode)
db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys))
if fsys == nil {
db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile)
}
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
return fsdir{fsys}, err
})
}
func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultText(fs.FileMode(arg[0].Int()).String())
}
func readfile(fsys fs.FS) func(ctx sqlite3.Context, arg ...sqlite3.Value) {
return func(ctx sqlite3.Context, arg ...sqlite3.Value) {
var err error
var data []byte
if fsys != nil {
data, err = fs.ReadFile(fsys, arg[0].Text())
} else {
data, err = os.ReadFile(arg[0].Text())
}
switch {
case err == nil:
ctx.ResultBlob(data)
case !errors.Is(err, fs.ErrNotExist):
ctx.ResultError(fmt.Errorf("readfile: %w", err))
}
}
}

80
ext/fileio/fileio_test.go Normal file
View File

@@ -0,0 +1,80 @@
package fileio_test
import (
"bytes"
"database/sql"
"io/fs"
"os"
"testing"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/fileio"
)
func Test_lsmode(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
fileio.Register(c)
return nil
})
if err != nil {
t.Fatal(err)
}
defer db.Close()
d, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
s, err := os.Stat(d)
if err != nil {
t.Fatal(err)
}
var mode string
err = db.QueryRow(`SELECT lsmode(?)`, s.Mode()).Scan(&mode)
if err != nil {
t.Fatal(err)
}
if len(mode) != 10 || mode[0] != 'd' {
t.Errorf("got %s", mode)
} else {
t.Logf("got %s", mode)
}
}
func Test_readfile(t *testing.T) {
t.Parallel()
for _, fsys := range []fs.FS{nil, os.DirFS(".")} {
t.Run("", func(t *testing.T) {
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
fileio.RegisterFS(c, fsys)
return nil
})
if err != nil {
t.Fatal(err)
}
defer db.Close()
rows, err := db.Query(`SELECT readfile('fileio_test.go')`)
if err != nil {
t.Fatal(err)
}
if rows.Next() {
var data sql.RawBytes
rows.Scan(&data)
if !bytes.HasPrefix(data, []byte("package fileio_test")) {
t.Errorf("got %s", data[:min(64, len(data))])
}
}
})
}
}

186
ext/fileio/fsdir.go Normal file
View File

@@ -0,0 +1,186 @@
package fileio
import (
"io/fs"
"os"
"path"
"path/filepath"
"strings"
"github.com/ncruces/go-sqlite3"
)
type fsdir struct{ fsys fs.FS }
func (d fsdir) BestIndex(idx *sqlite3.IndexInfo) error {
var root, base bool
for i, cst := range idx.Constraint {
switch cst.Column {
case 4: // root
if !cst.Usable || cst.Op != sqlite3.INDEX_CONSTRAINT_EQ {
return sqlite3.CONSTRAINT
}
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
Omit: true,
ArgvIndex: 1,
}
root = true
case 5: // base
if !cst.Usable || cst.Op != sqlite3.INDEX_CONSTRAINT_EQ {
return sqlite3.CONSTRAINT
}
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
Omit: true,
ArgvIndex: 2,
}
base = true
}
}
if !root {
return sqlite3.CONSTRAINT
}
if base {
idx.EstimatedCost = 10
} else {
idx.EstimatedCost = 100
}
return nil
}
func (d fsdir) Open() (sqlite3.VTabCursor, error) {
return &cursor{fsdir: d}, nil
}
type cursor struct {
fsdir
curr entry
next chan entry
done chan struct{}
base string
rowID int64
eof bool
}
type entry struct {
fs.DirEntry
err error
path string
}
func (c *cursor) Close() error {
if c.done != nil {
close(c.done)
s := <-c.next
c.done = nil
c.next = nil
return s.err
}
return nil
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
if err := c.Close(); err != nil {
return err
}
root := arg[0].Text()
if len(arg) > 1 {
base := arg[1].Text()
if c.fsys != nil {
root = path.Join(base, root)
base = path.Clean(base) + "/"
} else {
root = filepath.Join(base, root)
base = filepath.Clean(base) + string(filepath.Separator)
}
c.base = base
}
c.rowID = 0
c.eof = false
c.next = make(chan entry)
c.done = make(chan struct{})
go c.WalkDir(root)
return c.Next()
}
func (c *cursor) Next() error {
curr, ok := <-c.next
c.curr = curr
c.eof = !ok
c.rowID++
return c.curr.err
}
func (c *cursor) EOF() bool {
return c.eof
}
func (c *cursor) RowID() (int64, error) {
return c.rowID, nil
}
func (c *cursor) Column(ctx *sqlite3.Context, n int) error {
switch n {
case 0: // name
name := strings.TrimPrefix(c.curr.path, c.base)
ctx.ResultText(name)
case 1: // mode
i, err := c.curr.Info()
if err != nil {
return err
}
ctx.ResultInt64(int64(i.Mode()))
case 2: // mtime
i, err := c.curr.Info()
if err != nil {
return err
}
ctx.ResultTime(i.ModTime(), sqlite3.TimeFormatUnixFrac)
case 3: // data
switch typ := c.curr.Type(); {
case typ.IsRegular():
var data []byte
var err error
if c.fsys != nil {
data, err = fs.ReadFile(c.fsys, c.curr.path)
} else {
data, err = os.ReadFile(c.curr.path)
}
if err != nil {
return err
}
ctx.ResultBlob(data)
case typ&fs.ModeSymlink != 0 && c.fsys == nil:
t, err := os.Readlink(c.curr.path)
if err != nil {
return err
}
ctx.ResultText(t)
}
}
return nil
}
func (c *cursor) WalkDir(path string) {
defer close(c.next)
if c.fsys != nil {
fs.WalkDir(c.fsys, path, c.WalkDirFunc)
} else {
filepath.WalkDir(path, c.WalkDirFunc)
}
}
func (c *cursor) WalkDirFunc(path string, d fs.DirEntry, err error) error {
select {
case <-c.done:
return fs.SkipAll
case c.next <- entry{d, err, path}:
return nil
}
}

78
ext/fileio/fsdir_test.go Normal file
View File

@@ -0,0 +1,78 @@
package fileio_test
import (
"bytes"
"database/sql"
"io/fs"
"os"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/fileio"
)
func Test_fsdir(t *testing.T) {
t.Parallel()
for _, fsys := range []fs.FS{nil, os.DirFS(".")} {
t.Run("", func(t *testing.T) {
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
fileio.RegisterFS(c, fsys)
return nil
})
if err != nil {
t.Fatal(err)
}
defer db.Close()
rows, err := db.Query(`SELECT * FROM fsdir('.', '.') LIMIT 4`)
if err != nil {
t.Fatal(err)
}
for rows.Next() {
var name string
var mode fs.FileMode
var mtime time.Time
var data sql.RawBytes
err := rows.Scan(&name, &mode, sqlite3.TimeFormatUnixFrac.Scanner(&mtime), &data)
if err != nil {
t.Fatal(err)
}
if mode.Perm() == 0 {
t.Errorf("got: %v", mode)
}
if mtime.Before(time.Unix(0, 0)) {
t.Errorf("got: %v", mtime)
}
if name == "fsdir_test.go" {
if !bytes.HasPrefix(data, []byte("package fileio_test")) {
t.Errorf("got: %s", data[:min(64, len(data))])
}
}
}
})
}
}
func Test_fsdir_errors(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
fileio.Register(db)
err = db.Exec(`SELECT name FROM fsdir()`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
}

97
ext/fileio/write.go Normal file
View File

@@ -0,0 +1,97 @@
package fileio
import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/fsutil"
)
func writefile(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) < 2 || len(arg) > 4 {
ctx.ResultError(util.ErrorString("writefile: wrong number of arguments"))
return
}
file := arg[0].Text()
var mode fs.FileMode
if len(arg) > 2 {
mode = fsutil.FileModeFromValue(arg[2])
}
n, err := createFileAndDir(file, mode, arg[1])
if err != nil {
if len(arg) > 2 {
ctx.ResultError(fmt.Errorf("writefile: %w", err))
}
return
}
if mode&fs.ModeSymlink == 0 {
if len(arg) > 2 {
err := os.Chmod(file, mode.Perm())
if err != nil {
ctx.ResultError(fmt.Errorf("writefile: %w", err))
return
}
}
if len(arg) > 3 {
mtime := arg[3].Time(sqlite3.TimeFormatUnixFrac)
err := os.Chtimes(file, time.Time{}, mtime)
if err != nil {
ctx.ResultError(fmt.Errorf("writefile: %w", err))
return
}
}
}
if mode.IsRegular() {
ctx.ResultInt(n)
}
}
func createFileAndDir(path string, mode fs.FileMode, data sqlite3.Value) (int, error) {
n, err := createFile(path, mode, data)
if errors.Is(err, fs.ErrNotExist) {
if err := os.MkdirAll(filepath.Dir(path), 0777); err == nil {
return createFile(path, mode, data)
}
}
return n, err
}
func createFile(path string, mode fs.FileMode, data sqlite3.Value) (int, error) {
if mode.IsRegular() {
blob := data.RawBlob()
return len(blob), os.WriteFile(path, blob, fixPerm(mode, 0666))
}
if mode.IsDir() {
err := os.Mkdir(path, fixPerm(mode, 0777))
if errors.Is(err, fs.ErrExist) {
s, err := os.Lstat(path)
if err == nil && s.IsDir() {
return 0, nil
}
}
return 0, err
}
if mode&fs.ModeSymlink != 0 {
return 0, os.Symlink(data.Text(), path)
}
return 0, fmt.Errorf("invalid mode: %v", mode)
}
func fixPerm(mode fs.FileMode, def fs.FileMode) fs.FileMode {
if mode.Perm() == 0 {
return def
}
return mode.Perm()
}

92
ext/fileio/write_test.go Normal file
View File

@@ -0,0 +1,92 @@
package fileio
import (
"database/sql"
"io/fs"
"path/filepath"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
func Test_writefile(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
Register(c)
return nil
})
if err != nil {
t.Fatal(err)
}
defer db.Close()
dir := t.TempDir()
link := filepath.Join(dir, "link")
file := filepath.Join(dir, "test.txt")
nest := filepath.Join(dir, "tmp", "test.txt")
sock := filepath.Join(dir, "sock")
twosday := time.Date(2022, 2, 22, 22, 22, 22, 0, time.UTC)
_, err = db.Exec(`SELECT writefile(?, 'Hello world!')`, file)
if err != nil {
t.Fatal(err)
}
_, err = db.Exec(`SELECT writefile(?, ?, ?)`, link, "test.txt", fs.ModeSymlink)
if err != nil {
t.Fatal(err)
}
_, err = db.Exec(`SELECT writefile(?, ?, ?, ?)`, dir, nil, 0040700, twosday.Unix())
if err != nil {
t.Fatal(err)
}
rows, err := db.Query(`SELECT * FROM fsdir('.', ?)`, dir)
if err != nil {
t.Fatal(err)
}
for rows.Next() {
var name string
var mode fs.FileMode
var mtime time.Time
var data sql.NullString
err := rows.Scan(&name, &mode, &mtime, &data)
if err != nil {
t.Fatal(err)
}
if mode.IsDir() && !mtime.Equal(twosday) {
t.Errorf("got: %v", mtime)
}
if mode.IsRegular() && data.String != "Hello world!" {
t.Errorf("got: %v", data)
}
if mode&fs.ModeSymlink != 0 && data.String != "test.txt" {
t.Errorf("got: %v", data)
}
}
_, err = db.Exec(`SELECT writefile(?, 'Hello world!')`, nest)
if err != nil {
t.Fatal(err)
}
_, err = db.Exec(`SELECT writefile(?, ?, ?)`, sock, nil, fs.ModeSocket)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
_, err = db.Exec(`SELECT writefile()`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
}

30
ext/hash/blake2.go Normal file
View File

@@ -0,0 +1,30 @@
package hash
import (
"crypto"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
func blake2sFunc(ctx sqlite3.Context, arg ...sqlite3.Value) {
hashFunc(ctx, arg[0], crypto.BLAKE2s_256)
}
func blake2bFunc(ctx sqlite3.Context, arg ...sqlite3.Value) {
size := 512
if len(arg) > 1 {
size = arg[1].Int()
}
switch size {
case 256:
hashFunc(ctx, arg[0], crypto.BLAKE2b_256)
case 384:
hashFunc(ctx, arg[0], crypto.BLAKE2b_384)
case 512:
hashFunc(ctx, arg[0], crypto.BLAKE2b_512)
default:
ctx.ResultError(util.ErrorString("blake2b: size must be 256, 384, 512"))
}
}

97
ext/hash/hash.go Normal file
View File

@@ -0,0 +1,97 @@
// Package hash provides cryptographic hash functions.
//
// Provided functions:
// - md4(data)
// - md5(data)
// - sha1(data)
// - sha3(data, size) (default size 256)
// - sha224(data)
// - sha256(data, size) (default size 256)
// - sha384(data)
// - sha512(data, size) (default size 512)
// - blake2s(data)
// - blake2b(data, size) (default size 512)
// - ripemd160(data)
//
// Each SQL function will only be registered if the corresponding
// [crypto.Hash] function is available.
// To ensure a specific hash function is available,
// import the implementing package.
package hash
import (
"crypto"
"github.com/ncruces/go-sqlite3"
)
// Register registers cryptographic hash functions for a database connection.
func Register(db *sqlite3.Conn) {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
if crypto.MD4.Available() {
db.CreateFunction("md4", 1, flags, md4Func)
}
if crypto.MD5.Available() {
db.CreateFunction("md5", 1, flags, md5Func)
}
if crypto.SHA1.Available() {
db.CreateFunction("sha1", 1, flags, sha1Func)
}
if crypto.SHA3_512.Available() {
db.CreateFunction("sha3", 1, flags, sha3Func)
db.CreateFunction("sha3", 2, flags, sha3Func)
}
if crypto.SHA256.Available() {
db.CreateFunction("sha224", 1, flags, sha224Func)
db.CreateFunction("sha256", 1, flags, sha256Func)
db.CreateFunction("sha256", 2, flags, sha256Func)
}
if crypto.SHA512.Available() {
db.CreateFunction("sha384", 1, flags, sha384Func)
db.CreateFunction("sha512", 1, flags, sha512Func)
db.CreateFunction("sha512", 2, flags, sha512Func)
}
if crypto.BLAKE2s_256.Available() {
db.CreateFunction("blake2s", 1, flags, blake2sFunc)
}
if crypto.BLAKE2b_512.Available() {
db.CreateFunction("blake2b", 1, flags, blake2bFunc)
db.CreateFunction("blake2b", 2, flags, blake2bFunc)
}
if crypto.RIPEMD160.Available() {
db.CreateFunction("ripemd160", 1, flags, ripemd160Func)
}
}
func md4Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
hashFunc(ctx, arg[0], crypto.MD4)
}
func md5Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
hashFunc(ctx, arg[0], crypto.MD5)
}
func sha1Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
hashFunc(ctx, arg[0], crypto.SHA1)
}
func ripemd160Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
hashFunc(ctx, arg[0], crypto.RIPEMD160)
}
func hashFunc(ctx sqlite3.Context, arg sqlite3.Value, fn crypto.Hash) {
var data []byte
switch arg.Type() {
case sqlite3.NULL:
return
case sqlite3.BLOB:
data = arg.RawBlob()
default:
data = arg.RawText()
}
h := fn.New()
h.Write(data)
ctx.ResultBlob(h.Sum(nil))
}

98
ext/hash/hash_test.go Normal file
View File

@@ -0,0 +1,98 @@
package hash
import (
_ "crypto/md5"
_ "crypto/sha1"
_ "crypto/sha256"
_ "crypto/sha512"
"testing"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "golang.org/x/crypto/blake2b"
_ "golang.org/x/crypto/blake2s"
_ "golang.org/x/crypto/md4"
_ "golang.org/x/crypto/ripemd160"
_ "golang.org/x/crypto/sha3"
)
func TestRegister(t *testing.T) {
t.Parallel()
tests := []struct {
name string
hash string
}{
{"md4(NULL)", ""},
{"md4(X'')", "31D6CFE0D16AE931B73C59D7E0C089C0"},
{"md4('The quick brown fox jumps over the lazy dog')", "1BEE69A46BA811185C194762ABAEAE90"},
{"md5('')", "D41D8CD98F00B204E9800998ECF8427E"},
{"sha1('')", "DA39A3EE5E6B4B0D3255BFEF95601890AFD80709"},
{"ripemd160('')", "9C1185A5C5E9FC54612808977EE8F548B2258D31"},
{"sha224('')", "D14A028C2A3A2BC9476102BB288234C415A2B01F828EA62AC5B3E42F"},
{"sha256('')", "E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855"},
{"sha256('', 224)", "D14A028C2A3A2BC9476102BB288234C415A2B01F828EA62AC5B3E42F"},
{"sha384('')", "38B060A751AC96384CD9327EB1B1E36A21FDB71114BE07434C0CC7BF63F6E1DA274EDEBFE76F65FBD51AD2F14898B95B"},
{"sha512('')", "CF83E1357EEFB8BDF1542850D66D8007D620E4050B5715DC83F4A921D36CE9CE47D0D13C5D85F2B0FF8318D2877EEC2F63B931BD47417A81A538327AF927DA3E"},
{"sha512('', 224)", "6ED0DD02806FA89E25DE060C19D3AC86CABB87D6A0DDD05C333B84F4"},
{"sha512('', 256)", "C672B8D1EF56ED28AB87C3622C5114069BDD3AD7B8F9737498D0C01ECEF0967A"},
{"sha512('', 384)", "38B060A751AC96384CD9327EB1B1E36A21FDB71114BE07434C0CC7BF63F6E1DA274EDEBFE76F65FBD51AD2F14898B95B"},
{"sha3('')", "A7FFC6F8BF1ED76651C14756A061D662F580FF4DE43B49FA82D80A4B80F8434A"},
{"sha3('', 224)", "6B4E03423667DBB73B6E15454F0EB1ABD4597F9A1B078E3F5B5A6BC7"},
{"sha3('', 384)", "0C63A75B845E4F7D01107D852E4C2485C51A50AAAA94FC61995E71BBEE983A2AC3713831264ADB47FB6BD1E058D5F004"},
{"sha3('', 512)", "A69F73CCA23A9AC5C8B567DC185A756E97C982164FE25859E0D1DCC1475C80A615B2123AF1F5F94C11E3E9402C3AC558F500199D95B6D3E301758586281DCD26"},
{"blake2s('')", "69217A3079908094E11121D042354A7C1F55B6482CA1A51E1B250DFD1ED0EEF9"},
{"blake2b('')", "786A02F742015903C6C6FD852552D272912F4740E15847618A86E217F71F5419D25E1031AFEE585313896444934EB04B903A685B1448B755D56F701AFE9BE2CE"},
{"blake2b('', 384)", "B32811423377F52D7862286EE1A72EE540524380FDA1724A6F25D7978C6FD3244A6CAF0498812673C5E05EF583825100"},
{"blake2b('', 256)", "0E5751C026E543B2E8AB2EB06099DAA1D1E5DF47778F7787FAAB45CDF12FE3A8"},
}
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
Register(c)
return nil
})
if err != nil {
t.Fatal(err)
}
defer db.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var hash string
err = db.QueryRow(`SELECT hex(` + tt.name + `)`).Scan(&hash)
if err != nil {
t.Fatal(err)
}
if hash != tt.hash {
t.Errorf("got %s, want %s", hash, tt.hash)
}
})
}
_, err = db.Exec(`SELECT sha256('', 255)`)
if err == nil {
t.Error("want error")
}
_, err = db.Exec(`SELECT sha512('', 255)`)
if err == nil {
t.Error("want error")
}
_, err = db.Exec(`SELECT sha3('', 255)`)
if err == nil {
t.Error("want error")
}
_, err = db.Exec(`SELECT blake2b('', 255)`)
if err == nil {
t.Error("want error")
}
}

53
ext/hash/sha2.go Normal file
View File

@@ -0,0 +1,53 @@
package hash
import (
"crypto"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
func sha224Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
hashFunc(ctx, arg[0], crypto.SHA224)
}
func sha384Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
hashFunc(ctx, arg[0], crypto.SHA384)
}
func sha256Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
size := 256
if len(arg) > 1 {
size = arg[1].Int()
}
switch size {
case 224:
hashFunc(ctx, arg[0], crypto.SHA224)
case 256:
hashFunc(ctx, arg[0], crypto.SHA256)
default:
ctx.ResultError(util.ErrorString("sha256: size must be 224, 256"))
}
}
func sha512Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
size := 512
if len(arg) > 1 {
size = arg[1].Int()
}
switch size {
case 224:
hashFunc(ctx, arg[0], crypto.SHA512_224)
case 256:
hashFunc(ctx, arg[0], crypto.SHA512_256)
case 384:
hashFunc(ctx, arg[0], crypto.SHA384)
case 512:
hashFunc(ctx, arg[0], crypto.SHA512)
default:
ctx.ResultError(util.ErrorString("sha512: size must be 224, 256, 384, 512"))
}
}

28
ext/hash/sha3.go Normal file
View File

@@ -0,0 +1,28 @@
package hash
import (
"crypto"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
func sha3Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
size := 256
if len(arg) > 1 {
size = arg[1].Int()
}
switch size {
case 224:
hashFunc(ctx, arg[0], crypto.SHA3_224)
case 256:
hashFunc(ctx, arg[0], crypto.SHA3_256)
case 384:
hashFunc(ctx, arg[0], crypto.SHA3_384)
case 512:
hashFunc(ctx, arg[0], crypto.SHA3_512)
default:
ctx.ResultError(util.ErrorString("sha3: size must be 224, 256, 384, 512"))
}
}

View File

@@ -1,4 +1,13 @@
// Package lines provides a virtual table to read large files line-by-line.
// Package lines provides a virtual table to read data line-by-line.
//
// It is particularly useful for line-oriented datasets,
// like [ndjson] or [JSON Lines],
// when paired with SQLite's JSON support.
//
// https://github.com/asg017/sqlite-lines
//
// [ndjson]: https://ndjson.org/
// [JSON Lines]: https://jsonlines.org/
package lines
import (
@@ -6,31 +15,42 @@ import (
"bytes"
"fmt"
"io"
"math"
"os"
"io/fs"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/util/osutil"
)
// Register registers the lines and lines_read virtual tables.
// The lines virtual table reads from a database blob or text.
// The lines_read virtual table reads from a file or an [io.ReaderAt].
// Register registers the lines and lines_read table-valued functions.
// The lines function reads from a database blob or text.
// The lines_read function reads from a file or an [io.Reader].
// If a filename is specified, [os.Open] is used to open the file.
func Register(db *sqlite3.Conn) {
RegisterFS(db, osutil.FS{})
}
// RegisterFS registers the lines and lines_read table-valued functions.
// The lines function reads from a database blob or text.
// The lines_read function reads from a file or an [io.Reader].
// If a filename is specified, fsys is used to open the file.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
sqlite3.CreateModule[lines](db, "lines", nil,
func(db *sqlite3.Conn, arg ...string) (lines, error) {
err := db.DeclareVtab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
db.VtabConfig(sqlite3.VTAB_INNOCUOUS)
return false, err
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_INNOCUOUS)
return lines{}, err
})
sqlite3.CreateModule[lines](db, "lines_read", nil,
func(db *sqlite3.Conn, arg ...string) (lines, error) {
err := db.DeclareVtab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
db.VtabConfig(sqlite3.VTAB_DIRECTONLY)
return true, err
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
return lines{fsys}, err
})
}
type lines bool
type lines struct {
fsys fs.FS
}
func (l lines) BestIndex(idx *sqlite3.IndexInfo) error {
for i, cst := range idx.Constraint {
@@ -48,74 +68,126 @@ func (l lines) BestIndex(idx *sqlite3.IndexInfo) error {
}
func (l lines) Open() (sqlite3.VTabCursor, error) {
return &cursor{reader: bool(l)}, nil
if l.fsys != nil {
return &reader{fsys: l.fsys}, nil
} else {
return &buffer{}, nil
}
}
type cursor struct {
reader bool
scanner *bufio.Scanner
closer io.Closer
rowID int64
eof bool
}
func (c *cursor) Close() (err error) {
if c.closer != nil {
err = c.closer.Close()
c.closer = nil
}
return err
line []byte
rowID int64
eof bool
}
func (c *cursor) EOF() bool {
return c.eof
}
func (c *cursor) Next() error {
c.rowID++
c.eof = !c.scanner.Scan()
return c.scanner.Err()
}
func (c *cursor) RowID() (int64, error) {
return c.rowID, nil
}
func (c *cursor) Column(ctx *sqlite3.Context, n int) error {
if n == 0 {
ctx.ResultRawText(c.scanner.Bytes())
ctx.ResultRawText(c.line)
}
return nil
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
type reader struct {
fsys fs.FS
reader *bufio.Reader
closer io.Closer
cursor
}
func (c *reader) Close() (err error) {
if c.closer != nil {
err = c.closer.Close()
c.closer = nil
}
return err
}
func (c *reader) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
if err := c.Close(); err != nil {
return err
}
var r io.Reader
data := arg[0]
if c.reader {
if data.Type() == sqlite3.NULL {
if p, ok := data.Pointer().(io.ReaderAt); ok {
r = io.NewSectionReader(p, 0, math.MaxInt64)
}
} else {
f, err := os.Open(data.Text())
if err != nil {
return err
}
c.closer = f
r = f
typ := arg[0].Type()
switch typ {
case sqlite3.NULL:
if p, ok := arg[0].Pointer().(io.Reader); ok {
r = p
}
} else if data.Type() != sqlite3.NULL {
r = bytes.NewReader(data.RawBlob())
case sqlite3.TEXT:
f, err := c.fsys.Open(arg[0].Text())
if err != nil {
return err
}
r = f
}
if r == nil {
return fmt.Errorf("lines: unsupported argument:%.0w %v", sqlite3.MISMATCH, typ)
}
if r == nil {
return fmt.Errorf("lines: unsupported argument:%.0w %v", sqlite3.MISMATCH, data.Type())
}
c.scanner = bufio.NewScanner(r)
c.reader = bufio.NewReader(r)
c.closer, _ = r.(io.Closer)
c.rowID = 0
return c.Next()
}
func (c *reader) Next() (err error) {
c.line = c.line[:0]
for more := true; more; {
var line []byte
line, more, err = c.reader.ReadLine()
c.line = append(c.line, line...)
}
if err == io.EOF {
c.eof = true
err = nil
}
c.rowID++
return err
}
type buffer struct {
data []byte
cursor
}
func (c *buffer) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
typ := arg[0].Type()
switch typ {
case sqlite3.TEXT:
c.data = arg[0].RawText()
case sqlite3.BLOB:
c.data = arg[0].RawBlob()
default:
return fmt.Errorf("lines: unsupported argument:%.0w %v", sqlite3.MISMATCH, typ)
}
c.rowID = 0
return c.Next()
}
func (c *buffer) Next() error {
i := bytes.IndexByte(c.data, '\n')
j := i + 1
switch {
case i < 0:
i = len(c.data)
j = i
case i > 0 && c.data[i-1] == '\r':
i--
}
c.eof = len(c.data) == 0
c.line = c.data[:i]
c.data = c.data[j:]
c.rowID++
return nil
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log"
"net/http"
"os"
"strings"
"testing"
@@ -25,12 +26,11 @@ func Example() {
}
defer db.Close()
// https://storage.googleapis.com/quickdraw_dataset/full/simplified/calendar.ndjson
f, err := os.Open("calendar.ndjson")
res, err := http.Get("https://storage.googleapis.com/quickdraw_dataset/full/simplified/calendar.ndjson")
if err != nil {
log.Fatal(err)
}
defer f.Close()
defer res.Body.Close()
rows, err := db.Query(`
SELECT
@@ -40,7 +40,7 @@ func Example() {
GROUP BY 1
ORDER BY 2 DESC
LIMIT 5`,
sqlite3.Pointer(f))
sqlite3.Pointer(res.Body))
if err != nil {
log.Fatal(err)
}
@@ -58,7 +58,7 @@ func Example() {
if err := rows.Err(); err != nil {
log.Fatal(err)
}
// Sample output:
// Output:
// US: 141001
// GB: 22560
// CA: 11759
@@ -67,6 +67,8 @@ func Example() {
}
func Test_lines(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
lines.Register(c)
return nil
@@ -76,7 +78,7 @@ func Test_lines(t *testing.T) {
}
defer db.Close()
const data = "line 1\nline 2\nline 3"
const data = "line 1\nline 2\r\nline 3\n"
rows, err := db.Query(`SELECT rowid, line FROM lines(?)`, data)
if err != nil {
@@ -91,10 +93,15 @@ func Test_lines(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if want := fmt.Sprintf("line %d", id); line != want {
t.Errorf("got %q, want %q", line, want)
}
}
}
func Test_lines_error(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
lines.Register(c)
return nil
@@ -120,6 +127,8 @@ func Test_lines_error(t *testing.T) {
}
func Test_lines_read(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
lines.Register(c)
return nil
@@ -129,7 +138,7 @@ func Test_lines_read(t *testing.T) {
}
defer db.Close()
const data = "line 1\nline 2\nline 3"
const data = "line 1\nline 2\r\nline 3\n"
rows, err := db.Query(`SELECT rowid, line FROM lines_read(?)`,
sqlite3.Pointer(strings.NewReader(data)))
@@ -145,10 +154,15 @@ func Test_lines_read(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if want := fmt.Sprintf("line %d", id); line != want {
t.Errorf("got %q, want %q", line, want)
}
}
}
func Test_lines_test(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
lines.Register(c)
return nil

34
ext/pivot/op_test.go Normal file
View File

@@ -0,0 +1,34 @@
package pivot
import (
"testing"
"github.com/ncruces/go-sqlite3"
)
func Test_operator(t *testing.T) {
tests := []struct {
op sqlite3.IndexConstraintOp
want string
}{
{sqlite3.INDEX_CONSTRAINT_EQ, "="},
{sqlite3.INDEX_CONSTRAINT_LT, "<"},
{sqlite3.INDEX_CONSTRAINT_GT, ">"},
{sqlite3.INDEX_CONSTRAINT_LE, "<="},
{sqlite3.INDEX_CONSTRAINT_GE, ">="},
{sqlite3.INDEX_CONSTRAINT_NE, "<>"},
{sqlite3.INDEX_CONSTRAINT_IS, "IS"},
{sqlite3.INDEX_CONSTRAINT_ISNOT, "IS NOT"},
{sqlite3.INDEX_CONSTRAINT_REGEXP, "REGEXP"},
{sqlite3.INDEX_CONSTRAINT_MATCH, "MATCH"},
{sqlite3.INDEX_CONSTRAINT_GLOB, "GLOB"},
{sqlite3.INDEX_CONSTRAINT_LIKE, "LIKE"},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
if got := operator(tt.op); got != tt.want {
t.Errorf("operator() = %v, want %v", got, tt.want)
}
})
}
}

274
ext/pivot/pivot.go Normal file
View File

@@ -0,0 +1,274 @@
// Package pivot implements a pivot virtual table.
//
// https://github.com/jakethaw/pivot_vtab
package pivot
import (
"errors"
"fmt"
"strings"
"github.com/ncruces/go-sqlite3"
)
// Register registers the pivot virtual table.
func Register(db *sqlite3.Conn) {
sqlite3.CreateModule(db, "pivot", declare, declare)
}
type table struct {
db *sqlite3.Conn
scan string
cell string
keys []string
cols []*sqlite3.Value
}
func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
if len(arg) != 3 {
return nil, fmt.Errorf("pivot: wrong number of arguments")
}
table := &table{db: db}
defer func() {
if err != nil {
table.Close()
}
}()
var sep string
var create strings.Builder
create.WriteString("CREATE TABLE x(")
// Row key query.
table.scan = "SELECT * FROM\n" + arg[0]
stmt, _, err := db.Prepare(table.scan)
if err != nil {
return nil, err
}
defer stmt.Close()
table.keys = make([]string, stmt.ColumnCount())
for i := range table.keys {
name := sqlite3.QuoteIdentifier(stmt.ColumnName(i))
table.keys[i] = name
create.WriteString(sep)
create.WriteString(name)
sep = ","
}
stmt.Close()
// Column definition query.
stmt, _, err = db.Prepare("SELECT * FROM\n" + arg[1])
if err != nil {
return nil, err
}
if stmt.ColumnCount() != 2 {
return nil, fmt.Errorf("pivot: column definition query expects 2 result columns")
}
for stmt.Step() {
name := sqlite3.QuoteIdentifier(stmt.ColumnText(1))
table.cols = append(table.cols, stmt.ColumnValue(0).Dup())
create.WriteString(",")
create.WriteString(name)
}
stmt.Close()
// Pivot cell query.
table.cell = "SELECT * FROM\n" + arg[2]
stmt, _, err = db.Prepare(table.cell)
if err != nil {
return nil, err
}
if stmt.ColumnCount() != 1 {
return nil, fmt.Errorf("pivot: cell query expects 1 result columns")
}
if stmt.BindCount() != len(table.keys)+1 {
return nil, fmt.Errorf("pivot: cell query expects %d bound parameters", len(table.keys)+1)
}
create.WriteByte(')')
err = db.DeclareVTab(create.String())
if err != nil {
return nil, err
}
return table, nil
}
func (t *table) Close() error {
for i := range t.cols {
t.cols[i].Close()
}
return nil
}
func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
var idxStr strings.Builder
idxStr.WriteString(t.scan)
argvIndex := 1
sep := " WHERE "
for i, cst := range idx.Constraint {
if !cst.Usable || !(0 <= cst.Column && cst.Column < len(t.keys)) {
continue
}
op := operator(cst.Op)
if op == "" {
continue
}
idxStr.WriteString(sep)
idxStr.WriteString(t.keys[cst.Column])
idxStr.WriteString(" ")
idxStr.WriteString(op)
idxStr.WriteString(" ?")
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
ArgvIndex: argvIndex,
Omit: true,
}
sep = " AND "
argvIndex++
}
sep = " ORDER BY "
idx.OrderByConsumed = true
for _, ord := range idx.OrderBy {
if !(0 <= ord.Column && ord.Column < len(t.keys)) {
idx.OrderByConsumed = false
continue
}
idxStr.WriteString(sep)
idxStr.WriteString(t.keys[ord.Column])
idxStr.WriteString(" COLLATE ")
idxStr.WriteString(idx.Collation(ord.Column))
if ord.Desc {
idxStr.WriteString(" DESC")
}
sep = ","
}
idx.EstimatedCost = 1e9 / float64(argvIndex)
idx.IdxStr = idxStr.String()
return nil
}
func (t *table) Open() (sqlite3.VTabCursor, error) {
return &cursor{table: t}, nil
}
func (t *table) Rename(new string) error {
return nil
}
type cursor struct {
table *table
scan *sqlite3.Stmt
cell *sqlite3.Stmt
rowID int64
}
func (c *cursor) Close() error {
return errors.Join(c.scan.Close(), c.cell.Close())
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
err := c.scan.Close()
if err != nil {
return err
}
c.scan, _, err = c.table.db.Prepare(idxStr)
if err != nil {
return err
}
for i, arg := range arg {
err := c.scan.BindValue(i+1, arg)
if err != nil {
return err
}
}
if c.cell == nil {
c.cell, _, err = c.table.db.Prepare(c.table.cell)
if err != nil {
return err
}
}
c.rowID = 0
return c.Next()
}
func (c *cursor) Next() error {
if c.scan.Step() {
count := c.scan.ColumnCount()
for i := 0; i < count; i++ {
err := c.cell.BindValue(i+1, c.scan.ColumnValue(i))
if err != nil {
return err
}
}
c.rowID++
}
return c.scan.Err()
}
func (c *cursor) EOF() bool {
return !c.scan.Busy()
}
func (c *cursor) RowID() (int64, error) {
return c.rowID, nil
}
func (c *cursor) Column(ctx *sqlite3.Context, col int) error {
count := c.scan.ColumnCount()
if col < count {
ctx.ResultValue(c.scan.ColumnValue(col))
return nil
}
err := c.cell.BindValue(count+1, *c.table.cols[col-count])
if err != nil {
return err
}
if c.cell.Step() {
ctx.ResultValue(c.cell.ColumnValue(0))
}
return c.cell.Reset()
}
func operator(op sqlite3.IndexConstraintOp) string {
switch op {
case sqlite3.INDEX_CONSTRAINT_EQ:
return "="
case sqlite3.INDEX_CONSTRAINT_LT:
return "<"
case sqlite3.INDEX_CONSTRAINT_GT:
return ">"
case sqlite3.INDEX_CONSTRAINT_LE:
return "<="
case sqlite3.INDEX_CONSTRAINT_GE:
return ">="
case sqlite3.INDEX_CONSTRAINT_NE:
return "<>"
case sqlite3.INDEX_CONSTRAINT_MATCH:
return "MATCH"
case sqlite3.INDEX_CONSTRAINT_LIKE:
return "LIKE"
case sqlite3.INDEX_CONSTRAINT_GLOB:
return "GLOB"
case sqlite3.INDEX_CONSTRAINT_REGEXP:
return "REGEXP"
case sqlite3.INDEX_CONSTRAINT_IS, sqlite3.INDEX_CONSTRAINT_ISNULL:
return "IS"
case sqlite3.INDEX_CONSTRAINT_ISNOT, sqlite3.INDEX_CONSTRAINT_ISNOTNULL:
return "IS NOT"
default:
return ""
}
}

219
ext/pivot/pivot_test.go Normal file
View File

@@ -0,0 +1,219 @@
package pivot_test
import (
"fmt"
"log"
"strings"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/pivot"
)
// https://antonz.org/sqlite-pivot-table/
func Example() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
pivot.Register(db)
err = db.Exec(`
CREATE TABLE sales(product TEXT, year INT, income DECIMAL);
INSERT INTO sales(product, year, income) VALUES
('alpha', 2020, 100),
('alpha', 2021, 120),
('alpha', 2022, 130),
('alpha', 2023, 140),
('beta', 2020, 10),
('beta', 2021, 20),
('beta', 2022, 40),
('beta', 2023, 80),
('gamma', 2020, 80),
('gamma', 2021, 75),
('gamma', 2022, 78),
('gamma', 2023, 80);
`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`
CREATE VIRTUAL TABLE v_sales USING pivot(
-- rows
(SELECT DISTINCT product FROM sales),
-- columns
(SELECT DISTINCT year, year FROM sales),
-- cells
(SELECT sum(income) FROM sales WHERE product = ? AND year = ?)
)`)
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT * FROM v_sales`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
cols := make([]string, stmt.ColumnCount())
for i := range cols {
cols[i] = stmt.ColumnName(i)
}
fmt.Println(pretty(cols))
for stmt.Step() {
for i := range cols {
cols[i] = stmt.ColumnText(i)
}
fmt.Println(pretty(cols))
}
if err := stmt.Reset(); err != nil {
log.Fatal(err)
}
// Output:
// product 2020 2021 2022 2023
// alpha 100 120 130 140
// beta 10 20 40 80
// gamma 80 75 78 80
}
func TestRegister(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
pivot.Register(db)
err = db.Exec(`
CREATE TABLE r AS
SELECT 1 id UNION SELECT 2 UNION SELECT 3;
CREATE TABLE c(
id INTEGER PRIMARY KEY,
name TEXT
);
INSERT INTO c (name) VALUES
('a'),('b'),('c'),('d');
CREATE TABLE x(
r_id INT,
c_id INT,
val TEXT
);
INSERT INTO x (r_id, c_id, val)
SELECT r.id, c.id, c.name || r.id
FROM c, r;
`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`
CREATE VIRTUAL TABLE v_x USING pivot(
-- rows
(SELECT id r_id FROM r),
-- columns
(SELECT id c_id, name FROM c),
-- cells
(SELECT val FROM x WHERE r_id = ?1 AND c_id = ?2)
)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT * FROM v_x WHERE rowid <> 0 AND r_id <> 1 ORDER BY rowid, r_id DESC LIMIT 1`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnInt(0); got != 3 {
t.Errorf("got %d, want 3", got)
}
}
}
func TestRegister_errors(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
pivot.Register(db)
err = db.Exec(`CREATE VIRTUAL TABLE pivot USING pivot()`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot(SELECT 1, SELECT 2, SELECT 3)`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot((SELECT 1), SELECT 2, SELECT 3)`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot((SELECT 1), (SELECT 2), SELECT 3)`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot((SELECT 1), (SELECT 1, 2), SELECT 3)`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot((SELECT 1), (SELECT 1, 2), (SELECT 3, 4))`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING pivot((SELECT 1), (SELECT 1, 2), (SELECT 3))`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
}
func pretty(cols []string) string {
var buf strings.Builder
for i, s := range cols {
if i != 0 {
buf.WriteByte(' ')
}
for buf.Len()%8 != 0 {
buf.WriteByte(' ')
}
buf.WriteString(s)
}
return buf.String()
}

213
ext/statement/stmt.go Normal file
View File

@@ -0,0 +1,213 @@
// Package statement defines table-valued functions using SQL.
//
// It can be used to create "parametrized views":
// pre-packaged queries that can be parametrized at query execution time.
//
// https://github.com/0x09/sqlite-statement-vtab
package statement
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"unsafe"
"github.com/ncruces/go-sqlite3"
)
// Register registers the statement virtual table.
func Register(db *sqlite3.Conn) {
sqlite3.CreateModule(db, "statement", declare, declare)
}
type table struct {
stmt *sqlite3.Stmt
sql string
inuse bool
}
func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (*table, error) {
if len(arg) != 1 {
return nil, fmt.Errorf("statement: wrong number of arguments")
}
sql := "SELECT * FROM\n" + arg[0]
stmt, _, err := db.Prepare(sql)
if err != nil {
return nil, err
}
var sep string
var str strings.Builder
str.WriteString("CREATE TABLE x(")
outputs := stmt.ColumnCount()
for i := 0; i < outputs; i++ {
name := sqlite3.QuoteIdentifier(stmt.ColumnName(i))
str.WriteString(sep)
str.WriteString(name)
str.WriteString(" ")
str.WriteString(stmt.ColumnDeclType(i))
sep = ","
}
inputs := stmt.BindCount()
for i := 1; i <= inputs; i++ {
str.WriteString(sep)
name := stmt.BindName(i)
if name == "" {
str.WriteString("[")
str.WriteString(strconv.Itoa(i))
str.WriteString("] HIDDEN")
} else {
str.WriteString(sqlite3.QuoteIdentifier(name[1:]))
str.WriteString(" HIDDEN")
}
sep = ","
}
str.WriteByte(')')
err = db.DeclareVTab(str.String())
if err != nil {
stmt.Close()
return nil, err
}
return &table{sql: sql, stmt: stmt}, nil
}
func (t *table) Close() error {
return t.stmt.Close()
}
func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
idx.EstimatedCost = 1000
var argvIndex = 1
var needIndex bool
var listIndex []int
outputs := t.stmt.ColumnCount()
for i, cst := range idx.Constraint {
// Skip if this is a constraint on one of our output columns.
if cst.Column < outputs {
continue
}
// A given query plan is only usable if all provided input columns
// are usable and have equal constraints only.
if !cst.Usable || cst.Op != sqlite3.INDEX_CONSTRAINT_EQ {
return sqlite3.CONSTRAINT
}
// The non-zero argvIdx values must be contiguous.
// If they're not, build a list and serialize it through IdxStr.
nextIndex := cst.Column - outputs + 1
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
ArgvIndex: argvIndex,
Omit: true,
}
if nextIndex != argvIndex {
needIndex = true
}
listIndex = append(listIndex, nextIndex)
argvIndex++
}
if needIndex {
buf, err := json.Marshal(listIndex)
if err != nil {
return err
}
idx.IdxStr = unsafe.String(&buf[0], len(buf))
}
return nil
}
func (t *table) Open() (sqlite3.VTabCursor, error) {
stmt := t.stmt
if !t.inuse {
t.inuse = true
} else {
var err error
stmt, _, err = t.stmt.Conn().Prepare(t.sql)
if err != nil {
return nil, err
}
}
return &cursor{table: t, stmt: stmt}, nil
}
func (t *table) Rename(new string) error {
return nil
}
type cursor struct {
table *table
stmt *sqlite3.Stmt
arg []sqlite3.Value
rowID int64
}
func (c *cursor) Close() error {
if c.stmt == c.table.stmt {
c.table.inuse = false
c.stmt.ClearBindings()
return c.stmt.Reset()
}
return c.stmt.Close()
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
c.arg = arg
c.rowID = 0
c.stmt.ClearBindings()
if err := c.stmt.Reset(); err != nil {
return err
}
var list []int
if idxStr != "" {
buf := unsafe.Slice(unsafe.StringData(idxStr), len(idxStr))
err := json.Unmarshal(buf, &list)
if err != nil {
return err
}
}
for i, arg := range arg {
param := i + 1
if list != nil {
param = list[i]
}
err := c.stmt.BindValue(param, arg)
if err != nil {
return err
}
}
return c.Next()
}
func (c *cursor) Next() error {
if c.stmt.Step() {
c.rowID++
}
return c.stmt.Err()
}
func (c *cursor) EOF() bool {
return !c.stmt.Busy()
}
func (c *cursor) RowID() (int64, error) {
return c.rowID, nil
}
func (c *cursor) Column(ctx *sqlite3.Context, col int) error {
switch outputs := c.stmt.ColumnCount(); {
case col < outputs:
ctx.ResultValue(c.stmt.ColumnValue(col))
case col-outputs < len(c.arg):
ctx.ResultValue(c.arg[col-outputs])
}
return nil
}

145
ext/statement/stmt_test.go Normal file
View File

@@ -0,0 +1,145 @@
package statement_test
import (
"fmt"
"log"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/statement"
)
func Example() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
statement.Register(db)
err = db.Exec(`
CREATE VIRTUAL TABLE split_date USING statement((
SELECT
strftime('%Y', :date) AS year,
strftime('%m', :date) AS month,
strftime('%d', :date) AS day
))`)
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT * FROM split_date('2022-02-22')`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
fmt.Printf("Twosday was %d-%d-%d", stmt.ColumnInt(0), stmt.ColumnInt(1), stmt.ColumnInt(2))
}
if err := stmt.Reset(); err != nil {
log.Fatal(err)
}
// Output:
// Twosday was 2022-2-22
}
func TestRegister(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
statement.Register(db)
err = db.Exec(`
CREATE VIRTUAL TABLE arguments USING statement((SELECT ? AS a, ? AS b, ? AS c))
`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`
SELECT * from arguments WHERE [2] = 'y' AND [3] = 'z'
`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`
CREATE VIRTUAL TABLE hypot USING statement((SELECT sqrt(:x * :x + :y * :y) AS hypotenuse))
`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`
SELECT x, y, * FROM hypot WHERE x = 3 AND y = 4
`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
x := stmt.ColumnInt(0)
y := stmt.ColumnInt(1)
hypot := stmt.ColumnInt(2)
if x != 3 || y != 4 || hypot != 5 {
t.Errorf("hypot(%d, %d) = %d", x, y, hypot)
}
}
}
func TestRegister_errors(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
statement.Register(db)
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING statement()`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING statement(SELECT 1, SELECT 2)`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING statement((SELECT 1, SELECT 2))`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING statement((SELECT 1; SELECT 2))`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING statement((CREATE TABLE x(val)))`)
if err == nil {
t.Fatal("want error")
} else {
t.Log(err)
}
}

47
ext/stats/TODO.md Normal file
View File

@@ -0,0 +1,47 @@
# ANSI SQL Aggregate Functions
https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
## Built in aggregates
- [x] `COUNT(*)`
- [x] `COUNT(expression)`
- [x] `SUM(expression)`
- [x] `AVG(expression)`
- [x] `MIN(expression)`
- [x] `MAX(expression)`
https://sqlite.org/lang_aggfunc.html
## Statistical aggregates
- [x] `STDDEV_POP(expression)`
- [x] `STDDEV_SAMP(expression)`
- [x] `VAR_POP(expression)`
- [x] `VAR_SAMP(expression)`
- [x] `COVAR_POP(dependent, independent)`
- [x] `COVAR_SAMP(dependent, independent)`
- [x] `CORR(dependent, independent)`
## Linear regression aggregates
- [X] `REGR_AVGX(dependent, independent)`
- [X] `REGR_AVGY(dependent, independent)`
- [X] `REGR_SXX(dependent, independent)`
- [X] `REGR_SYY(dependent, independent)`
- [X] `REGR_SXY(dependent, independent)`
- [X] `REGR_COUNT(dependent, independent)`
- [X] `REGR_SLOPE(dependent, independent)`
- [X] `REGR_INTERCEPT(dependent, independent)`
- [X] `REGR_R2(dependent, independent)`
## Set aggregates
- [X] `CUME_DIST() OVER window`
- [X] `RANK() OVER window`
- [X] `DENSE_RANK() OVER window`
- [X] `PERCENT_RANK() OVER window`
- [ ] `PERCENTILE_CONT(percentile) OVER window`
- [ ] `PERCENTILE_DISC(percentile) OVER window`
https://sqlite.org/windowfunctions.html#builtins

View File

@@ -1,6 +1,6 @@
// Package stats provides aggregate functions for statistics.
//
// Functions:
// Provided functions:
// - stddev_pop: population standard deviation
// - stddev_samp: sample standard deviation
// - var_pop: population variance
@@ -8,9 +8,26 @@
// - covar_pop: population covariance
// - covar_samp: sample covariance
// - corr: correlation coefficient
// - regr_r2: correlation coefficient squared
// - regr_avgx: average of the independent variable
// - regr_avgy: average of the dependent variable
// - regr_sxx: sum of the squares of the independent variable
// - regr_syy: sum of the squares of the dependent variable
// - regr_sxy: sum of the products of each pair of variables
// - regr_count: count non-null pairs of variables
// - regr_slope: slope of the least-squares-fit linear equation
// - regr_intercept: y-intercept of the least-squares-fit linear equation
//
// These join the [Built-in Aggregate Functions]:
// - count: count rows/values
// - sum: sum values
// - avg: average value
// - min: minimum value
// - max: maximum value
//
// See: [ANSI SQL Aggregate Functions]
//
// [Built-in Aggregate Functions]: https://sqlite.org/lang_aggfunc.html
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
package stats
@@ -26,6 +43,15 @@ func Register(db *sqlite3.Conn) {
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr))
db.CreateWindowFunction("regr_r2", 2, flags, newCovariance(regr_r2))
db.CreateWindowFunction("regr_sxx", 2, flags, newCovariance(regr_sxx))
db.CreateWindowFunction("regr_syy", 2, flags, newCovariance(regr_syy))
db.CreateWindowFunction("regr_sxy", 2, flags, newCovariance(regr_sxy))
db.CreateWindowFunction("regr_avgx", 2, flags, newCovariance(regr_avgx))
db.CreateWindowFunction("regr_avgy", 2, flags, newCovariance(regr_avgy))
db.CreateWindowFunction("regr_slope", 2, flags, newCovariance(regr_slope))
db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept))
db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count))
}
const (
@@ -34,6 +60,15 @@ const (
stddev_pop
stddev_samp
corr
regr_r2
regr_sxx
regr_syy
regr_sxy
regr_avgx
regr_avgy
regr_slope
regr_intercept
regr_count
)
func newVariance(kind int) func() sqlite3.AggregateFunction {
@@ -61,13 +96,13 @@ func (fn *variance) Value(ctx sqlite3.Context) {
}
func (fn *variance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
if a := arg[0]; a.Type() != sqlite3.NULL {
if a := arg[0]; a.NumericType() != sqlite3.NULL {
fn.enqueue(a.Float())
}
}
func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
if a := arg[0]; a.Type() != sqlite3.NULL {
if a := arg[0]; a.NumericType() != sqlite3.NULL {
fn.dequeue(a.Float())
}
}
@@ -90,20 +125,39 @@ func (fn *covariance) Value(ctx sqlite3.Context) {
r = fn.covar_samp()
case corr:
r = fn.correlation()
case regr_r2:
r = fn.regr_r2()
case regr_sxx:
r = fn.regr_sxx()
case regr_syy:
r = fn.regr_syy()
case regr_sxy:
r = fn.regr_sxy()
case regr_avgx:
r = fn.regr_avgx()
case regr_avgy:
r = fn.regr_avgy()
case regr_slope:
r = fn.regr_slope()
case regr_intercept:
r = fn.regr_intercept()
case regr_count:
ctx.ResultInt64(fn.regr_count())
return
}
ctx.ResultFloat(r)
}
func (fn *covariance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
a, b := arg[0], arg[1]
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
if a.NumericType() != sqlite3.NULL && b.NumericType() != sqlite3.NULL {
fn.enqueue(a.Float(), b.Float())
}
}
func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
a, b := arg[0], arg[1]
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
if a.NumericType() != sqlite3.NULL && b.NumericType() != sqlite3.NULL {
fn.dequeue(a.Float(), b.Float())
}
}

View File

@@ -1,4 +1,4 @@
package stats
package stats_test
import (
"math"
@@ -6,6 +6,7 @@ import (
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/stats"
)
func TestRegister_variance(t *testing.T) {
@@ -17,7 +18,7 @@ func TestRegister_variance(t *testing.T) {
}
defer db.Close()
Register(db)
stats.Register(db)
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x)`)
if err != nil {
@@ -89,20 +90,25 @@ func TestRegister_covariance(t *testing.T) {
}
defer db.Close()
Register(db)
stats.Register(db)
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x, y)`)
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (y, x)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO data (x, y) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
err = db.Exec(`INSERT INTO data (y, x) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT
corr(x, y), covar_samp(x, y), covar_pop(x, y) FROM data`)
corr(y, x), covar_samp(y, x), covar_pop(y, x),
regr_avgy(y, x), regr_avgx(y, x),
regr_syy(y, x), regr_sxx(y, x), regr_sxy(y, x),
regr_slope(y, x), regr_intercept(y, x), regr_r2(y, x),
regr_count(y, x)
FROM data`)
if err != nil {
t.Fatal(err)
}
@@ -118,10 +124,37 @@ func TestRegister_covariance(t *testing.T) {
if got := stmt.ColumnFloat(2); got != 17 {
t.Errorf("got %v, want 17", got)
}
if got := stmt.ColumnFloat(3); got != 4.2 {
t.Errorf("got %v, want 4.2", got)
}
if got := stmt.ColumnFloat(4); got != 75 {
t.Errorf("got %v, want 75", got)
}
if got := stmt.ColumnFloat(5); got != 14.8 {
t.Errorf("got %v, want 14.8", got)
}
if got := stmt.ColumnFloat(6); got != 500 {
t.Errorf("got %v, want 500", got)
}
if got := stmt.ColumnFloat(7); got != 85 {
t.Errorf("got %v, want 85", got)
}
if got := stmt.ColumnFloat(8); got != 0.17 {
t.Errorf("got %v, want 0.17", got)
}
if got := stmt.ColumnFloat(9); got != -8.55 {
t.Errorf("got %v, want -8.55", got)
}
if got := stmt.ColumnFloat(10); got != 0.9763513513513513 {
t.Errorf("got %v, want 0.9763513513513513", got)
}
if got := stmt.ColumnInt(11); got != 5 {
t.Errorf("got %v, want 5", got)
}
}
{
stmt, _, err := db.Prepare(`SELECT covar_samp(x, y) OVER (ROWS 1 PRECEDING) FROM data`)
stmt, _, err := db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`)
if err != nil {
t.Fatal(err)
}
@@ -138,3 +171,67 @@ func TestRegister_covariance(t *testing.T) {
}
}
}
func Benchmark_average(b *testing.B) {
db, err := sqlite3.Open(":memory:")
if err != nil {
b.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT avg(value) FROM generate_series(0, ?)`)
if err != nil {
b.Fatal(err)
}
defer stmt.Close()
err = stmt.BindInt(1, b.N)
if err != nil {
b.Fatal(err)
}
if stmt.Step() {
want := float64(b.N) / 2
if got := stmt.ColumnFloat(0); got != want {
b.Errorf("got %v, want %v", got, want)
}
}
err = stmt.Err()
if err != nil {
b.Error(err)
}
}
func Benchmark_variance(b *testing.B) {
db, err := sqlite3.Open(":memory:")
if err != nil {
b.Fatal(err)
}
defer db.Close()
stats.Register(db)
stmt, _, err := db.Prepare(`SELECT var_pop(value) FROM generate_series(0, ?)`)
if err != nil {
b.Fatal(err)
}
defer stmt.Close()
err = stmt.BindInt(1, b.N)
if err != nil {
b.Fatal(err)
}
if stmt.Step() && b.N > 100 {
want := float64(b.N*b.N) / 12
if got := stmt.ColumnFloat(0); want > (got-want)*float64(b.N) {
b.Errorf("got %v, want %v", got, want)
}
}
err = stmt.Err()
if err != nil {
b.Error(err)
}
}

View File

@@ -6,9 +6,12 @@ import "math"
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
// See also:
// https://duckdb.org/docs/sql/aggregates.html#statistical-aggregates
type welford struct {
m1, m2 kahan
n uint64
n int64
}
func (w welford) average() float64 {
@@ -48,10 +51,10 @@ func (w *welford) dequeue(x float64) {
}
type welford2 struct {
m1x, m2x kahan
m1y, m2y kahan
m1x, m2x kahan
cov kahan
n uint64
n int64
}
func (w welford2) covar_pop() float64 {
@@ -63,33 +66,72 @@ func (w welford2) covar_samp() float64 {
}
func (w welford2) correlation() float64 {
return w.cov.hi / math.Sqrt(w.m2x.hi*w.m2y.hi)
return w.cov.hi / math.Sqrt(w.m2y.hi*w.m2x.hi)
}
func (w *welford2) enqueue(x, y float64) {
func (w welford2) regr_avgy() float64 {
return w.m1y.hi
}
func (w welford2) regr_avgx() float64 {
return w.m1x.hi
}
func (w welford2) regr_syy() float64 {
return w.m2y.hi
}
func (w welford2) regr_sxx() float64 {
return w.m2x.hi
}
func (w welford2) regr_sxy() float64 {
return w.cov.hi
}
func (w welford2) regr_count() int64 {
return w.n
}
func (w welford2) regr_slope() float64 {
return w.cov.hi / w.m2x.hi
}
func (w welford2) regr_intercept() float64 {
slope := -w.regr_slope()
hi := math.FMA(slope, w.m1x.hi, w.m1y.hi)
lo := math.FMA(slope, w.m1x.lo, w.m1y.lo)
return hi + lo
}
func (w welford2) regr_r2() float64 {
return w.cov.hi * w.cov.hi / (w.m2y.hi * w.m2x.hi)
}
func (w *welford2) enqueue(y, x float64) {
w.n++
d1x := x - w.m1x.hi - w.m1x.lo
d1y := y - w.m1y.hi - w.m1y.lo
w.m1x.add(d1x / float64(w.n))
d1x := x - w.m1x.hi - w.m1x.lo
w.m1y.add(d1y / float64(w.n))
d2x := x - w.m1x.hi - w.m1x.lo
w.m1x.add(d1x / float64(w.n))
d2y := y - w.m1y.hi - w.m1y.lo
w.m2x.add(d1x * d2x)
d2x := x - w.m1x.hi - w.m1x.lo
w.m2y.add(d1y * d2y)
w.cov.add(d1x * d2y)
w.m2x.add(d1x * d2x)
w.cov.add(d1y * d2x)
}
func (w *welford2) dequeue(x, y float64) {
func (w *welford2) dequeue(y, x float64) {
w.n--
d1x := x - w.m1x.hi - w.m1x.lo
d1y := y - w.m1y.hi - w.m1y.lo
w.m1x.sub(d1x / float64(w.n))
d1x := x - w.m1x.hi - w.m1x.lo
w.m1y.sub(d1y / float64(w.n))
d2x := x - w.m1x.hi - w.m1x.lo
w.m1x.sub(d1x / float64(w.n))
d2y := y - w.m1y.hi - w.m1y.lo
w.m2x.sub(d1x * d2x)
d2x := x - w.m1x.hi - w.m1x.lo
w.m2y.sub(d1y * d2y)
w.cov.sub(d1x * d2y)
w.m2x.sub(d1x * d2x)
w.cov.sub(d1y * d2x)
}
type kahan struct{ hi, lo float64 }

View File

@@ -6,6 +6,8 @@ import (
)
func Test_welford(t *testing.T) {
t.Parallel()
var s1, s2 welford
s1.enqueue(4)
@@ -38,6 +40,8 @@ func Test_welford(t *testing.T) {
}
func Test_covar(t *testing.T) {
t.Parallel()
var c1, c2 welford2
c1.enqueue(3, 70)
@@ -64,6 +68,8 @@ func Test_covar(t *testing.T) {
}
func Test_correlation(t *testing.T) {
t.Parallel()
var c welford2
c.enqueue(1, 3)
c.enqueue(2, 2)

View File

@@ -8,7 +8,7 @@
// The implementation is not 100% compatible with the [ICU extension]:
// - upper() and lower() use [strings.ToUpper], [strings.ToLower] and [cases];
// - the LIKE operator follows [strings.EqualFold] rules;
// - the REGEXP operator uses Go [regex/syntax];
// - the REGEXP operator uses Go [regexp/syntax];
// - collation sequences use [collate].
//
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
@@ -113,14 +113,14 @@ func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
re = r
ctx.SetAuxData(0, re)
}
ctx.ResultBool(re.Match(arg[1].RawBlob()))
ctx.ResultBool(re.Match(arg[1].RawText()))
}
func like(ctx sqlite3.Context, arg ...sqlite3.Value) {
escape := rune(-1)
if len(arg) == 3 {
var size int
b := arg[2].RawBlob()
b := arg[2].RawText()
escape, size = utf8.DecodeRune(b)
if size != len(b) {
ctx.ResultError(util.ErrorString("ESCAPE expression must be a single character"))
@@ -141,7 +141,7 @@ func like(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
ctx.SetAuxData(0, re)
}
ctx.ResultBool(re.Match(arg[1].RawBlob()))
ctx.ResultBool(re.Match(arg[1].RawText()))
}
func like2regex(pattern string, escape rune) string {

View File

@@ -189,6 +189,8 @@ func TestRegister_error(t *testing.T) {
}
func Test_like2regex(t *testing.T) {
t.Parallel()
const prefix = `(?is)\A`
const sufix = `\z`
tests := []struct {

146
func.go
View File

@@ -2,6 +2,7 @@ package sqlite3
import (
"context"
"sync"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
@@ -14,17 +15,17 @@ import (
// This can be used to load schemas that contain
// one or more unknown collating sequences.
func (c *Conn) AnyCollationNeeded() {
c.call(c.api.anyCollation, uint64(c.handle), 0, 0)
c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0)
}
// CreateCollation defines a new collating sequence.
//
// https://sqlite.org/c3ref/create_collation.html
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
defer c.arena.reset()
defer c.arena.mark()()
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
r := c.call(c.api.createCollation,
r := c.call("sqlite3_create_collation_go",
uint64(c.handle), uint64(namePtr), uint64(funcPtr))
return c.error(r)
}
@@ -32,28 +33,32 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
// CreateFunction defines a new scalar SQL function.
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error {
defer c.arena.reset()
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error {
defer c.arena.mark()()
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
r := c.call(c.api.createFunction,
r := c.call("sqlite3_create_function_go",
uint64(c.handle), uint64(namePtr), uint64(nArg),
uint64(flag), uint64(funcPtr))
return c.error(r)
}
// ScalarFunction is the type of a scalar SQL function.
// Implementations must not retain arg.
type ScalarFunction func(ctx Context, arg ...Value)
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
// If fn returns a [WindowFunction], then an aggregate window function is created.
// If fn returns an [io.Closer], it will be called to free resources.
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
defer c.arena.reset()
call := c.api.createAggregate
defer c.arena.mark()()
call := "sqlite3_create_aggregate_function_go"
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
if _, ok := fn().(WindowFunction); ok {
call = c.api.createWindow
call = "sqlite3_create_window_function_go"
}
r := c.call(call,
uint64(c.handle), uint64(namePtr), uint64(nArg),
@@ -66,7 +71,8 @@ func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn
// https://sqlite.org/appfunc.html
type AggregateFunction interface {
// Step is invoked to add a row to the current window.
// The function arguments, if any, corresponding to the row being added are passed to Step.
// The function arguments, if any, corresponding to the row being added, are passed to Step.
// Implementations must not retain arg.
Step(ctx Context, arg ...Value)
// Value is invoked to return the current (or final) value of the aggregate.
@@ -81,9 +87,21 @@ type WindowFunction interface {
// Inverse is invoked to remove the oldest presently aggregated result of Step from the current window.
// The function arguments, if any, are those passed to Step for the row being removed.
// Implementations must not retain arg.
Inverse(ctx Context, arg ...Value)
}
// OverloadFunction overloads a function for a virtual table.
//
// https://sqlite.org/c3ref/overload_function.html
func (c *Conn) OverloadFunction(name string, nArg int) error {
defer c.arena.mark()()
namePtr := c.arena.string(name)
r := c.call("sqlite3_overload_function",
uint64(c.handle), uint64(namePtr), uint64(nArg))
return c.error(r)
}
func destroyCallback(ctx context.Context, mod api.Module, pApp uint32) {
util.DelHandle(ctx, pApp)
}
@@ -93,80 +111,80 @@ func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nK
return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))
}
func funcCallback(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp, nArg, pArg uint32) {
args := getFuncArgs()
defer putFuncArgs(args)
db := ctx.Value(connKey{}).(*Conn)
fn := userDataHandle(db, pCtx).(func(ctx Context, arg ...Value))
fn(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
fn := util.GetHandle(db.ctx, pApp).(ScalarFunction)
callbackArgs(db, args[:nArg], pArg)
fn(Context{db, pCtx}, args[:nArg]...)
}
func stepCallback(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp, nArg, pArg uint32) {
args := getFuncArgs()
defer putFuncArgs(args)
db := ctx.Value(connKey{}).(*Conn)
fn := aggregateCtxHandle(db, pCtx, nil)
fn.Step(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
callbackArgs(db, args[:nArg], pArg)
fn, _ := callbackAggregate(db, pAgg, pApp)
fn.Step(Context{db, pCtx}, args[:nArg]...)
}
func finalCallback(ctx context.Context, mod api.Module, pCtx uint32) {
var handle uint32
func finalCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp uint32) {
db := ctx.Value(connKey{}).(*Conn)
fn := aggregateCtxHandle(db, pCtx, &handle)
fn, handle := callbackAggregate(db, pAgg, pApp)
fn.Value(Context{db, pCtx})
if err := util.DelHandle(ctx, handle); err != nil {
Context{db, pCtx}.ResultError(err)
}
util.DelHandle(ctx, handle)
}
func valueCallback(ctx context.Context, mod api.Module, pCtx uint32) {
func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg uint32) {
db := ctx.Value(connKey{}).(*Conn)
fn := aggregateCtxHandle(db, pCtx, nil)
fn := util.GetHandle(db.ctx, pAgg).(AggregateFunction)
fn.Value(Context{db, pCtx})
}
func inverseCallback(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg, nArg, pArg uint32) {
args := getFuncArgs()
defer putFuncArgs(args)
db := ctx.Value(connKey{}).(*Conn)
fn := aggregateCtxHandle(db, pCtx, nil).(WindowFunction)
fn.Inverse(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
callbackArgs(db, args[:nArg], pArg)
fn := util.GetHandle(db.ctx, pAgg).(WindowFunction)
fn.Inverse(Context{db, pCtx}, args[:nArg]...)
}
func userDataHandle(db *Conn, pCtx uint32) any {
pApp := uint32(db.call(db.api.userData, uint64(pCtx)))
return util.GetHandle(db.ctx, pApp)
func callbackAggregate(db *Conn, pAgg, pApp uint32) (AggregateFunction, uint32) {
if pApp == 0 {
handle := util.ReadUint32(db.mod, pAgg)
return util.GetHandle(db.ctx, handle).(AggregateFunction), handle
}
// We need to create the aggregate.
fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)()
handle := util.AddHandle(db.ctx, fn)
if pAgg != 0 {
util.WriteUint32(db.mod, pAgg, handle)
}
return fn, handle
}
func aggregateCtxHandle(db *Conn, pCtx uint32, close *uint32) AggregateFunction {
// On close, we're getting rid of the aggregate.
// Don't allocate space to store it.
var size uint64
if close == nil {
size = ptrlen
}
ptr := uint32(db.call(db.api.aggregateCtx, uint64(pCtx), size))
// If we already have an aggregate, return it.
if ptr != 0 {
if handle := util.ReadUint32(db.mod, ptr); handle != 0 {
fn := util.GetHandle(db.ctx, handle).(AggregateFunction)
if close != nil {
*close = handle
}
return fn
}
}
// Create a new aggregate, and store it if needed.
fn := userDataHandle(db, pCtx).(func() AggregateFunction)()
if ptr != 0 {
util.WriteUint32(db.mod, ptr, util.AddHandle(db.ctx, fn))
}
return fn
}
func callbackArgs(db *Conn, nArg, pArg uint32) []Value {
args := make([]Value, nArg)
for i := range args {
args[i] = Value{
sqlite: db.sqlite,
func callbackArgs(db *Conn, arg []Value, pArg uint32) {
for i := range arg {
arg[i] = Value{
c: db,
handle: util.ReadUint32(db.mod, pArg+ptrlen*uint32(i)),
}
}
return args
}
var funcArgsPool sync.Pool
func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) {
funcArgsPool.Put(p)
}
func getFuncArgs() *[_MAX_FUNCTION_ARG]Value {
if p := funcArgsPool.Get(); p == nil {
return new([_MAX_FUNCTION_ARG]Value)
} else {
return p.(*[_MAX_FUNCTION_ARG]Value)
}
}

View File

@@ -129,7 +129,7 @@ func ExampleContext_SetAuxData() {
ctx.SetAuxData(0, r)
re = r
}
ctx.ResultBool(re.Match(arg[1].RawBlob()))
ctx.ResultBool(re.Match(arg[1].RawText()))
})
if err != nil {
log.Fatal(err)

7
go.mod
View File

@@ -5,9 +5,10 @@ go 1.21
require (
github.com/ncruces/julianday v1.0.0
github.com/psanford/httpreadat v0.1.0
github.com/tetratelabs/wazero v1.5.0
golang.org/x/sync v0.5.0
golang.org/x/sys v0.14.0
github.com/tetratelabs/wazero v1.6.0
golang.org/x/crypto v0.18.0
golang.org/x/sync v0.6.0
golang.org/x/sys v0.16.0
golang.org/x/text v0.14.0
)

14
go.sum
View File

@@ -2,11 +2,13 @@ github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
github.com/tetratelabs/wazero v1.6.0 h1:z0H1iikCdP8t+q341xqepY4EWvHEw8Es7tlqiVzlP3g=
github.com/tetratelabs/wazero v1.6.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=

View File

@@ -1,4 +1,3 @@
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=

View File

@@ -3,7 +3,7 @@ module github.com/ncruces/go-sqlite3/gormlite
go 1.21
require (
github.com/ncruces/go-sqlite3 v0.10.5
github.com/ncruces/go-sqlite3 v0.11.0
gorm.io/gorm v1.25.5
)
@@ -12,5 +12,5 @@ require (
github.com/jinzhu/now v1.1.5 // indirect
github.com/ncruces/julianday v1.0.0 // indirect
github.com/tetratelabs/wazero v1.5.0 // indirect
golang.org/x/sys v0.14.0 // indirect
golang.org/x/sys v0.15.0 // indirect
)

View File

@@ -2,14 +2,14 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/ncruces/go-sqlite3 v0.10.5 h1:SPnFFYajDfhTuJNjeNwdOhwVCRSAqB1PdSHsGrdfYjw=
github.com/ncruces/go-sqlite3 v0.10.5/go.mod h1:8aGu9/G8lLZbvO6TXA0FXTP2liIefFmbpeXuhG4nJLw=
github.com/ncruces/go-sqlite3 v0.11.0 h1:PDjs8Ve2Z0GWmHyKQHGUyG78grCXKhiHCUZQI8CqXO8=
github.com/ncruces/go-sqlite3 v0.11.0/go.mod h1:zaYJ6xP+EQiWJCa3nd3h28cD8DuSIcIqh+LrJMrBN9k=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=

View File

@@ -13,7 +13,6 @@ import (
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
)
@@ -40,9 +39,7 @@ func (dialector _Dialector) Initialize(db *gorm.DB) (err error) {
if dialector.Conn != nil {
db.ConnPool = dialector.Conn
} else {
conn, err := driver.Open(dialector.DSN, func(c *sqlite3.Conn) error {
return c.Exec("PRAGMA foreign_keys = ON")
})
conn, err := driver.Open(dialector.DSN, nil)
if err != nil {
return err
}

View File

@@ -3,6 +3,8 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
go test
rm -rf gorm/ tests/
git clone --filter=blob:none https://github.com/go-gorm/gorm.git
mv gorm/tests tests

View File

@@ -1,7 +1,6 @@
package util
import (
"fmt"
"runtime"
"strconv"
)
@@ -15,9 +14,8 @@ const (
OOMErr = ErrorString("sqlite3: out of memory")
RangeErr = ErrorString("sqlite3: index out of range")
NoNulErr = ErrorString("sqlite3: missing NUL terminator")
NoGlobalErr = ErrorString("sqlite3: could not find global: ")
NoFuncErr = ErrorString("sqlite3: could not find function: ")
BinaryErr = ErrorString("sqlite3: no SQLite binary embed/set/loaded")
NoBinaryErr = ErrorString("sqlite3: no SQLite binary embed/set/loaded")
BadBinaryErr = ErrorString("sqlite3: invalid SQLite binary embed/set/loaded")
TimeErr = ErrorString("sqlite3: invalid time value")
WhenceErr = ErrorString("sqlite3: invalid whence")
OffsetErr = ErrorString("sqlite3: invalid offset")
@@ -35,14 +33,6 @@ func AssertErr() ErrorString {
return ErrorString(msg)
}
func Finalizer[T any](skip int) func(*T) {
msg := fmt.Sprintf("sqlite3: %T not closed", new(T))
if _, file, line, ok := runtime.Caller(skip + 1); ok && skip >= 0 {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return func(*T) { panic(ErrorString(msg)) }
}
func ErrorCodeString(rc uint32) string {
switch rc {
case ABORT_ROLLBACK:

View File

@@ -23,6 +23,19 @@ func ExportFuncVI[T0 i32](mod wazero.HostModuleBuilder, name string, fn func(con
Export(name)
}
type funcVII[T0, T1 i32] func(context.Context, api.Module, T0, T1)
func (fn funcVII[T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
fn(ctx, mod, T0(stack[0]), T1(stack[1]))
}
func ExportFuncVII[T0, T1 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1)) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcVII[T0, T1](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, nil).
Export(name)
}
type funcVIII[T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2)
func (fn funcVIII[T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
@@ -36,6 +49,32 @@ func ExportFuncVIII[T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, f
Export(name)
}
type funcVIIII[T0, T1, T2, T3 i32] func(context.Context, api.Module, T0, T1, T2, T3)
func (fn funcVIIII[T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]))
}
func ExportFuncVIIII[T0, T1, T2, T3 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3)) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcVIIII[T0, T1, T2, T3](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, nil).
Export(name)
}
type funcVIIIII[T0, T1, T2, T3, T4 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4)
func (fn funcVIIIII[T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) {
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]))
}
func ExportFuncVIIIII[T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3, T4)) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcVIIIII[T0, T1, T2, T3, T4](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, nil).
Export(name)
}
type funcII[TR, T0 i32] func(context.Context, api.Module, T0) TR
func (fn funcII[TR, T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {

35
internal/util/json.go Normal file
View File

@@ -0,0 +1,35 @@
package util
import (
"encoding/json"
"strconv"
"time"
"unsafe"
)
type JSON struct{ Value any }
func (j JSON) Scan(value any) error {
var buf []byte
switch v := value.(type) {
case []byte:
buf = v
case string:
buf = unsafe.Slice(unsafe.StringData(v), len(v))
case int64:
buf = strconv.AppendInt(nil, v, 10)
case float64:
buf = strconv.AppendFloat(nil, v, 'g', -1, 64)
case time.Time:
buf = append(buf, '"')
buf = v.AppendFormat(buf, time.RFC3339Nano)
buf = append(buf, '"')
case nil:
buf = append(buf, "null"...)
default:
panic(AssertErr())
}
return json.Unmarshal(buf, j.Value)
}

11
internal/util/pointer.go Normal file
View File

@@ -0,0 +1,11 @@
package util
type Pointer[T any] struct{ Value T }
func (p Pointer[T]) unwrap() any { return p.Value }
type PointerUnwrap interface{ unwrap() any }
func UnwrapPointer(p PointerUnwrap) any {
return p.unwrap()
}

View File

@@ -0,0 +1,15 @@
package util_test
import (
"math"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
)
func TestUnwrapPointer(t *testing.T) {
p := util.Pointer[float64]{Value: math.Pi}
if got := util.UnwrapPointer(p); got != math.Pi {
t.Errorf("want π, got %v", got)
}
}

10
internal/util/reflect.go Normal file
View File

@@ -0,0 +1,10 @@
package util
import "reflect"
func ReflectType(v reflect.Value) reflect.Type {
if v.Kind() != reflect.Invalid {
return v.Type()
}
return nil
}

View File

@@ -0,0 +1,21 @@
package util
import (
"fmt"
"math"
"reflect"
"testing"
)
func TestReflectType(t *testing.T) {
tests := []any{nil, 1, math.Pi, "abc"}
for _, tt := range tests {
t.Run(fmt.Sprint(tt), func(t *testing.T) {
want := fmt.Sprintf("%T", tt)
got := fmt.Sprintf("%v", ReflectType(reflect.ValueOf(tt)))
if got != want {
t.Errorf("ReflectType() = %v, want %v", got, want)
}
})
}
}

41
json.go
View File

@@ -1,46 +1,11 @@
package sqlite3
import (
"encoding/json"
"strconv"
"time"
"unsafe"
"github.com/ncruces/go-sqlite3/internal/util"
)
import "github.com/ncruces/go-sqlite3/internal/util"
// JSON returns a value that can be used as an argument to
// [database/sql.DB.Exec], [database/sql.Row.Scan] and similar methods to
// store value as JSON, or decode JSON into value.
// JSON should NOT be used with [BindJSON] or [ResultJSON].
func JSON(value any) any {
return jsonValue{value}
}
type jsonValue struct{ any }
func (j jsonValue) JSON() any { return j.any }
func (j jsonValue) Scan(value any) error {
var buf []byte
switch v := value.(type) {
case []byte:
buf = v
case string:
buf = unsafe.Slice(unsafe.StringData(v), len(v))
case int64:
buf = strconv.AppendInt(nil, v, 10)
case float64:
buf = strconv.AppendFloat(nil, v, 'g', -1, 64)
case time.Time:
buf = append(buf, '"')
buf = v.AppendFormat(buf, time.RFC3339Nano)
buf = append(buf, '"')
case nil:
buf = append(buf, "null"...)
default:
panic(util.AssertErr())
}
return json.Unmarshal(buf, j.any)
return util.JSON{Value: value}
}

View File

@@ -1,14 +1,12 @@
package sqlite3
// Pointer returns a pointer to a value
// that can be used as an argument to
import "github.com/ncruces/go-sqlite3/internal/util"
// Pointer returns a pointer to a value that can be used as an argument to
// [database/sql.DB.Exec] and similar methods.
// Pointer should NOT be used with [BindPointer] or [ResultPointer].
//
// https://sqlite.org/bindptr.html
func Pointer[T any](val T) any {
return pointer[T]{val}
func Pointer[T any](value T) any {
return util.Pointer[T]{Value: value}
}
type pointer[T any] struct{ val T }
func (p pointer[T]) Pointer() any { return p.val }

281
sqlite.go
View File

@@ -4,8 +4,10 @@ package sqlite3
import (
"context"
"math"
"math/bits"
"os"
"sync"
"unsafe"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/vfs"
@@ -57,7 +59,7 @@ func compileSQLite() {
}
}
if bin == nil {
instance.err = util.BinaryErr
instance.err = util.NoBinaryErr
return
}
@@ -67,8 +69,13 @@ func compileSQLite() {
type sqlite struct {
ctx context.Context
mod api.Module
api sqliteAPI
funcs struct {
fn [32]api.Function
id [32]*byte
mask uint32
}
stack [8]uint64
freer uint32
}
func instantiateSQLite() (sqlt *sqlite, err error) {
@@ -86,108 +93,12 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
return nil, err
}
getFun := func(name string) api.Function {
f := sqlt.mod.ExportedFunction(name)
if f == nil {
err = util.NoFuncErr + util.ErrorString(name)
return nil
}
return f
global := sqlt.mod.ExportedGlobal("malloc_destructor")
if global == nil {
return nil, util.BadBinaryErr
}
getVal := func(name string) uint32 {
g := sqlt.mod.ExportedGlobal(name)
if g == nil {
err = util.NoGlobalErr + util.ErrorString(name)
return 0
}
return util.ReadUint32(sqlt.mod, uint32(g.Get()))
}
sqlt.api = sqliteAPI{
free: getFun("free"),
malloc: getFun("malloc"),
destructor: getVal("malloc_destructor"),
errcode: getFun("sqlite3_errcode"),
errstr: getFun("sqlite3_errstr"),
errmsg: getFun("sqlite3_errmsg"),
erroff: getFun("sqlite3_error_offset"),
open: getFun("sqlite3_open_v2"),
close: getFun("sqlite3_close"),
closeZombie: getFun("sqlite3_close_v2"),
prepare: getFun("sqlite3_prepare_v3"),
finalize: getFun("sqlite3_finalize"),
reset: getFun("sqlite3_reset"),
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
interrupt: getFun("sqlite3_interrupt"),
progressHandler: getFun("sqlite3_progress_handler_go"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
bindName: getFun("sqlite3_bind_parameter_name"),
bindNull: getFun("sqlite3_bind_null"),
bindInteger: getFun("sqlite3_bind_int64"),
bindFloat: getFun("sqlite3_bind_double"),
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
bindPointer: getFun("sqlite3_bind_pointer_go"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
columnInteger: getFun("sqlite3_column_int64"),
columnFloat: getFun("sqlite3_column_double"),
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
blobOpen: getFun("sqlite3_blob_open"),
blobClose: getFun("sqlite3_blob_close"),
blobReopen: getFun("sqlite3_blob_reopen"),
blobBytes: getFun("sqlite3_blob_bytes"),
blobRead: getFun("sqlite3_blob_read"),
blobWrite: getFun("sqlite3_blob_write"),
backupInit: getFun("sqlite3_backup_init"),
backupStep: getFun("sqlite3_backup_step"),
backupFinish: getFun("sqlite3_backup_finish"),
backupRemaining: getFun("sqlite3_backup_remaining"),
backupPageCount: getFun("sqlite3_backup_pagecount"),
changes: getFun("sqlite3_changes64"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
autocommit: getFun("sqlite3_get_autocommit"),
anyCollation: getFun("sqlite3_anycollseq_init"),
createCollation: getFun("sqlite3_create_collation_go"),
createFunction: getFun("sqlite3_create_function_go"),
createAggregate: getFun("sqlite3_create_aggregate_function_go"),
createWindow: getFun("sqlite3_create_window_function_go"),
aggregateCtx: getFun("sqlite3_aggregate_context"),
userData: getFun("sqlite3_user_data"),
setAuxData: getFun("sqlite3_set_auxdata_go"),
getAuxData: getFun("sqlite3_get_auxdata"),
valueType: getFun("sqlite3_value_type"),
valueInteger: getFun("sqlite3_value_int64"),
valueFloat: getFun("sqlite3_value_double"),
valueText: getFun("sqlite3_value_text"),
valueBlob: getFun("sqlite3_value_blob"),
valueBytes: getFun("sqlite3_value_bytes"),
valuePointer: getFun("sqlite3_value_pointer_go"),
resultNull: getFun("sqlite3_result_null"),
resultInteger: getFun("sqlite3_result_int64"),
resultFloat: getFun("sqlite3_result_double"),
resultText: getFun("sqlite3_result_text64"),
resultBlob: getFun("sqlite3_result_blob64"),
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
resultPointer: getFun("sqlite3_result_pointer_go"),
resultValue: getFun("sqlite3_result_value"),
resultError: getFun("sqlite3_result_error"),
resultErrorCode: getFun("sqlite3_result_error_code"),
resultErrorMem: getFun("sqlite3_result_error_nomem"),
resultErrorBig: getFun("sqlite3_result_error_toobig"),
createModule: getFun("sqlite3_create_module_go"),
declareVTab: getFun("sqlite3_declare_vtab"),
vtabConfig: getFun("sqlite3_vtab_config_go"),
vtabRHSValue: getFun("sqlite3_vtab_rhs_value"),
}
sqlt.freer = util.ReadUint32(sqlt.mod, uint32(global.Get()))
if err != nil {
return nil, err
}
@@ -209,17 +120,17 @@ func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
panic(util.OOMErr)
}
if r := sqlt.call(sqlt.api.errstr, rc); r != 0 {
if r := sqlt.call("sqlite3_errstr", rc); r != 0 {
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_NAME)
}
if handle != 0 {
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_NAME)
if r := sqlt.call("sqlite3_errmsg", uint64(handle)); r != 0 {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_LENGTH)
}
if sql != nil {
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
if r := sqlt.call("sqlite3_error_offset", uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:]
}
}
@@ -232,12 +143,42 @@ func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
return &err
}
func (sqlt *sqlite) call(fn api.Function, params ...uint64) uint64 {
func (sqlt *sqlite) getfn(name string) api.Function {
c := &sqlt.funcs
p := unsafe.StringData(name)
for i := range c.id {
if c.id[i] == p {
c.id[i] = nil
c.mask &^= uint32(1) << i
return c.fn[i]
}
}
return sqlt.mod.ExportedFunction(name)
}
func (sqlt *sqlite) putfn(name string, fn api.Function) {
c := &sqlt.funcs
p := unsafe.StringData(name)
i := bits.TrailingZeros32(^c.mask)
if i < 32 {
c.id[i] = p
c.fn[i] = fn
c.mask |= uint32(1) << i
} else {
c.id[0] = p
c.fn[0] = fn
c.mask = uint32(1)
}
}
func (sqlt *sqlite) call(name string, params ...uint64) uint64 {
copy(sqlt.stack[:], params)
fn := sqlt.getfn(name)
err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
if err != nil {
panic(err)
}
sqlt.putfn(name, fn)
return sqlt.stack[0]
}
@@ -245,14 +186,14 @@ func (sqlt *sqlite) free(ptr uint32) {
if ptr == 0 {
return
}
sqlt.call(sqlt.api.free, uint64(ptr))
sqlt.call("free", uint64(ptr))
}
func (sqlt *sqlite) new(size uint64) uint32 {
if size > _MAX_ALLOCATION_SIZE {
panic(util.OOMErr)
}
ptr := uint32(sqlt.call(sqlt.api.malloc, size))
ptr := uint32(sqlt.call("malloc", size))
if ptr == 0 && size != 0 {
panic(util.OOMErr)
}
@@ -275,6 +216,8 @@ func (sqlt *sqlite) newString(s string) uint32 {
}
func (sqlt *sqlite) newArena(size uint64) arena {
// Ensure the arena's size is a multiple of 8.
size = (size + 7) &^ 7
return arena{
sqlt: sqlt,
size: uint32(size),
@@ -294,20 +237,32 @@ func (a *arena) free() {
if a.sqlt == nil {
return
}
a.reset()
for _, ptr := range a.ptrs {
a.sqlt.free(ptr)
}
a.sqlt.free(a.base)
a.sqlt = nil
}
func (a *arena) reset() {
for _, ptr := range a.ptrs {
a.sqlt.free(ptr)
func (a *arena) mark() (reset func()) {
ptrs := len(a.ptrs)
next := a.next
return func() {
for _, ptr := range a.ptrs[ptrs:] {
a.sqlt.free(ptr)
}
a.ptrs = a.ptrs[:ptrs]
a.next = next
}
a.ptrs = nil
a.next = 0
}
func (a *arena) new(size uint64) uint32 {
// Align the next address, to 4 or 8 bytes.
if size&7 != 0 {
a.next = (a.next + 3) &^ 3
} else {
a.next = (a.next + 7) &^ 7
}
if size <= uint64(a.size-a.next) {
ptr := a.base + a.next
a.next += uint32(size)
@@ -333,99 +288,15 @@ func (a *arena) string(s string) uint32 {
return ptr
}
type sqliteAPI struct {
free api.Function
malloc api.Function
errcode api.Function
errstr api.Function
errmsg api.Function
erroff api.Function
open api.Function
close api.Function
closeZombie api.Function
prepare api.Function
finalize api.Function
reset api.Function
step api.Function
exec api.Function
interrupt api.Function
progressHandler api.Function
clearBindings api.Function
bindCount api.Function
bindIndex api.Function
bindName api.Function
bindNull api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
bindPointer api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
columnInteger api.Function
columnFloat api.Function
columnText api.Function
columnBlob api.Function
columnBytes api.Function
blobOpen api.Function
blobClose api.Function
blobReopen api.Function
blobBytes api.Function
blobRead api.Function
blobWrite api.Function
backupInit api.Function
backupStep api.Function
backupFinish api.Function
backupRemaining api.Function
backupPageCount api.Function
changes api.Function
lastRowid api.Function
autocommit api.Function
anyCollation api.Function
createCollation api.Function
createFunction api.Function
createAggregate api.Function
createWindow api.Function
aggregateCtx api.Function
userData api.Function
setAuxData api.Function
getAuxData api.Function
valueType api.Function
valueInteger api.Function
valueFloat api.Function
valueText api.Function
valueBlob api.Function
valueBytes api.Function
valuePointer api.Function
resultNull api.Function
resultInteger api.Function
resultFloat api.Function
resultText api.Function
resultBlob api.Function
resultZeroBlob api.Function
resultPointer api.Function
resultValue api.Function
resultError api.Function
resultErrorCode api.Function
resultErrorMem api.Function
resultErrorBig api.Function
createModule api.Function
declareVTab api.Function
vtabConfig api.Function
vtabRHSValue api.Function
destructor uint32
}
func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncII(env, "go_progress", progressCallback)
util.ExportFuncVIII(env, "go_log", logCallback)
util.ExportFuncVI(env, "go_destroy", destroyCallback)
util.ExportFuncVIII(env, "go_func", funcCallback)
util.ExportFuncVIII(env, "go_step", stepCallback)
util.ExportFuncVI(env, "go_final", finalCallback)
util.ExportFuncVI(env, "go_value", valueCallback)
util.ExportFuncVIII(env, "go_inverse", inverseCallback)
util.ExportFuncVIIII(env, "go_func", funcCallback)
util.ExportFuncVIIIII(env, "go_step", stepCallback)
util.ExportFuncVIII(env, "go_final", finalCallback)
util.ExportFuncVII(env, "go_value", valueCallback)
util.ExportFuncVIIII(env, "go_inverse", inverseCallback)
util.ExportFuncIIIIII(env, "go_compare", compareCallback)
util.ExportFuncIIIIII(env, "go_vtab_create", vtabModuleCallback(0))
util.ExportFuncIIIIII(env, "go_vtab_connect", vtabModuleCallback(1))

51
sqlite3/column.c Normal file
View File

@@ -0,0 +1,51 @@
#include <stddef.h>
#include "sqlite3.h"
union sqlite3_data {
sqlite3_int64 i;
double d;
struct {
const void *ptr;
int len;
};
};
int sqlite3_columns_go(sqlite3_stmt *stmt, int nCol, char *aType,
union sqlite3_data *aData) {
if (nCol != sqlite3_column_count(stmt)) {
return SQLITE_MISUSE;
}
int rc = SQLITE_OK;
for (int i = 0; i < nCol; ++i) {
const void *ptr = NULL;
switch (aType[i] = sqlite3_column_type(stmt, i)) {
default: // SQLITE_NULL
aData[i] = (union sqlite3_data){};
case SQLITE_INTEGER:
aData[i].i = sqlite3_column_int64(stmt, i);
continue;
case SQLITE_FLOAT:
aData[i].d = sqlite3_column_double(stmt, i);
continue;
case SQLITE_TEXT:
ptr = sqlite3_column_text(stmt, i);
break;
case SQLITE_BLOB:
ptr = sqlite3_column_blob(stmt, i);
break;
}
if (ptr == NULL && rc == SQLITE_OK) {
rc = sqlite3_errcode(sqlite3_db_handle(stmt));
}
aData[i].ptr = ptr;
aData[i].len = sqlite3_column_bytes(stmt, i);
}
return rc;
}
static_assert(offsetof(union sqlite3_data, i) == 0, "Unexpected offset");
static_assert(offsetof(union sqlite3_data, d) == 0, "Unexpected offset");
static_assert(offsetof(union sqlite3_data, ptr) == 0, "Unexpected offset");
static_assert(offsetof(union sqlite3_data, len) == 4, "Unexpected offset");
static_assert(sizeof(union sqlite3_data) == 8, "Unexpected size");

View File

@@ -3,33 +3,34 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3440200.zip"
curl -#OL "https://sqlite.org/2024/sqlite-amalgamation-3450000.zip"
unzip -d . sqlite-amalgamation-*.zip
mv sqlite-amalgamation-*/sqlite3* .
rm -rf sqlite-amalgamation-*
cat *.patch | patch --posix
cat *.patch | patch --no-backup-if-mismatch
mkdir -p ext/
cd ext/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.2/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.2/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.2/ext/misc/uuid.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.2/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.2/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.2/ext/misc/series.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.2/ext/misc/anycollseq.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.45.0/ext/misc/anycollseq.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.45.0/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.45.0/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.45.0/ext/misc/ieee754.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.45.0/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.45.0/ext/misc/series.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.45.0/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.45.0/ext/misc/uuid.c"
cd ~-
cd ../vfs/tests/mptest/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.2/mptest/mptest.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.2/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.2/mptest/config02.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.2/mptest/crash01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.2/mptest/crash02.subtest"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.2/mptest/multiwrite01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.45.0/mptest/mptest.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.45.0/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.45.0/mptest/config02.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.45.0/mptest/crash01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.45.0/mptest/crash02.subtest"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.45.0/mptest/multiwrite01.test"
cd ~-
cd ../vfs/tests/speedtest1/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.2/test/speedtest1.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.45.0/test/speedtest1.c"
cd ~-

View File

@@ -3,14 +3,47 @@
#include "include.h"
#include "sqlite3.h"
void go_func(sqlite3_context *, int, sqlite3_value **);
void go_step(sqlite3_context *, int, sqlite3_value **);
void go_final(sqlite3_context *);
void go_value(sqlite3_context *);
void go_inverse(sqlite3_context *, int, sqlite3_value **);
int go_compare(go_handle, int, const void *, int, const void *);
void go_func(sqlite3_context *, go_handle, int, sqlite3_value **);
void go_step(sqlite3_context *, go_handle *, go_handle, int, sqlite3_value **);
void go_final(sqlite3_context *, go_handle, go_handle);
void go_value(sqlite3_context *, go_handle);
void go_inverse(sqlite3_context *, go_handle *, int, sqlite3_value **);
void go_func_wrapper(sqlite3_context *ctx, int nArg, sqlite3_value **pArg) {
go_func(ctx, sqlite3_user_data(ctx), nArg, pArg);
}
void go_step_wrapper(sqlite3_context *ctx, int nArg, sqlite3_value **pArg) {
go_handle *agg = sqlite3_aggregate_context(ctx, 4);
go_handle data = NULL;
if (agg == NULL || *agg == NULL) {
data = sqlite3_user_data(ctx);
}
go_step(ctx, agg, data, nArg, pArg);
}
void go_final_wrapper(sqlite3_context *ctx) {
go_handle *agg = sqlite3_aggregate_context(ctx, 0);
go_handle data = NULL;
if (agg == NULL || *agg == NULL) {
data = sqlite3_user_data(ctx);
}
go_final(ctx, agg, data);
}
void go_value_wrapper(sqlite3_context *ctx) {
go_handle *agg = sqlite3_aggregate_context(ctx, 4);
go_value(ctx, *agg);
}
void go_inverse_wrapper(sqlite3_context *ctx, int nArg, sqlite3_value **pArg) {
go_handle *agg = sqlite3_aggregate_context(ctx, 4);
go_inverse(ctx, *agg, nArg, pArg);
}
int sqlite3_create_collation_go(sqlite3 *db, const char *name, go_handle app) {
int rc = sqlite3_create_collation_v2(db, name, SQLITE_UTF8, app, go_compare,
go_destroy);
@@ -21,22 +54,22 @@ int sqlite3_create_collation_go(sqlite3 *db, const char *name, go_handle app) {
int sqlite3_create_function_go(sqlite3 *db, const char *name, int argc,
int flags, go_handle app) {
return sqlite3_create_function_v2(db, name, argc, SQLITE_UTF8 | flags, app,
go_func, /*step=*/NULL, /*final=*/NULL,
go_destroy);
go_func_wrapper, /*step=*/NULL,
/*final=*/NULL, go_destroy);
}
int sqlite3_create_aggregate_function_go(sqlite3 *db, const char *name,
int argc, int flags, go_handle app) {
return sqlite3_create_window_function(db, name, argc, SQLITE_UTF8 | flags,
app, go_step, go_final, /*value=*/NULL,
/*inverse=*/NULL, go_destroy);
return sqlite3_create_function_v2(db, name, argc, SQLITE_UTF8 | flags, app,
/*func=*/NULL, go_step_wrapper,
go_final_wrapper, go_destroy);
}
int sqlite3_create_window_function_go(sqlite3 *db, const char *name, int argc,
int flags, go_handle app) {
return sqlite3_create_window_function(db, name, argc, SQLITE_UTF8 | flags,
app, go_step, go_final, go_value,
go_inverse, go_destroy);
return sqlite3_create_window_function(
db, name, argc, SQLITE_UTF8 | flags, app, go_step_wrapper,
go_final_wrapper, go_value_wrapper, go_inverse_wrapper, go_destroy);
}
void sqlite3_set_auxdata_go(sqlite3_context *ctx, int i, go_handle aux) {

9
sqlite3/log.c Normal file
View File

@@ -0,0 +1,9 @@
#include <stdbool.h>
#include "sqlite3.h"
void go_log(void *, int, const char *);
int sqlite3_config_log_go(bool enable) {
return sqlite3_config(SQLITE_CONFIG_LOG, enable ? go_log : NULL, NULL);
}

View File

@@ -4,12 +4,15 @@
#include "ext/anycollseq.c"
#include "ext/base64.c"
#include "ext/decimal.c"
#include "ext/ieee754.c"
#include "ext/regexp.c"
#include "ext/series.c"
#include "ext/uint.c"
#include "ext/uuid.c"
// Bindings
#include "column.c"
#include "func.c"
#include "log.c"
#include "pointer.c"
#include "progress.c"
#include "time.c"
@@ -22,6 +25,7 @@ __attribute__((constructor)) void init() {
sqlite3_initialize();
sqlite3_auto_extension((void (*)(void))sqlite3_base_init);
sqlite3_auto_extension((void (*)(void))sqlite3_decimal_init);
sqlite3_auto_extension((void (*)(void))sqlite3_ieee_init);
sqlite3_auto_extension((void (*)(void))sqlite3_regexp_init);
sqlite3_auto_extension((void (*)(void))sqlite3_series_init);
sqlite3_auto_extension((void (*)(void))sqlite3_uint_init);

View File

@@ -37,18 +37,20 @@
#define SQLITE_DEFAULT_WAL_SYNCHRONOUS 1
#define SQLITE_LIKE_DOESNT_MATCH_BLOBS
#define SQLITE_MAX_EXPR_DEPTH 0
#define SQLITE_OMIT_DECLTYPE
#define SQLITE_USE_ALLOCA
#define SQLITE_OMIT_DEPRECATED
#define SQLITE_OMIT_SHARED_CACHE
#define SQLITE_OMIT_AUTOINIT
#define SQLITE_USE_ALLOCA
// #define SQLITE_OMIT_DECLTYPE
// #define SQLITE_OMIT_PROGRESS_CALLBACK
// Other Options
#define SQLITE_ALLOW_URI_AUTHORITY
#define SQLITE_ENABLE_BATCH_ATOMIC_WRITE
#define SQLITE_TRUSTED_SCHEMA 0
#define SQLITE_DEFAULT_FOREIGN_KEYS 1
#define SQLITE_ENABLE_ATOMIC_WRITE
#define SQLITE_OMIT_DESERIALIZE
#define SQLITE_ENABLE_BATCH_ATOMIC_WRITE
// Because WASM does not support shared memory,
// SQLite disables WAL for WASM builds.
@@ -56,6 +58,12 @@
// https://sqlite.org/wal.html#noshm
#undef SQLITE_OMIT_WAL
// We have our own memdb VFS.
// To avoid interactions between the two,
// omit sqlite3_serialize/sqlite3_deserialize,
// which we also don't wrap.
#define SQLITE_OMIT_DESERIALIZE
// Amalgamated Extensions
#define SQLITE_ENABLE_MATH_FUNCTIONS 1

View File

@@ -10,7 +10,6 @@ int go_vfs_find(const char *zVfsName);
int go_randomness(sqlite3_vfs *, int nByte, char *zOut);
int go_sleep(sqlite3_vfs *, int microseconds);
int go_current_time(sqlite3_vfs *, double *);
int go_current_time_64(sqlite3_vfs *, sqlite3_int64 *);
int go_open(sqlite3_vfs *, sqlite3_filename zName, sqlite3_file *, int flags,
@@ -68,7 +67,7 @@ int sqlite3_os_init() {
static sqlite3_vfs os_vfs = {
.iVersion = 2,
.szOsFile = sizeof(struct go_file),
.mxPathname = 512,
.mxPathname = 1024,
.zName = "os",
.xOpen = go_open_wrapper,
@@ -78,7 +77,6 @@ int sqlite3_os_init() {
.xRandomness = go_randomness,
.xSleep = go_sleep,
.xCurrentTime = go_current_time,
.xCurrentTimeInt64 = go_current_time_64,
};
return sqlite3_vfs_register(&os_vfs, /*default=*/true);
@@ -89,11 +87,11 @@ int localtime_s(struct tm *const pTm, time_t const *const pTime) {
}
sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
if (zVfsName) {
if (zVfsName && go_vfs_find(zVfsName)) {
static sqlite3_vfs *go_vfs_list;
for (sqlite3_vfs *it = go_vfs_list; it; it = it->pNext) {
if (!strcmp(zVfsName, it->zName) && go_vfs_find(it->zName)) {
if (!strcmp(zVfsName, it->zName)) {
return it;
}
}
@@ -108,30 +106,27 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
}
}
if (go_vfs_find(zVfsName)) {
sqlite3_vfs *head = go_vfs_list;
go_vfs_list = malloc(sizeof(sqlite3_vfs) + strlen(zVfsName) + 1);
char *name = (char *)(go_vfs_list + 1);
strcpy(name, zVfsName);
*go_vfs_list = (sqlite3_vfs){
.iVersion = 2,
.szOsFile = sizeof(struct go_file),
.mxPathname = 512,
.zName = name,
.pNext = head,
sqlite3_vfs *head = go_vfs_list;
go_vfs_list = malloc(sizeof(sqlite3_vfs) + strlen(zVfsName) + 1);
char *name = (char *)(go_vfs_list + 1);
strcpy(name, zVfsName);
*go_vfs_list = (sqlite3_vfs){
.iVersion = 2,
.szOsFile = sizeof(struct go_file),
.mxPathname = 1024,
.zName = name,
.pNext = head,
.xOpen = go_open_wrapper,
.xDelete = go_delete,
.xAccess = go_access,
.xFullPathname = go_full_pathname,
.xOpen = go_open_wrapper,
.xDelete = go_delete,
.xAccess = go_access,
.xFullPathname = go_full_pathname,
.xRandomness = go_randomness,
.xSleep = go_sleep,
.xCurrentTime = go_current_time,
.xCurrentTimeInt64 = go_current_time_64,
};
return go_vfs_list;
}
.xRandomness = go_randomness,
.xSleep = go_sleep,
.xCurrentTimeInt64 = go_current_time_64,
};
return go_vfs_list;
}
return sqlite3_vfs_find_orig(zVfsName);
}

View File

@@ -136,12 +136,10 @@ static int go_cur_close_wrapper(sqlite3_vtab_cursor *pCursor) {
static int go_vtab_find_function_wrapper(
sqlite3_vtab *pVTab, int nArg, const char *zName,
void (**pxFunc)(sqlite3_context *, int, sqlite3_value **), void **ppArg) {
struct go_vtab *vtab = container_of(pVTab, struct go_vtab, base);
go_handle handle;
int rc = go_vtab_find_function(pVTab, nArg, zName, &handle);
if (rc) {
*pxFunc = go_func;
*pxFunc = go_func_wrapper;
*ppArg = handle;
}
return rc;

View File

@@ -37,7 +37,7 @@ func Test_sqlite_call_closed(t *testing.T) {
sqlite.close()
defer func() { _ = recover() }()
sqlite.call(sqlite.api.free)
sqlite.call("free")
t.Error("want panic")
}

195
stmt.go
View File

@@ -28,27 +28,43 @@ func (s *Stmt) Close() error {
return nil
}
r := s.c.call(s.c.api.finalize, uint64(s.handle))
r := s.c.call("sqlite3_finalize", uint64(s.handle))
s.handle = 0
return s.c.error(r)
}
// Conn returns the database connection to which the prepared statement belongs.
//
// https://sqlite.org/c3ref/db_handle.html
func (s *Stmt) Conn() *Conn {
return s.c
}
// ReadOnly returns true if and only if the statement
// makes no direct changes to the content of the database file.
//
// https://sqlite.org/c3ref/stmt_readonly.html
func (s *Stmt) ReadOnly() bool {
r := s.c.call("sqlite3_stmt_readonly", uint64(s.handle))
return r != 0
}
// Reset resets the prepared statement object.
//
// https://sqlite.org/c3ref/reset.html
func (s *Stmt) Reset() error {
r := s.c.call(s.c.api.reset, uint64(s.handle))
r := s.c.call("sqlite3_reset", uint64(s.handle))
s.err = nil
return s.c.error(r)
}
// ClearBindings resets all bindings on the prepared statement.
// Busy determines if a prepared statement has been reset.
//
// https://sqlite.org/c3ref/clear_bindings.html
func (s *Stmt) ClearBindings() error {
r := s.c.call(s.c.api.clearBindings, uint64(s.handle))
return s.c.error(r)
// https://sqlite.org/c3ref/stmt_busy.html
func (s *Stmt) Busy() bool {
r := s.c.call("sqlite3_stmt_busy", uint64(s.handle))
return r != 0
}
// Step evaluates the SQL statement.
@@ -62,9 +78,10 @@ func (s *Stmt) ClearBindings() error {
// https://sqlite.org/c3ref/step.html
func (s *Stmt) Step() bool {
s.c.checkInterrupt()
r := s.c.call(s.c.api.step, uint64(s.handle))
r := s.c.call("sqlite3_step", uint64(s.handle))
switch r {
case _ROW:
s.err = nil
return true
case _DONE:
s.err = nil
@@ -90,11 +107,32 @@ func (s *Stmt) Exec() error {
return s.Reset()
}
// Status monitors the performance characteristics of prepared statements.
//
// https://sqlite.org/c3ref/stmt_status.html
func (s *Stmt) Status(op StmtStatus, reset bool) int {
var i uint64
if reset {
i = 1
}
r := s.c.call("sqlite3_stmt_status", uint64(s.handle),
uint64(op), i)
return int(r)
}
// ClearBindings resets all bindings on the prepared statement.
//
// https://sqlite.org/c3ref/clear_bindings.html
func (s *Stmt) ClearBindings() error {
r := s.c.call("sqlite3_clear_bindings", uint64(s.handle))
return s.c.error(r)
}
// BindCount returns the number of SQL parameters in the prepared statement.
//
// https://sqlite.org/c3ref/bind_parameter_count.html
func (s *Stmt) BindCount() int {
r := s.c.call(s.c.api.bindCount,
r := s.c.call("sqlite3_bind_parameter_count",
uint64(s.handle))
return int(r)
}
@@ -104,9 +142,9 @@ func (s *Stmt) BindCount() int {
//
// https://sqlite.org/c3ref/bind_parameter_index.html
func (s *Stmt) BindIndex(name string) int {
defer s.c.arena.reset()
defer s.c.arena.mark()()
namePtr := s.c.arena.string(name)
r := s.c.call(s.c.api.bindIndex,
r := s.c.call("sqlite3_bind_parameter_index",
uint64(s.handle), uint64(namePtr))
return int(r)
}
@@ -116,7 +154,7 @@ func (s *Stmt) BindIndex(name string) int {
//
// https://sqlite.org/c3ref/bind_parameter_name.html
func (s *Stmt) BindName(param int) string {
r := s.c.call(s.c.api.bindName,
r := s.c.call("sqlite3_bind_parameter_name",
uint64(s.handle), uint64(param))
ptr := uint32(r)
@@ -153,7 +191,7 @@ func (s *Stmt) BindInt(param int, value int) error {
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindInt64(param int, value int64) error {
r := s.c.call(s.c.api.bindInteger,
r := s.c.call("sqlite3_bind_int64",
uint64(s.handle), uint64(param), uint64(value))
return s.c.error(r)
}
@@ -163,7 +201,7 @@ func (s *Stmt) BindInt64(param int, value int64) error {
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindFloat(param int, value float64) error {
r := s.c.call(s.c.api.bindFloat,
r := s.c.call("sqlite3_bind_double",
uint64(s.handle), uint64(param), math.Float64bits(value))
return s.c.error(r)
}
@@ -177,10 +215,10 @@ func (s *Stmt) BindText(param int, value string) error {
return TOOBIG
}
ptr := s.c.newString(value)
r := s.c.call(s.c.api.bindText,
r := s.c.call("sqlite3_bind_text64",
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
uint64(s.c.api.destructor), _UTF8)
uint64(s.c.freer), _UTF8)
return s.c.error(r)
}
@@ -193,10 +231,10 @@ func (s *Stmt) BindRawText(param int, value []byte) error {
return TOOBIG
}
ptr := s.c.newBytes(value)
r := s.c.call(s.c.api.bindText,
r := s.c.call("sqlite3_bind_text64",
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
uint64(s.c.api.destructor), _UTF8)
uint64(s.c.freer), _UTF8)
return s.c.error(r)
}
@@ -210,10 +248,10 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
return TOOBIG
}
ptr := s.c.newBytes(value)
r := s.c.call(s.c.api.bindBlob,
r := s.c.call("sqlite3_bind_blob64",
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
uint64(s.c.api.destructor))
uint64(s.c.freer))
return s.c.error(r)
}
@@ -222,7 +260,7 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindZeroBlob(param int, n int64) error {
r := s.c.call(s.c.api.bindZeroBlob,
r := s.c.call("sqlite3_bind_zeroblob64",
uint64(s.handle), uint64(param), uint64(n))
return s.c.error(r)
}
@@ -232,7 +270,7 @@ func (s *Stmt) BindZeroBlob(param int, n int64) error {
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindNull(param int) error {
r := s.c.call(s.c.api.bindNull,
r := s.c.call("sqlite3_bind_null",
uint64(s.handle), uint64(param))
return s.c.error(r)
}
@@ -265,10 +303,10 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
buf := util.View(s.c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
r := s.c.call(s.c.api.bindText,
r := s.c.call("sqlite3_bind_text64",
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(buf)),
uint64(s.c.api.destructor), _UTF8)
uint64(s.c.freer), _UTF8)
return s.c.error(r)
}
@@ -280,7 +318,7 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindPointer(param int, ptr any) error {
valPtr := util.AddHandle(s.c.ctx, ptr)
r := s.c.call(s.c.api.bindPointer,
r := s.c.call("sqlite3_bind_pointer_go",
uint64(s.handle), uint64(param), uint64(valPtr))
return s.c.error(r)
}
@@ -297,11 +335,24 @@ func (s *Stmt) BindJSON(param int, value any) error {
return s.BindRawText(param, data)
}
// BindValue binds a copy of value to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindValue(param int, value Value) error {
if value.c != s.c {
return MISUSE
}
r := s.c.call("sqlite3_bind_value",
uint64(s.handle), uint64(param), uint64(value.handle))
return s.c.error(r)
}
// ColumnCount returns the number of columns in a result set.
//
// https://sqlite.org/c3ref/column_count.html
func (s *Stmt) ColumnCount() int {
r := s.c.call(s.c.api.columnCount,
r := s.c.call("sqlite3_column_count",
uint64(s.handle))
return int(r)
}
@@ -311,7 +362,7 @@ func (s *Stmt) ColumnCount() int {
//
// https://sqlite.org/c3ref/column_name.html
func (s *Stmt) ColumnName(col int) string {
r := s.c.call(s.c.api.columnName,
r := s.c.call("sqlite3_column_name",
uint64(s.handle), uint64(col))
ptr := uint32(r)
@@ -326,11 +377,24 @@ func (s *Stmt) ColumnName(col int) string {
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnType(col int) Datatype {
r := s.c.call(s.c.api.columnType,
r := s.c.call("sqlite3_column_type",
uint64(s.handle), uint64(col))
return Datatype(r)
}
// ColumnDeclType returns the declared datatype of the result column.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_decltype.html
func (s *Stmt) ColumnDeclType(col int) string {
r := s.c.call("sqlite3_column_decltype",
uint64(s.handle), uint64(col))
if r == 0 {
return ""
}
return util.ReadString(s.c.mod, uint32(r), _MAX_NAME)
}
// ColumnBool returns the value of the result column as a bool.
// The leftmost column of the result set has the index 0.
// SQLite does not have a separate boolean storage class.
@@ -339,10 +403,7 @@ func (s *Stmt) ColumnType(col int) Datatype {
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnBool(col int) bool {
if i := s.ColumnInt64(col); i != 0 {
return true
}
return false
return s.ColumnInt64(col) != 0
}
// ColumnInt returns the value of the result column as an int.
@@ -358,7 +419,7 @@ func (s *Stmt) ColumnInt(col int) int {
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnInt64(col int) int64 {
r := s.c.call(s.c.api.columnInteger,
r := s.c.call("sqlite3_column_int64",
uint64(s.handle), uint64(col))
return int64(r)
}
@@ -368,7 +429,7 @@ func (s *Stmt) ColumnInt64(col int) int64 {
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnFloat(col int) float64 {
r := s.c.call(s.c.api.columnFloat,
r := s.c.call("sqlite3_column_double",
uint64(s.handle), uint64(col))
return math.Float64frombits(r)
}
@@ -422,7 +483,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnRawText(col int) []byte {
r := s.c.call(s.c.api.columnText,
r := s.c.call("sqlite3_column_text",
uint64(s.handle), uint64(col))
return s.columnRawBytes(col, uint32(r))
}
@@ -434,19 +495,19 @@ func (s *Stmt) ColumnRawText(col int) []byte {
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnRawBlob(col int) []byte {
r := s.c.call(s.c.api.columnBlob,
r := s.c.call("sqlite3_column_blob",
uint64(s.handle), uint64(col))
return s.columnRawBytes(col, uint32(r))
}
func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte {
if ptr == 0 {
r := s.c.call(s.c.api.errcode, uint64(s.c.handle))
r := s.c.call("sqlite3_errcode", uint64(s.c.handle))
s.err = s.c.error(r)
return nil
}
r := s.c.call(s.c.api.columnBytes,
r := s.c.call("sqlite3_column_bytes",
uint64(s.handle), uint64(col))
return util.View(s.c.mod, ptr, r)
}
@@ -475,17 +536,53 @@ func (s *Stmt) ColumnJSON(col int, ptr any) error {
return json.Unmarshal(data, ptr)
}
// Return true if stmt is an empty SQL statement.
// This is used as an optimization.
// It's OK to always return false here.
func emptyStatement(stmt string) bool {
for _, b := range []byte(stmt) {
switch b {
case ' ', '\n', '\r', '\t', '\v', '\f':
case ';':
default:
return false
// ColumnValue returns the unprotected value of the result column.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnValue(col int) Value {
r := s.c.call("sqlite3_column_value",
uint64(s.handle), uint64(col))
return Value{
c: s.c,
unprot: true,
handle: uint32(r),
}
}
func (s *Stmt) Columns(dest []any) error {
defer s.c.arena.mark()()
count := uint64(len(dest))
typePtr := s.c.arena.new(count)
dataPtr := s.c.arena.new(8 * count)
r := s.c.call("sqlite3_columns_go",
uint64(s.handle), count, uint64(typePtr), uint64(dataPtr))
if err := s.c.error(r); err != nil {
return err
}
types := util.View(s.c.mod, typePtr, count)
for i := range dest {
switch types[i] {
case byte(INTEGER):
dest[i] = int64(util.ReadUint64(s.c.mod, dataPtr+8*uint32(i)))
continue
case byte(FLOAT):
dest[i] = util.ReadFloat64(s.c.mod, dataPtr+8*uint32(i))
continue
case byte(NULL):
dest[i] = nil
continue
}
ptr := util.ReadUint32(s.c.mod, dataPtr+8*uint32(i)+0)
len := util.ReadUint32(s.c.mod, dataPtr+8*uint32(i)+4)
buf := util.View(s.c.mod, ptr, uint64(len))
if types[i] == byte(TEXT) {
dest[i] = string(buf)
} else {
dest[i] = buf
}
}
return true
return nil
}

View File

@@ -1,58 +0,0 @@
package sqlite3
import "testing"
func Test_emptyStatement(t *testing.T) {
t.Parallel()
tests := []struct {
name string
stmt string
want bool
}{
{"empty", "", true},
{"space", " ", true},
{"separator", ";\n ", true},
{"begin", "BEGIN", false},
{"select", "SELECT 1;", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := emptyStatement(tt.stmt); got != tt.want {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}
func Fuzz_emptyStatement(f *testing.F) {
f.Add("")
f.Add(" ")
f.Add(";\n ")
f.Add("; ;\v")
f.Add("BEGIN")
f.Add("SELECT 1;")
db, err := Open(":memory:")
if err != nil {
f.Fatal(err)
}
defer db.Close()
f.Fuzz(func(t *testing.T, sql string) {
// If empty, SQLite parses it as empty.
if emptyStatement(sql) {
stmt, tail, err := db.Prepare(sql)
if err != nil {
t.Errorf("%q, %v", sql, err)
}
if stmt != nil {
t.Errorf("%q, %v", sql, stmt)
}
if tail != "" {
t.Errorf("%q", sql)
}
stmt.Close()
}
})
}

View File

@@ -289,3 +289,78 @@ func TestConn_Prepare_invalid(t *testing.T) {
t.Error("got message:", got)
}
}
func TestConn_Config(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
o, err := db.Config(sqlite3.DBCONFIG_DEFENSIVE)
if err != nil {
t.Fatal(err)
}
if o != false {
t.Error("want false")
}
o, err = db.Config(sqlite3.DBCONFIG_DEFENSIVE, true)
if err != nil {
t.Fatal(err)
}
if o != true {
t.Error("want true")
}
o, err = db.Config(sqlite3.DBCONFIG_DEFENSIVE)
if err != nil {
t.Fatal(err)
}
if o != true {
t.Error("want true")
}
o, err = db.Config(sqlite3.DBCONFIG_DEFENSIVE, false)
if err != nil {
t.Fatal(err)
}
if o != false {
t.Error("want false")
}
o, err = db.Config(sqlite3.DBCONFIG_DEFENSIVE)
if err != nil {
t.Fatal(err)
}
if o != false {
t.Error("want false")
}
}
func TestConn_ConfigLog(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
var code sqlite3.ExtendedErrorCode
err = db.ConfigLog(func(c sqlite3.ExtendedErrorCode, msg string) {
t.Log(msg)
code = c
})
if err != nil {
t.Fatal(err)
}
db.Prepare(`SELECT * FRM sqlite_schema`)
if code != sqlite3.ExtendedErrorCode(sqlite3.ERROR) {
t.Error("want sqlite3.ERROR")
}
}

View File

@@ -15,6 +15,9 @@ import (
//go:embed testdata/wal.db
var waldb []byte
//go:embed testdata/utf16be.db
var utf16db []byte
func TestDB_memory(t *testing.T) {
t.Parallel()
testDB(t, ":memory:")
@@ -34,12 +37,22 @@ func TestDB_nolock(t *testing.T) {
func TestDB_wal(t *testing.T) {
t.Parallel()
wal := filepath.Join(t.TempDir(), "test.db")
err := os.WriteFile(wal, waldb, 0666)
tmp := filepath.Join(t.TempDir(), "test.db")
err := os.WriteFile(tmp, waldb, 0666)
if err != nil {
t.Fatal(err)
}
testDB(t, wal)
testDB(t, tmp)
}
func TestDB_utf16(t *testing.T) {
t.Parallel()
tmp := filepath.Join(t.TempDir(), "test.db")
err := os.WriteFile(tmp, utf16db, 0666)
if err != nil {
t.Fatal(err)
}
testDB(t, tmp)
}
func TestDB_vfs(t *testing.T) {
@@ -80,6 +93,9 @@ func testDB(t testing.TB, name string) {
id := stmt.ColumnInt(0)
name := stmt.ColumnText(1)
if row >= 3 {
continue
}
if id != ids[row] {
t.Errorf("got %d, want %d", id, ids[row])
}

View File

@@ -72,6 +72,17 @@ func TestDriver(t *testing.T) {
}
defer rows.Close()
typs, err := rows.ColumnTypes()
if err != nil {
t.Fatal(err)
}
if got := typs[0].DatabaseTypeName(); got != "INT" {
t.Errorf("got %s, want INT", got)
}
if got := typs[1].DatabaseTypeName(); got != "VARCHAR" {
t.Errorf("got %s, want INT", got)
}
row := 0
ids := []int{0, 1, 2}
names := []string{"go", "zig", "whatever"}

View File

@@ -167,6 +167,26 @@ func TestCreateFunction(t *testing.T) {
}
}
func TestOverloadFunction(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.OverloadFunction("test", 0)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`SELECT test()`)
if err == nil {
t.Fatal("want error")
}
}
func TestAnyCollationNeeded(t *testing.T) {
t.Parallel()

View File

@@ -60,7 +60,7 @@ func TestMultiProcess(t *testing.T) {
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
cmd := exec.Command("go", "test", "-v", "-run", "TestChildProcess")
cmd := exec.Command(os.Args[0], append(os.Args[1:], "-test.v", "-test.run=TestChildProcess")...)
out, err := cmd.StdoutPipe()
if err != nil {
t.Fatal(err)
@@ -71,8 +71,10 @@ func TestMultiProcess(t *testing.T) {
var buf [3]byte
// Wait for child to start.
if _, err := io.ReadFull(out, buf[:]); err != nil || string(buf[:]) != "===" {
if _, err := io.ReadFull(out, buf[:]); err != nil {
t.Fatal(err)
} else if str := string(buf[:]); str != "===" {
t.Fatal(str)
}
testParallel(t, name, 1000)

View File

@@ -30,6 +30,10 @@ func TestStmt(t *testing.T) {
}
defer stmt.Close()
if got := stmt.ReadOnly(); got != false {
t.Error("got true, want false")
}
if got := stmt.BindCount(); got != 1 {
t.Errorf("got %d, want 1", got)
}
@@ -137,6 +141,10 @@ func TestStmt(t *testing.T) {
}
defer stmt.Close()
if got := stmt.ReadOnly(); got != true {
t.Error("got false, want true")
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
@@ -586,6 +594,10 @@ func TestStmt_ColumnTime(t *testing.T) {
t.Errorf("want error")
}
}
if got := stmt.Status(sqlite3.STMTSTATUS_RUN, true); got != 1 {
t.Errorf("got %d, want 1", got)
}
}
func TestStmt_Error(t *testing.T) {

Some files were not shown because too many files have changed in this diff Show More