Compare commits

...

69 Commits

Author SHA1 Message Date
Nuno Cruces
d78239bfbf More EINTR. 2025-03-14 00:07:09 +00:00
Nuno Cruces
49852732b2 Optimize. 2025-03-12 17:29:12 +00:00
Nuno Cruces
9b90d076cb Update README.md 2025-03-12 12:01:13 +00:00
Nuno Cruces
15b94577b1 Tweak. 2025-03-11 20:15:53 +00:00
Nuno Cruces
25557244cc Global ConfigLog. 2025-03-11 17:07:56 +00:00
Nuno Cruces
c2d3bf0cfc Reduce flakyness. 2025-03-11 14:57:48 +00:00
Nuno Cruces
58a5682084 Handle EINTR. 2025-03-11 12:07:14 +00:00
Nuno Cruces
1ed954e96f Fix #243. 2025-03-10 14:54:34 +00:00
Nuno Cruces
9e7a0a875d Improved arg reuse. 2025-03-10 12:01:15 +00:00
Nuno Cruces
26adda4529 Seq aggregate functions (#229) 2025-03-08 14:07:43 +00:00
Nuno Cruces
2f6cd8de1d Docs. 2025-03-07 11:47:02 +00:00
dependabot[bot]
e027e055ff Bump golang.org/x/crypto from 0.35.0 to 0.36.0 (#239)
Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.35.0 to 0.36.0.
- [Commits](https://github.com/golang/crypto/compare/v0.35.0...v0.36.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-05 23:11:20 +00:00
dependabot[bot]
63fdc141e5 Bump golang.org/x/text from 0.22.0 to 0.23.0 (#240)
Bumps [golang.org/x/text](https://github.com/golang/text) from 0.22.0 to 0.23.0.
- [Release notes](https://github.com/golang/text/releases)
- [Commits](https://github.com/golang/text/compare/v0.22.0...v0.23.0)

---
updated-dependencies:
- dependency-name: golang.org/x/text
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-05 22:57:38 +00:00
Nuno Cruces
0bbd145a49 Update modules. 2025-02-28 16:57:25 +00:00
Nuno Cruces
c755ef96e6 Export logging. 2025-02-28 14:50:22 +00:00
Nuno Cruces
9a69e407cc Fix #235. 2025-02-28 00:33:45 +00:00
Nuno Cruces
e9db0d8e84 Issue #233. 2025-02-27 00:07:49 +00:00
dependabot[bot]
dadf53e175 Bump golang.org/x/crypto from 0.33.0 to 0.35.0 (#231)
Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.33.0 to 0.35.0.
- [Commits](https://github.com/golang/crypto/compare/v0.33.0...v0.35.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-02-24 23:15:52 +00:00
Nuno Cruces
f536765206 Go 1.23. 2025-02-24 14:09:55 +00:00
Nuno Cruces
12034c4f0b Retract. 2025-02-24 14:09:55 +00:00
Nuno Cruces
b4e5d1a213 Issue #230. 2025-02-24 13:13:25 +00:00
Nuno Cruces
b06c7dda6c Checksum robustness. 2025-02-24 13:13:25 +00:00
Nuno Cruces
5e1909a20e Issue #230. 2025-02-24 13:13:25 +00:00
Nuno Cruces
77d74baca5 Fix potential leak. 2025-02-22 12:48:41 +00:00
Nuno Cruces
4142680d5a Updated modules. 2025-02-20 13:36:02 +00:00
Nuno Cruces
9f4fe6f27c SQLite 3.49.1. 2025-02-18 18:03:20 +00:00
Nuno Cruces
7870ce0690 wazero v1.9.0. 2025-02-18 16:36:22 +00:00
Nuno Cruces
ec3226e16e Fix CI. 2025-02-17 12:21:53 +00:00
Nuno Cruces
4dd7bd0ff2 More type safe. 2025-02-17 12:00:55 +00:00
Nuno Cruces
975feb2fd4 Issue #228. 2025-02-16 18:09:42 +00:00
Nuno Cruces
58f8c2d33e Ignore. 2025-02-15 01:12:38 +00:00
Nuno Cruces
019660eed6 Fix warning. 2025-02-12 09:58:58 +00:00
Nuno Cruces
30c1bcdbe9 Serdes robustness. 2025-02-12 00:41:16 +00:00
Nuno Cruces
9b4002f5ac Add missing consts. 2025-02-11 18:24:05 +00:00
Nuno Cruces
2a78d4bc2b Updated modules. 2025-02-11 18:15:14 +00:00
Nuno Cruces
c09623a903 binaryen-version_122. 2025-02-11 18:07:30 +00:00
Nuno Cruces
fa613f9ddb Remove go.work. 2025-02-11 17:50:37 +00:00
Nuno Cruces
57997201ee SQLite 3.49.0. 2025-02-10 07:20:01 +00:00
dependabot[bot]
6995cca5c0 Bump golang.org/x/crypto from 0.32.0 to 0.33.0 (#225)
Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.32.0 to 0.33.0.
- [Commits](https://github.com/golang/crypto/compare/v0.32.0...v0.33.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-02-08 08:40:31 +00:00
dependabot[bot]
a10eef3ac8 Bump golang.org/x/sys from 0.29.0 to 0.30.0 (#223)
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.29.0 to 0.30.0.
- [Commits](https://github.com/golang/sys/compare/v0.29.0...v0.30.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-02-05 19:02:19 +01:00
dependabot[bot]
d627ca3dc1 Bump golang.org/x/text from 0.21.0 to 0.22.0 (#221)
Bumps [golang.org/x/text](https://github.com/golang/text) from 0.21.0 to 0.22.0.
- [Release notes](https://github.com/golang/text/releases)
- [Commits](https://github.com/golang/text/compare/v0.21.0...v0.22.0)

---
updated-dependencies:
- dependency-name: golang.org/x/text
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-02-04 23:35:35 +01:00
Nuno Cruces
b2f7ab8335 Fix GlobPrefix. (#220) 2025-01-28 17:54:17 +00:00
Nuno Cruces
c9135b9823 UUID version and timestamp. 2025-01-28 11:51:27 +00:00
dependabot[bot]
0d9ed94aad Bump github.com/ncruces/sort from 0.1.2 to 0.1.3 (#218)
Bumps [github.com/ncruces/sort](https://github.com/ncruces/sort) from 0.1.2 to 0.1.3.
- [Commits](https://github.com/ncruces/sort/compare/v0.1.2...v0.1.3)

---
updated-dependencies:
- dependency-name: github.com/ncruces/sort
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-01-24 22:56:31 +00:00
Nuno Cruces
1d951ecd18 Go 1.22. 2025-01-24 10:46:05 +00:00
Nuno Cruces
c0298ad274 NetBSD 10.1. 2025-01-22 17:20:57 +00:00
Nuno Cruces
42bad5891a Skewness and excess kurtosis. 2025-01-22 12:09:20 +00:00
Nuno Cruces
40090d8250 Moments. 2025-01-21 14:11:47 +00:00
Nuno Cruces
d2f162972d More type safe. (#216) 2025-01-21 01:42:57 +00:00
Nuno Cruces
e2da469834 Fix numerical issues. 2025-01-20 14:39:36 +00:00
Nuno Cruces
1677b97fa4 Fix #215. 2025-01-19 01:30:04 +00:00
Nuno Cruces
407e13d238 Handle some errors. 2025-01-17 14:40:12 +00:00
Nuno Cruces
9132f74b69 Use Linux ARM runners. 2025-01-17 11:49:35 +00:00
Nuno Cruces
c024121fd2 C tweaks. 2025-01-17 10:51:25 +00:00
Nuno Cruces
aa8287f8e7 Allow others to enable threads. 2025-01-16 17:21:36 +00:00
Nuno Cruces
ab09da7136 More unicode. 2025-01-16 15:46:49 +00:00
Nuno Cruces
a159b548ed Dependencies. 2025-01-14 17:53:40 +00:00
Nuno Cruces
d9b37307e7 SQLite 3.48.0. 2025-01-14 17:33:53 +00:00
Nuno Cruces
3bae1d7d4b SQLITE_FCNTL_BUSYHANDLER. 2025-01-14 17:09:54 +00:00
Nuno Cruces
8887036c20 SQLITE_FCNTL_SYNC. 2025-01-14 10:05:54 +00:00
Nuno Cruces
ccb3dcd097 SQLITE_FCNTL_PDB. 2025-01-13 13:45:41 +00:00
Nuno Cruces
a9f33cc2b0 New constants. 2025-01-13 12:05:27 +00:00
Nuno Cruces
f025ffb385 Fix naming. 2025-01-13 09:28:47 +00:00
Nuno Cruces
aa4357a78f Ordered-set aggregate syntax. 2025-01-11 19:22:04 +00:00
Nuno Cruces
aef7f051a8 Prevent modification. 2025-01-10 12:38:11 +00:00
Nuno Cruces
a79ee4c2c6 Avoid weird mutex. 2025-01-09 13:44:29 +00:00
Nuno Cruces
7424747338 Update README.md 2025-01-08 23:16:25 +00:00
Nuno Cruces
11830e05a6 Remove legacy. 2025-01-08 18:34:48 +00:00
Nuno Cruces
7dc4520690 Fix #207. 2025-01-08 16:36:41 +00:00
130 changed files with 2973 additions and 1670 deletions

View File

@@ -1,11 +1,6 @@
name: VM Actions matrix
description: VM Actions matrix template
inputs:
run:
description: The CI command to run
required: true
runs:
using: composite
steps:
@@ -13,4 +8,4 @@ runs:
with:
usesh: true
copyback: false
run: ${{inputs.run}}
run: . ./test.sh

View File

@@ -3,13 +3,13 @@ set -euo pipefail
if [[ "$OSTYPE" == "linux"* ]]; then
WASI_SDK="https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-25/wasi-sdk-25.0-x86_64-linux.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_121/binaryen-version_121-x86_64-linux.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_122/binaryen-version_122-x86_64-linux.tar.gz"
elif [[ "$OSTYPE" == "darwin"* ]]; then
WASI_SDK="https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-25/wasi-sdk-25.0-arm64-macos.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_121/binaryen-version_121-arm64-macos.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_122/binaryen-version_122-arm64-macos.tar.gz"
elif [[ "$OSTYPE" == "msys" || "$OSTYPE" == "cygwin" ]]; then
WASI_SDK="https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-25/wasi-sdk-25.0-x86_64-windows.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_121/binaryen-version_121-x86_64-windows.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_122/binaryen-version_122-x86_64-windows.tar.gz"
fi
# Download tools

View File

@@ -7,12 +7,14 @@ on:
- '**.go'
- '**.mod'
- '**.wasm'
- '**.yml'
pull_request:
branches: [ 'main' ]
paths:
- '**.go'
- '**.mod'
- '**.wasm'
- '**.yml'
workflow_dispatch:
jobs:
@@ -57,16 +59,22 @@ jobs:
run: go test -v -tags sqlite3_dotlk ./...
if: matrix.os != 'windows-latest'
- name: Test modules
shell: bash
run: |
go work init .
go work use -r embed gormlite
go test -v ./embed/bcw2/...
- name: Test GORM
shell: bash
run: gormlite/test.sh
- name: Test modules
shell: bash
run: go test -v ./embed/bcw2/...
if: matrix.os != 'windows-latest'
- name: Collect coverage
run: go run github.com/dave/courtney@latest
run: |
go get -tool github.com/dave/courtney@v0.4.4
go tool courtney
if: |
github.event_name == 'push' &&
matrix.os == 'ubuntu-latest'
@@ -88,7 +96,7 @@ jobs:
version: '14.2'
flags: '-test.v'
- name: netbsd
version: '10.0'
version: '10.1'
flags: '-test.v'
- name: freebsd
arch: arm64
@@ -96,7 +104,7 @@ jobs:
flags: '-test.v -test.short'
- name: netbsd
arch: arm64
version: '10.0'
version: '10.1'
flags: '-test.v -test.short'
- name: openbsd
version: '7.6'
@@ -115,7 +123,7 @@ jobs:
run: .github/workflows/build-test.sh
- name: Test
uses: cross-platform-actions/action@v0.26.0
uses: cross-platform-actions/action@v0.27.0
with:
operating_system: ${{ matrix.os.name }}
architecture: ${{ matrix.os.arch }}
@@ -154,10 +162,6 @@ jobs:
- name: Test
uses: ./.github/actions/vmactions
with:
usesh: true
copyback: false
run: . ./test.sh
test-wasip1:
runs-on: ubuntu-latest
@@ -170,7 +174,7 @@ jobs:
with: { go-version: stable }
- name: Set path
run: echo "$(go env GOROOT)/misc/wasm" >> "$GITHUB_PATH"
run: echo "$(go env GOROOT)/lib/wasm" >> "$GITHUB_PATH"
- name: Test wasmtime
env:
@@ -193,9 +197,6 @@ jobs:
- name: Test 386 (32-bit)
run: GOARCH=386 go test -v -short ./...
- name: Test arm64 (compiler)
run: GOARCH=arm64 go test -v -short ./...
- name: Test riscv64 (interpreter)
run: GOARCH=riscv64 go test -v -short ./...
@@ -205,6 +206,18 @@ jobs:
- name: Test s390x (big-endian)
run: GOARCH=s390x go test -v -short -tags sqlite3_dotlk ./...
test-linuxarm:
runs-on: ubuntu-24.04-arm
needs: test
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with: { go-version: stable }
- name: Test
run: go test -v ./...
test-macintel:
runs-on: macos-13
needs: test

9
.gitignore vendored
View File

@@ -13,4 +13,11 @@
# Dependency directories (remove the comment below to include it)
# vendor/
tools
tools
# Go workspace file
go.work
go.work.sum
# env file
.env

View File

@@ -65,18 +65,21 @@ db.QueryRow(`SELECT sqlite_version()`).Scan(&version)
This module replaces the SQLite [OS Interface](https://sqlite.org/vfs.html)
(aka VFS) with a [pure Go](vfs/) implementation,
which has advantages and disadvantages.
Read more about the Go VFS design [here](vfs/README.md).
Because each database connection executes within a Wasm sandboxed environment,
memory usage will be higher than alternatives.
### Testing
This project aims for [high test coverage](https://github.com/ncruces/go-sqlite3/wiki/Test-coverage-report).
It also benefits greatly from [SQLite's](https://sqlite.org/testing.html) and
[wazero's](https://tetrate.io/blog/introducing-wazero-from-tetrate/#:~:text=Rock%2Dsolid%20test%20approach) thorough testing.
[wazero's](https://tetrate.io/blog/introducing-wazero-from-tetrate/#:~:text=Rock%2Dsolid%20test%20approach)
thorough testing.
Every commit is [tested](https://github.com/ncruces/go-sqlite3/wiki/Support-matrix) on
Linux (amd64/arm64/386/riscv64/ppc64le/s390x), macOS (amd64/arm64),
Windows (amd64), FreeBSD (amd64), OpenBSD (amd64), NetBSD (amd64),
Linux (amd64/arm64/386/riscv64/ppc64le/s390x), macOS (arm64/amd64),
Windows (amd64), FreeBSD (amd64/arm64), OpenBSD (amd64), NetBSD (amd64/arm64),
DragonFly BSD (amd64), illumos (amd64), and Solaris (amd64).
The Go VFS is tested by running SQLite's
@@ -84,12 +87,21 @@ The Go VFS is tested by running SQLite's
### Performance
Perfomance of the [`database/sql`](https://pkg.go.dev/database/sql) driver is
Performance of the [`database/sql`](https://pkg.go.dev/database/sql) driver is
[competitive](https://github.com/cvilsmeier/go-sqlite-bench) with alternatives.
The Wasm and VFS layers are also tested by running SQLite's
The Wasm and VFS layers are also benchmarked by running SQLite's
[speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c).
### Concurrency
This module behaves similarly to SQLite in [multi-thread](https://sqlite.org/threadsafe.html) mode:
it is goroutine-safe, provided that no single database connection, or object derived from it,
is used concurrently by multiple goroutines.
The [`database/sql`](https://pkg.go.dev/database/sql) API is safe to use concurrently,
according to its documentation.
### FAQ, issues, new features
For questions, please see [Discussions](https://github.com/ncruces/go-sqlite3/discussions/categories/q-a).
@@ -98,7 +110,7 @@ Also, post there if you used this driver for something interesting
([_"Show and tell"_](https://github.com/ncruces/go-sqlite3/discussions/categories/show-and-tell)),
have an [idea](https://github.com/ncruces/go-sqlite3/discussions/categories/ideas)…
The [Issue](https://github.com/ncruces/go-sqlite3/issues) tracker is for bugs we want fixed,
The [Issue](https://github.com/ncruces/go-sqlite3/issues) tracker is for bugs,
and features we're working on, planning to work on, or asking for help with.
### Alternatives
@@ -106,4 +118,4 @@ and features we're working on, planning to work on, or asking for help with.
- [`modernc.org/sqlite`](https://pkg.go.dev/modernc.org/sqlite)
- [`crawshaw.io/sqlite`](https://pkg.go.dev/crawshaw.io/sqlite)
- [`github.com/mattn/go-sqlite3`](https://pkg.go.dev/github.com/mattn/go-sqlite3)
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)

View File

@@ -5,8 +5,8 @@ package sqlite3
// https://sqlite.org/c3ref/backup.html
type Backup struct {
c *Conn
handle uint32
otherc uint32
handle ptr_t
otherc ptr_t
}
// Backup backs up srcDB on the src connection to the "main" database in dstURI.
@@ -61,7 +61,7 @@ func (src *Conn) BackupInit(srcDB, dstURI string) (*Backup, error) {
return src.backupInit(dst, "main", src.handle, srcDB)
}
func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string) (*Backup, error) {
func (c *Conn) backupInit(dst ptr_t, dstName string, src ptr_t, srcName string) (*Backup, error) {
defer c.arena.mark()()
dstPtr := c.arena.string(dstName)
srcPtr := c.arena.string(srcName)
@@ -71,19 +71,19 @@ func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string
other = src
}
r := c.call("sqlite3_backup_init",
uint64(dst), uint64(dstPtr),
uint64(src), uint64(srcPtr))
if r == 0 {
ptr := ptr_t(c.call("sqlite3_backup_init",
stk_t(dst), stk_t(dstPtr),
stk_t(src), stk_t(srcPtr)))
if ptr == 0 {
defer c.closeDB(other)
r = c.call("sqlite3_errcode", uint64(dst))
return nil, c.sqlite.error(r, dst)
rc := res_t(c.call("sqlite3_errcode", stk_t(dst)))
return nil, c.sqlite.error(rc, dst)
}
return &Backup{
c: c,
otherc: other,
handle: uint32(r),
handle: ptr,
}, nil
}
@@ -97,10 +97,10 @@ func (b *Backup) Close() error {
return nil
}
r := b.c.call("sqlite3_backup_finish", uint64(b.handle))
rc := res_t(b.c.call("sqlite3_backup_finish", stk_t(b.handle)))
b.c.closeDB(b.otherc)
b.handle = 0
return b.c.error(r)
return b.c.error(rc)
}
// Step copies up to nPage pages between the source and destination databases.
@@ -108,11 +108,11 @@ func (b *Backup) Close() error {
//
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backupstep
func (b *Backup) Step(nPage int) (done bool, err error) {
r := b.c.call("sqlite3_backup_step", uint64(b.handle), uint64(nPage))
if r == _DONE {
rc := res_t(b.c.call("sqlite3_backup_step", stk_t(b.handle), stk_t(nPage)))
if rc == _DONE {
return true, nil
}
return false, b.c.error(r)
return false, b.c.error(rc)
}
// Remaining returns the number of pages still to be backed up
@@ -120,8 +120,8 @@ func (b *Backup) Step(nPage int) (done bool, err error) {
//
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backupremaining
func (b *Backup) Remaining() int {
r := b.c.call("sqlite3_backup_remaining", uint64(b.handle))
return int(int32(r))
n := int32(b.c.call("sqlite3_backup_remaining", stk_t(b.handle)))
return int(n)
}
// PageCount returns the total number of pages in the source database
@@ -129,6 +129,6 @@ func (b *Backup) Remaining() int {
//
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backuppagecount
func (b *Backup) PageCount() int {
r := b.c.call("sqlite3_backup_pagecount", uint64(b.handle))
return int(int32(r))
n := int32(b.c.call("sqlite3_backup_pagecount", stk_t(b.handle)))
return int(n)
}

64
blob.go
View File

@@ -20,8 +20,8 @@ type Blob struct {
c *Conn
bytes int64
offset int64
handle uint32
bufptr uint32
handle ptr_t
bufptr ptr_t
buflen int64
}
@@ -37,23 +37,23 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob,
tablePtr := c.arena.string(table)
columnPtr := c.arena.string(column)
var flags uint64
var flags int32
if write {
flags = 1
}
c.checkInterrupt(c.handle)
r := c.call("sqlite3_blob_open", uint64(c.handle),
uint64(dbPtr), uint64(tablePtr), uint64(columnPtr),
uint64(row), flags, uint64(blobPtr))
rc := res_t(c.call("sqlite3_blob_open", stk_t(c.handle),
stk_t(dbPtr), stk_t(tablePtr), stk_t(columnPtr),
stk_t(row), stk_t(flags), stk_t(blobPtr)))
if err := c.error(r); err != nil {
if err := c.error(rc); err != nil {
return nil, err
}
blob := Blob{c: c}
blob.handle = util.ReadUint32(c.mod, blobPtr)
blob.bytes = int64(c.call("sqlite3_blob_bytes", uint64(blob.handle)))
blob.handle = util.Read32[ptr_t](c.mod, blobPtr)
blob.bytes = int64(int32(c.call("sqlite3_blob_bytes", stk_t(blob.handle))))
return &blob, nil
}
@@ -67,10 +67,10 @@ func (b *Blob) Close() error {
return nil
}
r := b.c.call("sqlite3_blob_close", uint64(b.handle))
rc := res_t(b.c.call("sqlite3_blob_close", stk_t(b.handle)))
b.c.free(b.bufptr)
b.handle = 0
return b.c.error(r)
return b.c.error(rc)
}
// Size returns the size of the BLOB in bytes.
@@ -94,13 +94,13 @@ func (b *Blob) Read(p []byte) (n int, err error) {
want = avail
}
if want > b.buflen {
b.bufptr = b.c.realloc(b.bufptr, uint64(want))
b.bufptr = b.c.realloc(b.bufptr, want)
b.buflen = want
}
r := b.c.call("sqlite3_blob_read", uint64(b.handle),
uint64(b.bufptr), uint64(want), uint64(b.offset))
err = b.c.error(r)
rc := res_t(b.c.call("sqlite3_blob_read", stk_t(b.handle),
stk_t(b.bufptr), stk_t(want), stk_t(b.offset)))
err = b.c.error(rc)
if err != nil {
return 0, err
}
@@ -109,7 +109,7 @@ func (b *Blob) Read(p []byte) (n int, err error) {
err = io.EOF
}
copy(p, util.View(b.c.mod, b.bufptr, uint64(want)))
copy(p, util.View(b.c.mod, b.bufptr, want))
return int(want), err
}
@@ -127,19 +127,19 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
want = avail
}
if want > b.buflen {
b.bufptr = b.c.realloc(b.bufptr, uint64(want))
b.bufptr = b.c.realloc(b.bufptr, want)
b.buflen = want
}
for want > 0 {
r := b.c.call("sqlite3_blob_read", uint64(b.handle),
uint64(b.bufptr), uint64(want), uint64(b.offset))
err = b.c.error(r)
rc := res_t(b.c.call("sqlite3_blob_read", stk_t(b.handle),
stk_t(b.bufptr), stk_t(want), stk_t(b.offset)))
err = b.c.error(rc)
if err != nil {
return n, err
}
mem := util.View(b.c.mod, b.bufptr, uint64(want))
mem := util.View(b.c.mod, b.bufptr, want)
m, err := w.Write(mem[:want])
b.offset += int64(m)
n += int64(m)
@@ -165,14 +165,14 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
func (b *Blob) Write(p []byte) (n int, err error) {
want := int64(len(p))
if want > b.buflen {
b.bufptr = b.c.realloc(b.bufptr, uint64(want))
b.bufptr = b.c.realloc(b.bufptr, want)
b.buflen = want
}
util.WriteBytes(b.c.mod, b.bufptr, p)
r := b.c.call("sqlite3_blob_write", uint64(b.handle),
uint64(b.bufptr), uint64(want), uint64(b.offset))
err = b.c.error(r)
rc := res_t(b.c.call("sqlite3_blob_write", stk_t(b.handle),
stk_t(b.bufptr), stk_t(want), stk_t(b.offset)))
err = b.c.error(rc)
if err != nil {
return 0, err
}
@@ -196,17 +196,17 @@ func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) {
want = 1
}
if want > b.buflen {
b.bufptr = b.c.realloc(b.bufptr, uint64(want))
b.bufptr = b.c.realloc(b.bufptr, want)
b.buflen = want
}
for {
mem := util.View(b.c.mod, b.bufptr, uint64(want))
mem := util.View(b.c.mod, b.bufptr, want)
m, err := r.Read(mem[:want])
if m > 0 {
r := b.c.call("sqlite3_blob_write", uint64(b.handle),
uint64(b.bufptr), uint64(m), uint64(b.offset))
err := b.c.error(r)
rc := res_t(b.c.call("sqlite3_blob_write", stk_t(b.handle),
stk_t(b.bufptr), stk_t(m), stk_t(b.offset)))
err := b.c.error(rc)
if err != nil {
return n, err
}
@@ -254,8 +254,8 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
// https://sqlite.org/c3ref/blob_reopen.html
func (b *Blob) Reopen(row int64) error {
b.c.checkInterrupt(b.c.handle)
err := b.c.error(b.c.call("sqlite3_blob_reopen", uint64(b.handle), uint64(row)))
b.bytes = int64(b.c.call("sqlite3_blob_bytes", uint64(b.handle)))
err := b.c.error(res_t(b.c.call("sqlite3_blob_reopen", stk_t(b.handle), stk_t(row))))
b.bytes = int64(int32(b.c.call("sqlite3_blob_bytes", stk_t(b.handle))))
b.offset = 0
return err
}

168
config.go
View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"strconv"
"sync/atomic"
"github.com/tetratelabs/wazero/api"
@@ -32,7 +33,7 @@ func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) {
defer c.arena.mark()()
argsPtr := c.arena.new(intlen + ptrlen)
var flag int
var flag int32
switch {
case len(arg) == 0:
flag = -1
@@ -40,31 +41,40 @@ func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) {
flag = 1
}
util.WriteUint32(c.mod, argsPtr+0*ptrlen, uint32(flag))
util.WriteUint32(c.mod, argsPtr+1*ptrlen, argsPtr)
util.Write32(c.mod, argsPtr+0*ptrlen, flag)
util.Write32(c.mod, argsPtr+1*ptrlen, argsPtr)
r := c.call("sqlite3_db_config", uint64(c.handle),
uint64(op), uint64(argsPtr))
return util.ReadUint32(c.mod, argsPtr) != 0, c.error(r)
rc := res_t(c.call("sqlite3_db_config", stk_t(c.handle),
stk_t(op), stk_t(argsPtr)))
return util.ReadBool(c.mod, argsPtr), c.error(rc)
}
var defaultLogger atomic.Pointer[func(code ExtendedErrorCode, msg string)]
// ConfigLog sets up the default error logging callback for new connections.
//
// https://sqlite.org/errlog.html
func ConfigLog(cb func(code ExtendedErrorCode, msg string)) {
defaultLogger.Store(&cb)
}
// ConfigLog sets up the error logging callback for the connection.
//
// https://sqlite.org/errlog.html
func (c *Conn) ConfigLog(cb func(code ExtendedErrorCode, msg string)) error {
var enable uint64
var enable int32
if cb != nil {
enable = 1
}
r := c.call("sqlite3_config_log_go", enable)
if err := c.error(r); err != nil {
rc := res_t(c.call("sqlite3_config_log_go", stk_t(enable)))
if err := c.error(rc); err != nil {
return err
}
c.log = cb
return nil
}
func logCallback(ctx context.Context, mod api.Module, _, iCode, zMsg uint32) {
func logCallback(ctx context.Context, mod api.Module, _ ptr_t, iCode res_t, zMsg ptr_t) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.log != nil {
msg := util.ReadString(mod, zMsg, _MAX_LENGTH)
c.log(xErrorCode(iCode), msg)
@@ -88,93 +98,93 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro
defer c.arena.mark()()
ptr := c.arena.new(max(ptrlen, intlen))
var schemaPtr uint32
var schemaPtr ptr_t
if schema != "" {
schemaPtr = c.arena.string(schema)
}
var rc uint64
var res any
var rc res_t
var ret any
switch op {
default:
return nil, MISUSE
case FCNTL_RESET_CACHE:
rc = c.call("sqlite3_file_control",
uint64(c.handle), uint64(schemaPtr),
uint64(op), 0)
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), 0))
case FCNTL_PERSIST_WAL, FCNTL_POWERSAFE_OVERWRITE:
var flag int
var flag int32
switch {
case len(arg) == 0:
flag = -1
case arg[0]:
flag = 1
}
util.WriteUint32(c.mod, ptr, uint32(flag))
rc = c.call("sqlite3_file_control",
uint64(c.handle), uint64(schemaPtr),
uint64(op), uint64(ptr))
res = util.ReadUint32(c.mod, ptr) != 0
util.Write32(c.mod, ptr, flag)
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
ret = util.ReadBool(c.mod, ptr)
case FCNTL_CHUNK_SIZE:
util.WriteUint32(c.mod, ptr, uint32(arg[0].(int)))
rc = c.call("sqlite3_file_control",
uint64(c.handle), uint64(schemaPtr),
uint64(op), uint64(ptr))
util.Write32(c.mod, ptr, int32(arg[0].(int)))
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
case FCNTL_RESERVE_BYTES:
bytes := -1
if len(arg) > 0 {
bytes = arg[0].(int)
}
util.WriteUint32(c.mod, ptr, uint32(bytes))
rc = c.call("sqlite3_file_control",
uint64(c.handle), uint64(schemaPtr),
uint64(op), uint64(ptr))
res = int(util.ReadUint32(c.mod, ptr))
util.Write32(c.mod, ptr, int32(bytes))
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
ret = int(util.Read32[int32](c.mod, ptr))
case FCNTL_DATA_VERSION:
rc = c.call("sqlite3_file_control",
uint64(c.handle), uint64(schemaPtr),
uint64(op), uint64(ptr))
res = util.ReadUint32(c.mod, ptr)
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
ret = util.Read32[uint32](c.mod, ptr)
case FCNTL_LOCKSTATE:
rc = c.call("sqlite3_file_control",
uint64(c.handle), uint64(schemaPtr),
uint64(op), uint64(ptr))
res = vfs.LockLevel(util.ReadUint32(c.mod, ptr))
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
ret = util.Read32[vfs.LockLevel](c.mod, ptr)
case FCNTL_VFS_POINTER:
rc = c.call("sqlite3_file_control",
uint64(c.handle), uint64(schemaPtr),
uint64(op), uint64(ptr))
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
if rc == _OK {
const zNameOffset = 16
ptr = util.ReadUint32(c.mod, ptr)
ptr = util.ReadUint32(c.mod, ptr+zNameOffset)
ptr = util.Read32[ptr_t](c.mod, ptr)
ptr = util.Read32[ptr_t](c.mod, ptr+zNameOffset)
name := util.ReadString(c.mod, ptr, _MAX_NAME)
res = vfs.Find(name)
ret = vfs.Find(name)
}
case FCNTL_FILE_POINTER, FCNTL_JOURNAL_POINTER:
rc = c.call("sqlite3_file_control",
uint64(c.handle), uint64(schemaPtr),
uint64(op), uint64(ptr))
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
if rc == _OK {
const fileHandleOffset = 4
ptr = util.ReadUint32(c.mod, ptr)
ptr = util.ReadUint32(c.mod, ptr+fileHandleOffset)
res = util.GetHandle(c.ctx, ptr)
ptr = util.Read32[ptr_t](c.mod, ptr)
ptr = util.Read32[ptr_t](c.mod, ptr+fileHandleOffset)
ret = util.GetHandle(c.ctx, ptr)
}
}
if err := c.error(rc); err != nil {
return nil, err
}
return res, nil
return ret, nil
}
// Limit allows the size of various constructs to be
@@ -182,20 +192,20 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro
//
// https://sqlite.org/c3ref/limit.html
func (c *Conn) Limit(id LimitCategory, value int) int {
r := c.call("sqlite3_limit", uint64(c.handle), uint64(id), uint64(value))
return int(int32(r))
v := int32(c.call("sqlite3_limit", stk_t(c.handle), stk_t(id), stk_t(value)))
return int(v)
}
// SetAuthorizer registers an authorizer callback with the database connection.
//
// https://sqlite.org/c3ref/set_authorizer.html
func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4th, schema, inner string) AuthorizerReturnCode) error {
var enable uint64
var enable int32
if cb != nil {
enable = 1
}
r := c.call("sqlite3_set_authorizer_go", uint64(c.handle), enable)
if err := c.error(r); err != nil {
rc := res_t(c.call("sqlite3_set_authorizer_go", stk_t(c.handle), stk_t(enable)))
if err := c.error(rc); err != nil {
return err
}
c.authorizer = cb
@@ -203,7 +213,7 @@ func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4
}
func authorizerCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zName3rd, zName4th, zSchema, zInner uint32) (rc AuthorizerReturnCode) {
func authorizerCallback(ctx context.Context, mod api.Module, pDB ptr_t, action AuthorizerActionCode, zName3rd, zName4th, zSchema, zInner ptr_t) (rc AuthorizerReturnCode) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.authorizer != nil {
var name3rd, name4th, schema, inner string
if zName3rd != 0 {
@@ -227,15 +237,15 @@ func authorizerCallback(ctx context.Context, mod api.Module, pDB uint32, action
//
// https://sqlite.org/c3ref/trace_v2.html
func (c *Conn) Trace(mask TraceEvent, cb func(evt TraceEvent, arg1 any, arg2 any) error) error {
r := c.call("sqlite3_trace_go", uint64(c.handle), uint64(mask))
if err := c.error(r); err != nil {
rc := res_t(c.call("sqlite3_trace_go", stk_t(c.handle), stk_t(mask)))
if err := c.error(rc); err != nil {
return err
}
c.trace = cb
return nil
}
func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pArg1, pArg2 uint32) (rc uint32) {
func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pArg1, pArg2 ptr_t) (rc res_t) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.trace != nil {
var arg1, arg2 any
if evt == TRACE_CLOSE {
@@ -248,7 +258,7 @@ func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pAr
case TRACE_STMT:
arg2 = s.SQL()
case TRACE_PROFILE:
arg2 = int64(util.ReadUint64(mod, pArg2))
arg2 = util.Read64[int64](mod, pArg2)
}
break
}
@@ -269,20 +279,20 @@ func (c *Conn) WALCheckpoint(schema string, mode CheckpointMode) (nLog, nCkpt in
nLogPtr := c.arena.new(ptrlen)
nCkptPtr := c.arena.new(ptrlen)
schemaPtr := c.arena.string(schema)
r := c.call("sqlite3_wal_checkpoint_v2",
uint64(c.handle), uint64(schemaPtr), uint64(mode),
uint64(nLogPtr), uint64(nCkptPtr))
nLog = int(int32(util.ReadUint32(c.mod, nLogPtr)))
nCkpt = int(int32(util.ReadUint32(c.mod, nCkptPtr)))
return nLog, nCkpt, c.error(r)
rc := res_t(c.call("sqlite3_wal_checkpoint_v2",
stk_t(c.handle), stk_t(schemaPtr), stk_t(mode),
stk_t(nLogPtr), stk_t(nCkptPtr)))
nLog = int(util.Read32[int32](c.mod, nLogPtr))
nCkpt = int(util.Read32[int32](c.mod, nCkptPtr))
return nLog, nCkpt, c.error(rc)
}
// WALAutoCheckpoint configures WAL auto-checkpoints.
//
// https://sqlite.org/c3ref/wal_autocheckpoint.html
func (c *Conn) WALAutoCheckpoint(pages int) error {
r := c.call("sqlite3_wal_autocheckpoint", uint64(c.handle), uint64(pages))
return c.error(r)
rc := res_t(c.call("sqlite3_wal_autocheckpoint", stk_t(c.handle), stk_t(pages)))
return c.error(rc)
}
// WALHook registers a callback function to be invoked
@@ -290,15 +300,15 @@ func (c *Conn) WALAutoCheckpoint(pages int) error {
//
// https://sqlite.org/c3ref/wal_hook.html
func (c *Conn) WALHook(cb func(db *Conn, schema string, pages int) error) {
var enable uint64
var enable int32
if cb != nil {
enable = 1
}
c.call("sqlite3_wal_hook_go", uint64(c.handle), enable)
c.call("sqlite3_wal_hook_go", stk_t(c.handle), stk_t(enable))
c.wal = cb
}
func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema uint32, pages int32) (rc uint32) {
func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema ptr_t, pages int32) (rc res_t) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.wal != nil {
schema := util.ReadString(mod, zSchema, _MAX_NAME)
err := c.wal(c, schema, int(pages))
@@ -311,15 +321,15 @@ func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema uint32, pa
//
// https://sqlite.org/c3ref/autovacuum_pages.html
func (c *Conn) AutoVacuumPages(cb func(schema string, dbPages, freePages, bytesPerPage uint) uint) error {
var funcPtr uint32
var funcPtr ptr_t
if cb != nil {
funcPtr = util.AddHandle(c.ctx, cb)
}
r := c.call("sqlite3_autovacuum_pages_go", uint64(c.handle), uint64(funcPtr))
return c.error(r)
rc := res_t(c.call("sqlite3_autovacuum_pages_go", stk_t(c.handle), stk_t(funcPtr)))
return c.error(rc)
}
func autoVacuumCallback(ctx context.Context, mod api.Module, pApp, zSchema, nDbPage, nFreePage, nBytePerPage uint32) uint32 {
func autoVacuumCallback(ctx context.Context, mod api.Module, pApp, zSchema ptr_t, nDbPage, nFreePage, nBytePerPage uint32) uint32 {
fn := util.GetHandle(ctx, pApp).(func(schema string, dbPages, freePages, bytesPerPage uint) uint)
schema := util.ReadString(mod, zSchema, _MAX_NAME)
return uint32(fn(schema, uint(nDbPage), uint(nFreePage), uint(nBytePerPage)))
@@ -329,14 +339,14 @@ func autoVacuumCallback(ctx context.Context, mod api.Module, pApp, zSchema, nDbP
//
// https://sqlite.org/c3ref/hard_heap_limit64.html
func (c *Conn) SoftHeapLimit(n int64) int64 {
return int64(c.call("sqlite3_soft_heap_limit64", uint64(n)))
return int64(c.call("sqlite3_soft_heap_limit64", stk_t(n)))
}
// HardHeapLimit imposes a hard limit on heap size.
//
// https://sqlite.org/c3ref/hard_heap_limit64.html
func (c *Conn) HardHeapLimit(n int64) int64 {
return int64(c.call("sqlite3_hard_heap_limit64", uint64(n)))
return int64(c.call("sqlite3_hard_heap_limit64", stk_t(n)))
}
// EnableChecksums enables checksums on a database.

169
conn.go
View File

@@ -3,6 +3,7 @@ package sqlite3
import (
"context"
"fmt"
"iter"
"math"
"math/rand"
"net/url"
@@ -35,11 +36,11 @@ type Conn struct {
update func(AuthorizerActionCode, string, string, int64)
commit func() bool
rollback func()
arena arena
busy1st time.Time
busylst time.Time
handle uint32
arena arena
handle ptr_t
}
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI].
@@ -48,7 +49,7 @@ func Open(filename string) (*Conn, error) {
}
// OpenContext is like [Open] but includes a context,
// which is used to interrupt the process of opening the connectiton.
// which is used to interrupt the process of opening the connection.
func OpenContext(ctx context.Context, filename string) (*Conn, error) {
return newConn(ctx, filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
}
@@ -68,9 +69,9 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
return newConn(context.Background(), filename, flags)
}
type connKey struct{}
type connKey = util.ConnKey
func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _ error) {
func newConn(ctx context.Context, filename string, flags OpenFlag) (ret *Conn, _ error) {
err := ctx.Err()
if err != nil {
return nil, err
@@ -82,7 +83,7 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _
return nil, err
}
defer func() {
if res == nil {
if ret == nil {
c.Close()
c.sqlite.close()
} else {
@@ -91,7 +92,10 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _
}()
c.ctx = context.WithValue(c.ctx, connKey{}, c)
c.arena = c.newArena(1024)
if logger := defaultLogger.Load(); logger != nil {
c.ConfigLog(*logger)
}
c.arena = c.newArena()
c.handle, err = c.openDB(filename, flags)
if err == nil {
err = initExtensions(c)
@@ -102,21 +106,21 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _
return c, nil
}
func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) {
defer c.arena.mark()()
connPtr := c.arena.new(ptrlen)
namePtr := c.arena.string(filename)
flags |= OPEN_EXRESCODE
r := c.call("sqlite3_open_v2", uint64(namePtr), uint64(connPtr), uint64(flags), 0)
rc := res_t(c.call("sqlite3_open_v2", stk_t(namePtr), stk_t(connPtr), stk_t(flags), 0))
handle := util.ReadUint32(c.mod, connPtr)
if err := c.sqlite.error(r, handle); err != nil {
handle := util.Read32[ptr_t](c.mod, connPtr)
if err := c.sqlite.error(rc, handle); err != nil {
c.closeDB(handle)
return 0, err
}
c.call("sqlite3_progress_handler_go", uint64(handle), 100)
c.call("sqlite3_progress_handler_go", stk_t(handle), 100)
if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
var pragmas strings.Builder
if _, after, ok := strings.Cut(filename, "?"); ok {
@@ -130,8 +134,8 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
if pragmas.Len() != 0 {
c.checkInterrupt(handle)
pragmaPtr := c.arena.string(pragmas.String())
r := c.call("sqlite3_exec", uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.sqlite.error(r, handle, pragmas.String()); err != nil {
rc := res_t(c.call("sqlite3_exec", stk_t(handle), stk_t(pragmaPtr), 0, 0, 0))
if err := c.sqlite.error(rc, handle, pragmas.String()); err != nil {
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
c.closeDB(handle)
return 0, err
@@ -141,9 +145,9 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
return handle, nil
}
func (c *Conn) closeDB(handle uint32) {
r := c.call("sqlite3_close_v2", uint64(handle))
if err := c.sqlite.error(r, handle); err != nil {
func (c *Conn) closeDB(handle ptr_t) {
rc := res_t(c.call("sqlite3_close_v2", stk_t(handle)))
if err := c.sqlite.error(rc, handle); err != nil {
panic(err)
}
}
@@ -165,8 +169,8 @@ func (c *Conn) Close() error {
c.pending.Close()
c.pending = nil
r := c.call("sqlite3_close", uint64(c.handle))
if err := c.error(r); err != nil {
rc := res_t(c.call("sqlite3_close", stk_t(c.handle)))
if err := c.error(rc); err != nil {
return err
}
@@ -183,8 +187,8 @@ func (c *Conn) Exec(sql string) error {
sqlPtr := c.arena.string(sql)
c.checkInterrupt(c.handle)
r := c.call("sqlite3_exec", uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
return c.error(r, sql)
rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(sqlPtr), 0, 0, 0))
return c.error(rc, sql)
}
// Prepare calls [Conn.PrepareFlags] with no flags.
@@ -209,17 +213,17 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
sqlPtr := c.arena.string(sql)
c.checkInterrupt(c.handle)
r := c.call("sqlite3_prepare_v3", uint64(c.handle),
uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),
uint64(stmtPtr), uint64(tailPtr))
rc := res_t(c.call("sqlite3_prepare_v3", stk_t(c.handle),
stk_t(sqlPtr), stk_t(len(sql)+1), stk_t(flags),
stk_t(stmtPtr), stk_t(tailPtr)))
stmt = &Stmt{c: c}
stmt.handle = util.ReadUint32(c.mod, stmtPtr)
if sql := sql[util.ReadUint32(c.mod, tailPtr)-sqlPtr:]; sql != "" {
stmt.handle = util.Read32[ptr_t](c.mod, stmtPtr)
if sql := sql[util.Read32[ptr_t](c.mod, tailPtr)-sqlPtr:]; sql != "" {
tail = sql
}
if err := c.error(r, sql); err != nil {
if err := c.error(rc, sql); err != nil {
return nil, "", err
}
if stmt.handle == 0 {
@@ -233,9 +237,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
//
// https://sqlite.org/c3ref/db_name.html
func (c *Conn) DBName(n int) string {
r := c.call("sqlite3_db_name", uint64(c.handle), uint64(n))
ptr := uint32(r)
ptr := ptr_t(c.call("sqlite3_db_name", stk_t(c.handle), stk_t(n)))
if ptr == 0 {
return ""
}
@@ -246,34 +248,34 @@ func (c *Conn) DBName(n int) string {
//
// https://sqlite.org/c3ref/db_filename.html
func (c *Conn) Filename(schema string) *vfs.Filename {
var ptr uint32
var ptr ptr_t
if schema != "" {
defer c.arena.mark()()
ptr = c.arena.string(schema)
}
r := c.call("sqlite3_db_filename", uint64(c.handle), uint64(ptr))
return vfs.GetFilename(c.ctx, c.mod, uint32(r), vfs.OPEN_MAIN_DB)
ptr = ptr_t(c.call("sqlite3_db_filename", stk_t(c.handle), stk_t(ptr)))
return vfs.GetFilename(c.ctx, c.mod, ptr, vfs.OPEN_MAIN_DB)
}
// ReadOnly determines if a database is read-only.
//
// https://sqlite.org/c3ref/db_readonly.html
func (c *Conn) ReadOnly(schema string) (ro bool, ok bool) {
var ptr uint32
var ptr ptr_t
if schema != "" {
defer c.arena.mark()()
ptr = c.arena.string(schema)
}
r := c.call("sqlite3_db_readonly", uint64(c.handle), uint64(ptr))
return int32(r) > 0, int32(r) < 0
b := int32(c.call("sqlite3_db_readonly", stk_t(c.handle), stk_t(ptr)))
return b > 0, b < 0
}
// GetAutocommit tests the connection for auto-commit mode.
//
// https://sqlite.org/c3ref/get_autocommit.html
func (c *Conn) GetAutocommit() bool {
r := c.call("sqlite3_get_autocommit", uint64(c.handle))
return r != 0
b := int32(c.call("sqlite3_get_autocommit", stk_t(c.handle)))
return b != 0
}
// LastInsertRowID returns the rowid of the most recent successful INSERT
@@ -281,8 +283,7 @@ func (c *Conn) GetAutocommit() bool {
//
// https://sqlite.org/c3ref/last_insert_rowid.html
func (c *Conn) LastInsertRowID() int64 {
r := c.call("sqlite3_last_insert_rowid", uint64(c.handle))
return int64(r)
return int64(c.call("sqlite3_last_insert_rowid", stk_t(c.handle)))
}
// SetLastInsertRowID allows the application to set the value returned by
@@ -290,7 +291,7 @@ func (c *Conn) LastInsertRowID() int64 {
//
// https://sqlite.org/c3ref/set_last_insert_rowid.html
func (c *Conn) SetLastInsertRowID(id int64) {
c.call("sqlite3_set_last_insert_rowid", uint64(c.handle), uint64(id))
c.call("sqlite3_set_last_insert_rowid", stk_t(c.handle), stk_t(id))
}
// Changes returns the number of rows modified, inserted or deleted
@@ -299,8 +300,7 @@ func (c *Conn) SetLastInsertRowID(id int64) {
//
// https://sqlite.org/c3ref/changes.html
func (c *Conn) Changes() int64 {
r := c.call("sqlite3_changes64", uint64(c.handle))
return int64(r)
return int64(c.call("sqlite3_changes64", stk_t(c.handle)))
}
// TotalChanges returns the number of rows modified, inserted or deleted
@@ -309,16 +309,15 @@ func (c *Conn) Changes() int64 {
//
// https://sqlite.org/c3ref/total_changes.html
func (c *Conn) TotalChanges() int64 {
r := c.call("sqlite3_total_changes64", uint64(c.handle))
return int64(r)
return int64(c.call("sqlite3_total_changes64", stk_t(c.handle)))
}
// ReleaseMemory frees memory used by a database connection.
//
// https://sqlite.org/c3ref/db_release_memory.html
func (c *Conn) ReleaseMemory() error {
r := c.call("sqlite3_db_release_memory", uint64(c.handle))
return c.error(r)
rc := res_t(c.call("sqlite3_db_release_memory", stk_t(c.handle)))
return c.error(rc)
}
// GetInterrupt gets the context set with [Conn.SetInterrupt].
@@ -354,10 +353,10 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen)
loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`)
c.call("sqlite3_prepare_v3", uint64(c.handle), uint64(loopPtr), math.MaxUint64,
uint64(PREPARE_PERSISTENT), uint64(stmtPtr), 0)
c.call("sqlite3_prepare_v3", stk_t(c.handle), stk_t(loopPtr), math.MaxUint64,
stk_t(PREPARE_PERSISTENT), stk_t(stmtPtr), 0)
c.pending = &Stmt{c: c}
c.pending.handle = util.ReadUint32(c.mod, stmtPtr)
c.pending.handle = util.Read32[ptr_t](c.mod, stmtPtr)
}
if old.Done() != nil && ctx.Err() == nil {
@@ -369,13 +368,13 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
return old
}
func (c *Conn) checkInterrupt(handle uint32) {
func (c *Conn) checkInterrupt(handle ptr_t) {
if c.interrupt.Err() != nil {
c.call("sqlite3_interrupt", uint64(handle))
c.call("sqlite3_interrupt", stk_t(handle))
}
}
func progressCallback(ctx context.Context, mod api.Module, _ uint32) (interrupt uint32) {
func progressCallback(ctx context.Context, mod api.Module, _ ptr_t) (interrupt int32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
if c.interrupt.Done() != nil {
runtime.Gosched()
@@ -392,11 +391,11 @@ func progressCallback(ctx context.Context, mod api.Module, _ uint32) (interrupt
// https://sqlite.org/c3ref/busy_timeout.html
func (c *Conn) BusyTimeout(timeout time.Duration) error {
ms := min((timeout+time.Millisecond-1)/time.Millisecond, math.MaxInt32)
r := c.call("sqlite3_busy_timeout", uint64(c.handle), uint64(ms))
return c.error(r)
rc := res_t(c.call("sqlite3_busy_timeout", stk_t(c.handle), stk_t(ms)))
return c.error(rc)
}
func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (retry uint32) {
func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (retry int32) {
// https://fractaledmind.github.io/2024/04/15/sqlite-on-rails-the-how-and-why-of-optimal-performance/
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.interrupt.Err() == nil {
switch {
@@ -419,19 +418,19 @@ func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (r
//
// https://sqlite.org/c3ref/busy_handler.html
func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool)) error {
var enable uint64
var enable int32
if cb != nil {
enable = 1
}
r := c.call("sqlite3_busy_handler_go", uint64(c.handle), enable)
if err := c.error(r); err != nil {
rc := res_t(c.call("sqlite3_busy_handler_go", stk_t(c.handle), stk_t(enable)))
if err := c.error(rc); err != nil {
return err
}
c.busy = cb
return nil
}
func busyCallback(ctx context.Context, mod api.Module, pDB uint32, count int32) (retry uint32) {
func busyCallback(ctx context.Context, mod api.Module, pDB ptr_t, count int32) (retry int32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil {
interrupt := c.interrupt
if interrupt == nil {
@@ -452,16 +451,16 @@ func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int, err erro
hiPtr := c.arena.new(intlen)
curPtr := c.arena.new(intlen)
var i uint64
var i int32
if reset {
i = 1
}
r := c.call("sqlite3_db_status", uint64(c.handle),
uint64(op), uint64(curPtr), uint64(hiPtr), i)
if err = c.error(r); err == nil {
current = int(util.ReadUint32(c.mod, curPtr))
highwater = int(util.ReadUint32(c.mod, hiPtr))
rc := res_t(c.call("sqlite3_db_status", stk_t(c.handle),
stk_t(op), stk_t(curPtr), stk_t(hiPtr), stk_t(i)))
if err = c.error(rc); err == nil {
current = int(util.Read32[int32](c.mod, curPtr))
highwater = int(util.Read32[int32](c.mod, hiPtr))
}
return
}
@@ -472,7 +471,7 @@ func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int, err erro
func (c *Conn) TableColumnMetadata(schema, table, column string) (declType, collSeq string, notNull, primaryKey, autoInc bool, err error) {
defer c.arena.mark()()
var schemaPtr, columnPtr uint32
var schemaPtr, columnPtr ptr_t
declTypePtr := c.arena.new(ptrlen)
collSeqPtr := c.arena.new(ptrlen)
notNullPtr := c.arena.new(ptrlen)
@@ -486,32 +485,38 @@ func (c *Conn) TableColumnMetadata(schema, table, column string) (declType, coll
columnPtr = c.arena.string(column)
}
r := c.call("sqlite3_table_column_metadata", uint64(c.handle),
uint64(schemaPtr), uint64(tablePtr), uint64(columnPtr),
uint64(declTypePtr), uint64(collSeqPtr),
uint64(notNullPtr), uint64(primaryKeyPtr), uint64(autoIncPtr))
if err = c.error(r); err == nil && column != "" {
if ptr := util.ReadUint32(c.mod, declTypePtr); ptr != 0 {
rc := res_t(c.call("sqlite3_table_column_metadata", stk_t(c.handle),
stk_t(schemaPtr), stk_t(tablePtr), stk_t(columnPtr),
stk_t(declTypePtr), stk_t(collSeqPtr),
stk_t(notNullPtr), stk_t(primaryKeyPtr), stk_t(autoIncPtr)))
if err = c.error(rc); err == nil && column != "" {
if ptr := util.Read32[ptr_t](c.mod, declTypePtr); ptr != 0 {
declType = util.ReadString(c.mod, ptr, _MAX_NAME)
}
if ptr := util.ReadUint32(c.mod, collSeqPtr); ptr != 0 {
if ptr := util.Read32[ptr_t](c.mod, collSeqPtr); ptr != 0 {
collSeq = util.ReadString(c.mod, ptr, _MAX_NAME)
}
notNull = util.ReadUint32(c.mod, notNullPtr) != 0
autoInc = util.ReadUint32(c.mod, autoIncPtr) != 0
primaryKey = util.ReadUint32(c.mod, primaryKeyPtr) != 0
notNull = util.ReadBool(c.mod, notNullPtr)
autoInc = util.ReadBool(c.mod, autoIncPtr)
primaryKey = util.ReadBool(c.mod, primaryKeyPtr)
}
return
}
func (c *Conn) error(rc uint64, sql ...string) error {
func (c *Conn) error(rc res_t, sql ...string) error {
return c.sqlite.error(rc, c.handle, sql...)
}
func (c *Conn) stmtsIter(yield func(*Stmt) bool) {
for _, s := range c.stmts {
if !yield(s) {
break
// Stmts returns an iterator for the prepared statements
// associated with the database connection.
//
// https://sqlite.org/c3ref/next_stmt.html
func (c *Conn) Stmts() iter.Seq[*Stmt] {
return func(yield func(*Stmt) bool) {
for _, s := range c.stmts {
if !yield(s) {
break
}
}
}
}

View File

@@ -1,11 +0,0 @@
//go:build go1.23
package sqlite3
import "iter"
// Stmts returns an iterator for the prepared statements
// associated with the database connection.
//
// https://sqlite.org/c3ref/next_stmt.html
func (c *Conn) Stmts() iter.Seq[*Stmt] { return c.stmtsIter }

View File

@@ -1,9 +0,0 @@
//go:build !go1.23
package sqlite3
// Stmts returns an iterator for the prepared statements
// associated with the database connection.
//
// https://sqlite.org/c3ref/next_stmt.html
func (c *Conn) Stmts() func(func(*Stmt) bool) { return c.stmtsIter }

View File

@@ -1,19 +1,28 @@
package sqlite3
import "strconv"
import (
"strconv"
"github.com/ncruces/go-sqlite3/internal/util"
)
const (
_OK = 0 /* Successful result */
_ROW = 100 /* sqlite3_step() has another row ready */
_DONE = 101 /* sqlite3_step() has finished executing */
_MAX_NAME = 1e6 // Self-imposed limit for most NUL terminated strings.
_MAX_LENGTH = 1e9
_MAX_SQL_LENGTH = 1e9
_MAX_FUNCTION_ARG = 100
_MAX_NAME = 1e6 // Self-imposed limit for most NUL terminated strings.
_MAX_LENGTH = 1e9
_MAX_SQL_LENGTH = 1e9
ptrlen = 4
intlen = 4
ptrlen = util.PtrLen
intlen = util.IntLen
)
type (
stk_t = util.Stk_t
ptr_t = util.Ptr_t
res_t = util.Res_t
)
// ErrorCode is a result code that [Error.Code] might return.
@@ -166,6 +175,7 @@ const (
PREPARE_PERSISTENT PrepareFlag = 0x01
PREPARE_NORMALIZE PrepareFlag = 0x02
PREPARE_NO_VTAB PrepareFlag = 0x04
PREPARE_DONT_LOG PrepareFlag = 0x10
)
// FunctionFlag is a flag that can be passed to
@@ -219,6 +229,7 @@ const (
DBSTATUS_DEFERRED_FKS DBStatus = 10
DBSTATUS_CACHE_USED_SHARED DBStatus = 11
DBSTATUS_CACHE_SPILL DBStatus = 12
// DBSTATUS_MAX DBStatus = 12
)
// DBConfig are the available database connection configuration options.
@@ -247,7 +258,10 @@ const (
DBCONFIG_TRUSTED_SCHEMA DBConfig = 1017
DBCONFIG_STMT_SCANSTATUS DBConfig = 1018
DBCONFIG_REVERSE_SCANORDER DBConfig = 1019
// DBCONFIG_MAX DBConfig = 1019
DBCONFIG_ENABLE_ATTACH_CREATE DBConfig = 1020
DBCONFIG_ENABLE_ATTACH_WRITE DBConfig = 1021
DBCONFIG_ENABLE_COMMENTS DBConfig = 1022
// DBCONFIG_MAX DBConfig = 1022
)
// FcntlOpcode are the available opcodes for [Conn.FileControl].

View File

@@ -15,7 +15,7 @@ import (
// https://sqlite.org/c3ref/context.html
type Context struct {
c *Conn
handle uint32
handle ptr_t
}
// Conn returns the database connection of the
@@ -32,14 +32,14 @@ func (ctx Context) Conn() *Conn {
// https://sqlite.org/c3ref/get_auxdata.html
func (ctx Context) SetAuxData(n int, data any) {
ptr := util.AddHandle(ctx.c.ctx, data)
ctx.c.call("sqlite3_set_auxdata_go", uint64(ctx.handle), uint64(n), uint64(ptr))
ctx.c.call("sqlite3_set_auxdata_go", stk_t(ctx.handle), stk_t(n), stk_t(ptr))
}
// GetAuxData returns metadata for argument n of the function.
//
// https://sqlite.org/c3ref/get_auxdata.html
func (ctx Context) GetAuxData(n int) any {
ptr := uint32(ctx.c.call("sqlite3_get_auxdata", uint64(ctx.handle), uint64(n)))
ptr := ptr_t(ctx.c.call("sqlite3_get_auxdata", stk_t(ctx.handle), stk_t(n)))
return util.GetHandle(ctx.c.ctx, ptr)
}
@@ -68,7 +68,7 @@ func (ctx Context) ResultInt(value int) {
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultInt64(value int64) {
ctx.c.call("sqlite3_result_int64",
uint64(ctx.handle), uint64(value))
stk_t(ctx.handle), stk_t(value))
}
// ResultFloat sets the result of the function to a float64.
@@ -76,7 +76,7 @@ func (ctx Context) ResultInt64(value int64) {
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultFloat(value float64) {
ctx.c.call("sqlite3_result_double",
uint64(ctx.handle), math.Float64bits(value))
stk_t(ctx.handle), stk_t(math.Float64bits(value)))
}
// ResultText sets the result of the function to a string.
@@ -85,7 +85,7 @@ func (ctx Context) ResultFloat(value float64) {
func (ctx Context) ResultText(value string) {
ptr := ctx.c.newString(value)
ctx.c.call("sqlite3_result_text_go",
uint64(ctx.handle), uint64(ptr), uint64(len(value)))
stk_t(ctx.handle), stk_t(ptr), stk_t(len(value)))
}
// ResultRawText sets the text result of the function to a []byte.
@@ -95,7 +95,7 @@ func (ctx Context) ResultText(value string) {
func (ctx Context) ResultRawText(value []byte) {
ptr := ctx.c.newBytes(value)
ctx.c.call("sqlite3_result_text_go",
uint64(ctx.handle), uint64(ptr), uint64(len(value)))
stk_t(ctx.handle), stk_t(ptr), stk_t(len(value)))
}
// ResultBlob sets the result of the function to a []byte.
@@ -105,7 +105,7 @@ func (ctx Context) ResultRawText(value []byte) {
func (ctx Context) ResultBlob(value []byte) {
ptr := ctx.c.newBytes(value)
ctx.c.call("sqlite3_result_blob_go",
uint64(ctx.handle), uint64(ptr), uint64(len(value)))
stk_t(ctx.handle), stk_t(ptr), stk_t(len(value)))
}
// ResultZeroBlob sets the result of the function to a zero-filled, length n BLOB.
@@ -113,7 +113,7 @@ func (ctx Context) ResultBlob(value []byte) {
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultZeroBlob(n int64) {
ctx.c.call("sqlite3_result_zeroblob64",
uint64(ctx.handle), uint64(n))
stk_t(ctx.handle), stk_t(n))
}
// ResultNull sets the result of the function to NULL.
@@ -121,7 +121,7 @@ func (ctx Context) ResultZeroBlob(n int64) {
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultNull() {
ctx.c.call("sqlite3_result_null",
uint64(ctx.handle))
stk_t(ctx.handle))
}
// ResultTime sets the result of the function to a [time.Time].
@@ -146,14 +146,14 @@ func (ctx Context) ResultTime(value time.Time, format TimeFormat) {
}
func (ctx Context) resultRFC3339Nano(value time.Time) {
const maxlen = uint64(len(time.RFC3339Nano)) + 5
const maxlen = int64(len(time.RFC3339Nano)) + 5
ptr := ctx.c.new(maxlen)
buf := util.View(ctx.c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
ctx.c.call("sqlite3_result_text_go",
uint64(ctx.handle), uint64(ptr), uint64(len(buf)))
stk_t(ctx.handle), stk_t(ptr), stk_t(len(buf)))
}
// ResultPointer sets the result of the function to NULL, just like [Context.ResultNull],
@@ -164,7 +164,7 @@ func (ctx Context) resultRFC3339Nano(value time.Time) {
func (ctx Context) ResultPointer(ptr any) {
valPtr := util.AddHandle(ctx.c.ctx, ptr)
ctx.c.call("sqlite3_result_pointer_go",
uint64(ctx.handle), uint64(valPtr))
stk_t(ctx.handle), stk_t(valPtr))
}
// ResultJSON sets the result of the function to the JSON encoding of value.
@@ -188,7 +188,7 @@ func (ctx Context) ResultValue(value Value) {
return
}
ctx.c.call("sqlite3_result_value",
uint64(ctx.handle), uint64(value.handle))
stk_t(ctx.handle), stk_t(value.handle))
}
// ResultError sets the result of the function an error.
@@ -196,12 +196,12 @@ func (ctx Context) ResultValue(value Value) {
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultError(err error) {
if errors.Is(err, NOMEM) {
ctx.c.call("sqlite3_result_error_nomem", uint64(ctx.handle))
ctx.c.call("sqlite3_result_error_nomem", stk_t(ctx.handle))
return
}
if errors.Is(err, TOOBIG) {
ctx.c.call("sqlite3_result_error_toobig", uint64(ctx.handle))
ctx.c.call("sqlite3_result_error_toobig", stk_t(ctx.handle))
return
}
@@ -210,11 +210,11 @@ func (ctx Context) ResultError(err error) {
defer ctx.c.arena.mark()()
ptr := ctx.c.arena.string(msg)
ctx.c.call("sqlite3_result_error",
uint64(ctx.handle), uint64(ptr), uint64(len(msg)))
stk_t(ctx.handle), stk_t(ptr), stk_t(len(msg)))
}
if code != _OK {
ctx.c.call("sqlite3_result_error_code",
uint64(ctx.handle), uint64(code))
stk_t(ctx.handle), stk_t(code))
}
}
@@ -223,6 +223,6 @@ func (ctx Context) ResultError(err error) {
//
// https://sqlite.org/c3ref/vtab_nochange.html
func (ctx Context) VTabNoChange() bool {
r := ctx.c.call("sqlite3_vtab_nochange", uint64(ctx.handle))
return r != 0
b := int32(ctx.c.call("sqlite3_vtab_nochange", stk_t(ctx.handle)))
return b != 0
}

View File

@@ -201,7 +201,7 @@ func (n *connector) Driver() driver.Driver {
return &SQLite{}
}
func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) {
func (n *connector) Connect(ctx context.Context) (ret driver.Conn, err error) {
c := &conn{
txLock: n.txLock,
tmRead: n.tmRead,
@@ -213,7 +213,7 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) {
return nil, err
}
defer func() {
if res == nil {
if ret == nil {
c.Close()
}
}()
@@ -466,8 +466,9 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
old := s.Stmt.Conn().SetInterrupt(ctx)
defer s.Stmt.Conn().SetInterrupt(old)
err = s.Stmt.Exec()
s.Stmt.ClearBindings()
err = errors.Join(
s.Stmt.Exec(),
s.Stmt.ClearBindings())
if err != nil {
return nil, err
}
@@ -604,8 +605,9 @@ var (
)
func (r *rows) Close() error {
r.Stmt.ClearBindings()
return r.Stmt.Reset()
return errors.Join(
r.Stmt.Reset(),
r.Stmt.ClearBindings())
}
func (r *rows) Columns() []string {
@@ -718,19 +720,19 @@ func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
switch scan {
case _INT:
return reflect.TypeOf(int64(0))
return reflect.TypeFor[int64]()
case _REAL:
return reflect.TypeOf(float64(0))
return reflect.TypeFor[float64]()
case _TEXT:
return reflect.TypeOf("")
return reflect.TypeFor[string]()
case _BLOB:
return reflect.TypeOf([]byte{})
return reflect.TypeFor[[]byte]()
case _BOOL:
return reflect.TypeOf(false)
return reflect.TypeFor[bool]()
case _TIME:
return reflect.TypeOf(time.Time{})
return reflect.TypeFor[time.Time]()
default:
return reflect.TypeOf((*any)(nil)).Elem()
return reflect.TypeFor[any]()
}
}

View File

@@ -369,13 +369,13 @@ func Test_time(t *testing.T) {
func Test_ColumnType_ScanType(t *testing.T) {
var (
INT = reflect.TypeOf(int64(0))
REAL = reflect.TypeOf(float64(0))
TEXT = reflect.TypeOf("")
BLOB = reflect.TypeOf([]byte{})
BOOL = reflect.TypeOf(false)
TIME = reflect.TypeOf(time.Time{})
ANY = reflect.TypeOf((*any)(nil)).Elem()
INT = reflect.TypeFor[int64]()
REAL = reflect.TypeFor[float64]()
TEXT = reflect.TypeFor[string]()
BLOB = reflect.TypeFor[[]byte]()
BOOL = reflect.TypeFor[bool]()
TIME = reflect.TypeFor[time.Time]()
ANY = reflect.TypeFor[any]()
)
t.Parallel()

View File

@@ -3,7 +3,7 @@ package driver
import (
"context"
"database/sql/driver"
"reflect"
"slices"
"testing"
_ "github.com/ncruces/go-sqlite3/embed"
@@ -16,7 +16,7 @@ func Test_namedValues(t *testing.T) {
{Ordinal: 2, Value: false},
}
got := namedValues([]driver.Value{true, false})
if !reflect.DeepEqual(got, want) {
if !slices.Equal(got, want) {
t.Errorf("got %v, want %v", got, want)
}
}

View File

@@ -1,6 +1,6 @@
# Embeddable Wasm build of SQLite
This folder includes an embeddable Wasm build of SQLite 3.47.2 for use with
This folder includes an embeddable Wasm build of SQLite 3.49.1 for use with
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
The following optional features are compiled in:

View File

@@ -1,13 +1,19 @@
# Embeddable Wasm build of SQLite
This folder includes an embeddable Wasm build of SQLite, including the experimental
This folder includes an alternative embeddable Wasm build of SQLite,
which includes the experimental
[`BEGIN CONCURRENT`](https://sqlite.org/src/doc/begin-concurrent/doc/begin_concurrent.md) and
[Wal2](https://sqlite.org/cgi/src/doc/wal2/doc/wal2.md) patches.
It also enables the optional
[`UPDATE … ORDER BY … LIMIT`](https://sqlite.org/lang_update.html#optional_limit_and_order_by_clauses) and
[`DELETE … ORDER BY … LIMIT`](https://sqlite.org/lang_delete.html#optional_limit_and_order_by_clauses) clauses,
and the [`WITHIN GROUP ORDER BY`](https://sqlite.org/compile.html#enable_ordered_set_aggregates) aggregate syntax.
> [!IMPORTANT]
> This package is experimental.
> It is built from the `bedrock` branch of SQLite,
> since that is _currently_ the most stable, maintained branch to include both features.
> since that is _currently_ the most stable, maintained branch to include these features.
> [!CAUTION]
> The Wal2 journaling mode creates databases that other versions of SQLite cannot access.

Binary file not shown.

View File

@@ -5,6 +5,7 @@ import (
"testing"
"github.com/ncruces/go-sqlite3/driver"
"github.com/ncruces/go-sqlite3/ext/stats"
"github.com/ncruces/go-sqlite3/vfs"
)
@@ -15,7 +16,7 @@ func Test_bcw2(t *testing.T) {
tmp := filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))
db, err := driver.Open("file:" + tmp + "?_pragma=journal_mode(wal2)&_txlock=concurrent")
db, err := driver.Open("file:"+tmp+"?_pragma=journal_mode(wal2)&_txlock=concurrent", stats.Register)
if err != nil {
t.Fatal(err)
}
@@ -37,6 +38,11 @@ func Test_bcw2(t *testing.T) {
t.Fatal(err)
}
_, err = tx.Exec(`SELECT median() WITHIN GROUP (ORDER BY col) FROM test`)
if err != nil {
t.Fatal(err)
}
err = tx.Commit()
if err != nil {
t.Fatal(err)
@@ -47,7 +53,7 @@ func Test_bcw2(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if version != "3.48.0" {
if version != "3.50.0" {
t.Error(version)
}
}

View File

@@ -13,14 +13,16 @@ mkdir -p build/ext/
cp "$ROOT"/sqlite3/*.[ch] build/
cp "$ROOT"/sqlite3/*.patch build/
# https://sqlite.org/src/info/ec5d7025cba9f4ac
curl -# https://sqlite.org/src/tarball/sqlite.tar.gz?r=ec5d7025 | tar xz
# https://sqlite.org/src/info/c09656c62155a6e8
curl -# https://sqlite.org/src/tarball/sqlite.tar.gz?r=c09656c6 | tar xz
cd sqlite
cat ../repro.patch | patch -p0 --no-backup-if-mismatch
if [[ "$OSTYPE" == "msys" || "$OSTYPE" == "cygwin" ]]; then
MSYS_NO_PATHCONV=1 nmake /f makefile.msc sqlite3.c OPTS=-DSQLITE_ENABLE_UPDATE_DELETE_LIMIT
MSYS_NO_PATHCONV=1 nmake /f makefile.msc sqlite3.c "OPTS=-DSQLITE_ENABLE_UPDATE_DELETE_LIMIT -DSQLITE_ENABLE_ORDERED_SET_AGGREGATES"
else
sh configure --enable-update-limit && make sqlite3.c
sh configure --enable-update-limit
OPTS=-DSQLITE_ENABLE_ORDERED_SET_AGGREGATES make sqlite3.c
fi
cd ~-
@@ -37,7 +39,7 @@ mv sqlite/ext/misc/spellfix.c build/ext/
mv sqlite/ext/misc/uint.c build/ext/
cd build
cat *.patch | patch --no-backup-if-mismatch
cat *.patch | patch -p0 --no-backup-if-mismatch
cd ~-
"$WASI_SDK/clang" --target=wasm32-wasi -std=c23 -g0 -O2 \

View File

@@ -1,13 +1,14 @@
module github.com/ncruces/go-sqlite3/embed/bcw2
go 1.21
go 1.23.0
toolchain go1.23.0
toolchain go1.24.0
require github.com/ncruces/go-sqlite3 v0.21.3
require github.com/ncruces/go-sqlite3 v0.24.0
require (
github.com/ncruces/julianday v1.0.0 // indirect
github.com/tetratelabs/wazero v1.8.2 // indirect
golang.org/x/sys v0.29.0 // indirect
github.com/ncruces/sort v0.1.5 // indirect
github.com/tetratelabs/wazero v1.9.0 // indirect
golang.org/x/sys v0.30.0 // indirect
)

View File

@@ -1,10 +1,12 @@
github.com/ncruces/go-sqlite3 v0.21.3 h1:hHkfNQLcbnxPJZhC/RGw9SwP3bfkv/Y0xUHWsr1CdMQ=
github.com/ncruces/go-sqlite3 v0.21.3/go.mod h1:zxMOaSG5kFYVFK4xQa0pdwIszqxqJ0W0BxBgwdrNjuA=
github.com/ncruces/go-sqlite3 v0.24.0 h1:Z4jfmzu2NCd4SmyFwLT2OmF3EnTZbqwATvdiuNHNhLA=
github.com/ncruces/go-sqlite3 v0.24.0/go.mod h1:/Vs8ACZHjJ1SA6E9RZUn3EyB1OP3nDQ4z/ar+0fplTQ=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.8.2 h1:yIgLR/b2bN31bjxwXHD8a3d+BogigR952csSDdLYEv4=
github.com/tetratelabs/wazero v1.8.2/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
github.com/ncruces/sort v0.1.5 h1:fiFWXXAqKI8QckPf/6hu/bGFwcEPrirIOFaJqWujs4k=
github.com/ncruces/sort v0.1.5/go.mod h1:obJToO4rYr6VWP0Uw5FYymgYGt3Br4RXcs/JdKaXAPk=
github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I=
github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=

View File

@@ -11,13 +11,14 @@ package bcw2
import (
_ "embed"
"unsafe"
"github.com/ncruces/go-sqlite3"
)
//go:embed bcw2.wasm
var binary []byte
var binary string
func init() {
sqlite3.Binary = binary
sqlite3.Binary = unsafe.Slice(unsafe.StringData(binary), len(binary))
}

23
embed/bcw2/repro.patch Normal file
View File

@@ -0,0 +1,23 @@
# https://sqlite.org/src/vpatch?from=67809715977a5bad&to=3f57584710d61174
--- tool/mkpragmatab.tcl
+++ tool/mkpragmatab.tcl
@@ -526,14 +526,17 @@
puts $fd [format {#define PragFlg_%-10s 0x%02x /* %s */} \
$f $fv $flagMeaning($f)]
set fv [expr {$fv*2}]
}
-# Sort the column lists so that longer column lists occur first
+# Sort the column lists so that longer column lists occur first.
+# In the event of a tie, sort column lists lexicographically.
#
proc colscmp {a b} {
- return [expr {[llength $b] - [llength $a]}]
+ set rc [expr {[llength $b] - [llength $a]}]
+ if {$rc} {return $rc}
+ return [string compare $a $b]
}
set cols_list [lsort -command colscmp $cols_list]
# Generate the array of column names used by pragmas that act like
# queries.

View File

@@ -77,8 +77,10 @@ sqlite3_get_autocommit
sqlite3_get_auxdata
sqlite3_hard_heap_limit64
sqlite3_interrupt
sqlite3_invoke_busy_handler_go
sqlite3_last_insert_rowid
sqlite3_limit
sqlite3_log_go
sqlite3_malloc64
sqlite3_open_v2
sqlite3_overload_function

View File

@@ -8,13 +8,14 @@ package embed
import (
_ "embed"
"unsafe"
"github.com/ncruces/go-sqlite3"
)
//go:embed sqlite3.wasm
var binary []byte
var binary string
func init() {
sqlite3.Binary = binary
sqlite3.Binary = unsafe.Slice(unsafe.StringData(binary), len(binary))
}

View File

@@ -19,7 +19,7 @@ func Test_init(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if version != "3.47.2" {
if version != "3.49.1" {
t.Error(version)
}
}

Binary file not shown.

View File

@@ -15,7 +15,7 @@ type Error struct {
str string
msg string
sql string
code uint64
code res_t
}
// Code returns the primary error code for this error.
@@ -146,27 +146,27 @@ func (e ExtendedErrorCode) Code() ErrorCode {
return ErrorCode(e)
}
func errorCode(err error, def ErrorCode) (msg string, code uint32) {
func errorCode(err error, def ErrorCode) (msg string, code res_t) {
switch code := err.(type) {
case nil:
return "", _OK
case ErrorCode:
return "", uint32(code)
return "", res_t(code)
case xErrorCode:
return "", uint32(code)
return "", res_t(code)
case *Error:
return code.msg, uint32(code.code)
return code.msg, res_t(code.code)
}
var ecode ErrorCode
var xcode xErrorCode
switch {
case errors.As(err, &xcode):
code = uint32(xcode)
code = res_t(xcode)
case errors.As(err, &ecode):
code = uint32(ecode)
code = res_t(ecode)
default:
code = uint32(def)
code = res_t(def)
}
return err.Error(), code
}

View File

@@ -59,14 +59,14 @@ func TestError_Temporary(t *testing.T) {
tests := []struct {
name string
code uint64
code res_t
want bool
}{
{"ERROR", uint64(ERROR), false},
{"BUSY", uint64(BUSY), true},
{"BUSY_RECOVERY", uint64(BUSY_RECOVERY), true},
{"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), true},
{"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true},
{"ERROR", res_t(ERROR), false},
{"BUSY", res_t(BUSY), true},
{"BUSY_RECOVERY", res_t(BUSY_RECOVERY), true},
{"BUSY_SNAPSHOT", res_t(BUSY_SNAPSHOT), true},
{"BUSY_TIMEOUT", res_t(BUSY_TIMEOUT), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -97,14 +97,14 @@ func TestError_Timeout(t *testing.T) {
tests := []struct {
name string
code uint64
code res_t
want bool
}{
{"ERROR", uint64(ERROR), false},
{"BUSY", uint64(BUSY), false},
{"BUSY_RECOVERY", uint64(BUSY_RECOVERY), false},
{"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), false},
{"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true},
{"ERROR", res_t(ERROR), false},
{"BUSY", res_t(BUSY), false},
{"BUSY_RECOVERY", res_t(BUSY_RECOVERY), false},
{"BUSY_SNAPSHOT", res_t(BUSY_SNAPSHOT), false},
{"BUSY_TIMEOUT", res_t(BUSY_TIMEOUT), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -136,8 +136,8 @@ func Test_ErrorCode_Error(t *testing.T) {
// Test all error codes.
for i := 0; i == int(ErrorCode(i)); i++ {
want := "sqlite3: "
r := db.call("sqlite3_errstr", uint64(i))
want += util.ReadString(db.mod, uint32(r), _MAX_NAME)
ptr := ptr_t(db.call("sqlite3_errstr", stk_t(i)))
want += util.ReadString(db.mod, ptr, _MAX_NAME)
got := ErrorCode(i).Error()
if got != want {
@@ -158,8 +158,8 @@ func Test_ExtendedErrorCode_Error(t *testing.T) {
// Test all extended error codes.
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
want := "sqlite3: "
r := db.call("sqlite3_errstr", uint64(i))
want += util.ReadString(db.mod, uint32(r), _MAX_NAME)
ptr := ptr_t(db.call("sqlite3_errstr", stk_t(i)))
want += util.ReadString(db.mod, ptr, _MAX_NAME)
got := ExtendedErrorCode(i).Error()
if got != want {
@@ -172,7 +172,7 @@ func Test_errorCode(t *testing.T) {
tests := []struct {
arg error
wantMsg string
wantCode uint32
wantCode res_t
}{
{nil, "", _OK},
{ERROR, "", util.ERROR},
@@ -190,7 +190,7 @@ func Test_errorCode(t *testing.T) {
if gotMsg != tt.wantMsg {
t.Errorf("errorCode() gotMsg = %q, want %q", gotMsg, tt.wantMsg)
}
if gotCode != uint32(tt.wantCode) {
if gotCode != tt.wantCode {
t.Errorf("errorCode() gotCode = %d, want %d", gotCode, tt.wantCode)
}
})

View File

@@ -25,6 +25,8 @@ you can load into your database connections.
creates [pivot tables](https://github.com/jakethaw/pivot_vtab).
- [`github.com/ncruces/go-sqlite3/ext/regexp`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/regexp)
provides regular expression functions.
- [`github.com/ncruces/go-sqlite3/ext/serdes`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/serdes)
(de)serializes databases.
- [`github.com/ncruces/go-sqlite3/ext/statement`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/statement)
creates [parameterized views](https://github.com/0x09/sqlite-statement-vtab).
- [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
@@ -34,4 +36,13 @@ you can load into your database connections.
- [`github.com/ncruces/go-sqlite3/ext/uuid`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/uuid)
generates [UUIDs](https://en.wikipedia.org/wiki/Universally_unique_identifier).
- [`github.com/ncruces/go-sqlite3/ext/zorder`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/zorder)
maps multidimensional data to one dimension.
maps multidimensional data to one dimension.
### Pakages
These packages may also be useful to work with SQLite:
- [`github.com/ncruces/decimal`](https://pkg.go.dev/github.com/ncruces/decimal)
decimal arithmetic.
- [`github.com/ncruces/julianday`](https://pkg.go.dev/github.com/ncruces/julianday)
Julian day math.

View File

@@ -4,7 +4,7 @@ import (
"io"
"log"
"os"
"reflect"
"slices"
"strings"
"testing"
@@ -278,7 +278,7 @@ func Test_openblob(t *testing.T) {
}
want := []string{"\xca\xfe", "\xba\xbe"}
if !reflect.DeepEqual(got, want) {
if !slices.Equal(got, want) {
t.Errorf("got %v, want %v", got, want)
}
}

View File

@@ -232,7 +232,7 @@ func (b *bloom) Update(arg ...sqlite3.Value) (rowid int64, err error) {
}
defer f.Close()
for n := 0; n < b.hashes; n++ {
for n := range b.hashes {
hash := calcHash(n, blob)
hash %= uint64(b.bytes * 8)
bitpos := byte(hash % 8)
@@ -268,13 +268,13 @@ func (b *bloom) Open() (sqlite3.VTabCursor, error) {
type cursor struct {
*bloom
arg *sqlite3.Value
arg sqlite3.Value
eof bool
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
c.eof = false
c.arg = &arg[0]
c.arg = arg[0]
blob := arg[0].RawBlob()
f, err := c.db.OpenBlob(c.schema, c.storage, "data", 1, false)
@@ -312,7 +312,7 @@ func (c *cursor) Column(ctx sqlite3.Context, n int) error {
case 0:
ctx.ResultBool(true)
case 1:
ctx.ResultValue(*c.arg)
ctx.ResultValue(c.arg)
}
return nil
}

View File

@@ -210,12 +210,14 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
c.nodes = []node{{root, 0}}
set := util.Set[int64]{}
set.Add(root)
for i := 0; i < len(c.nodes); i++ {
for i := range c.nodes {
curr := c.nodes[i]
if curr.depth >= maxDepth {
continue
}
stmt.BindInt64(1, curr.id)
if err := stmt.BindInt64(1, curr.id); err != nil {
return err
}
for stmt.Step() {
if stmt.ColumnType(0) == sqlite3.INTEGER {
next := stmt.ColumnInt64(0)
@@ -225,7 +227,9 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
}
}
}
stmt.Reset()
if err := stmt.Reset(); err != nil {
return err
}
}
return nil
}

View File

@@ -30,7 +30,7 @@ func Register(db *sqlite3.Conn) error {
// RegisterFS registers the CSV virtual table.
// If a filename is specified, fsys is used to open the file.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (res *table, err error) {
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
var (
filename string
data string
@@ -214,7 +214,10 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
return err
}
if c.table.header {
c.Next() // skip header
err = c.Next() // skip header
if err != nil {
return err
}
}
c.rowID = 0
return c.Next()

View File

@@ -1,70 +0,0 @@
//go:build !go1.23
package fileio
import (
"fmt"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Adapted from: https://research.swtch.com/coro
const errCoroCanceled = util.ErrorString("coroutine canceled")
func coroNew[In, Out any](f func(In, func(Out) In) Out) (resume func(In) (Out, bool), cancel func()) {
type msg[T any] struct {
panic any
val T
}
cin := make(chan msg[In])
cout := make(chan msg[Out])
running := true
resume = func(in In) (out Out, ok bool) {
if !running {
return
}
cin <- msg[In]{val: in}
m := <-cout
if m.panic != nil {
panic(m.panic)
}
return m.val, running
}
cancel = func() {
if !running {
return
}
e := fmt.Errorf("%w", errCoroCanceled)
cin <- msg[In]{panic: e}
m := <-cout
if m.panic != nil && m.panic != e {
panic(m.panic)
}
}
yield := func(out Out) In {
cout <- msg[Out]{val: out}
m := <-cin
if m.panic != nil {
panic(m.panic)
}
return m.val
}
go func() {
defer func() {
if running {
running = false
cout <- msg[Out]{panic: recover()}
}
}()
var out Out
m := <-cin
if m.panic == nil {
out = f(m.val, yield)
}
running = false
cout <- msg[Out]{val: out}
}()
return resume, cancel
}

View File

@@ -42,7 +42,7 @@ func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultText(fs.FileMode(arg[0].Int()).String())
}
func readfile(fsys fs.FS) func(ctx sqlite3.Context, arg ...sqlite3.Value) {
func readfile(fsys fs.FS) sqlite3.ScalarFunction {
return func(ctx sqlite3.Context, arg ...sqlite3.Value) {
var err error
var data []byte

View File

@@ -2,6 +2,7 @@ package fileio
import (
"io/fs"
"iter"
"os"
"path"
"path/filepath"
@@ -62,12 +63,12 @@ func (d fsdir) Open() (sqlite3.VTabCursor, error) {
type cursor struct {
fsdir
base string
resume resume
cancel func()
curr entry
eof bool
rowID int64
base string
next func() (entry, bool)
stop func()
curr entry
eof bool
rowID int64
}
type entry struct {
@@ -77,8 +78,8 @@ type entry struct {
}
func (c *cursor) Close() error {
if c.cancel != nil {
c.cancel()
if c.stop != nil {
c.stop()
}
return nil
}
@@ -101,14 +102,26 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
c.base = base
}
c.resume, c.cancel = pull(c, root)
c.next, c.stop = iter.Pull(func(yield func(entry) bool) {
walkDir := func(path string, d fs.DirEntry, err error) error {
if yield(entry{d, err, path}) {
return nil
}
return fs.SkipAll
}
if c.fsys != nil {
fs.WalkDir(c.fsys, root, walkDir)
} else {
filepath.WalkDir(root, walkDir)
}
})
c.eof = false
c.rowID = 0
return c.Next()
}
func (c *cursor) Next() error {
curr, ok := next(c)
curr, ok := c.next()
c.curr = curr
c.eof = !ok
c.rowID++

View File

@@ -1,29 +0,0 @@
//go:build !go1.23
package fileio
import (
"io/fs"
"path/filepath"
)
type resume = func(struct{}) (entry, bool)
func next(c *cursor) (entry, bool) {
return c.resume(struct{}{})
}
func pull(c *cursor, root string) (resume, func()) {
return coroNew(func(_ struct{}, yield func(entry) struct{}) entry {
walkDir := func(path string, d fs.DirEntry, err error) error {
yield(entry{d, err, path})
return nil
}
if c.fsys != nil {
fs.WalkDir(c.fsys, root, walkDir)
} else {
filepath.WalkDir(root, walkDir)
}
return entry{}
})
}

View File

@@ -1,31 +0,0 @@
//go:build go1.23
package fileio
import (
"io/fs"
"iter"
"path/filepath"
)
type resume = func() (entry, bool)
func next(c *cursor) (entry, bool) {
return c.resume()
}
func pull(c *cursor, root string) (resume, func()) {
return iter.Pull(func(yield func(entry) bool) {
walkDir := func(path string, d fs.DirEntry, err error) error {
if yield(entry{d, err, path}) {
return nil
}
return fs.SkipAll
}
if c.fsys != nil {
fs.WalkDir(c.fsys, root, walkDir)
} else {
filepath.WalkDir(root, walkDir)
}
})
}

View File

@@ -25,14 +25,14 @@ type table struct {
cols []*sqlite3.Value
}
func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (res *table, err error) {
func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (ret *table, err error) {
if len(arg) != 3 {
return nil, fmt.Errorf("pivot: wrong number of arguments")
}
t := &table{db: db}
defer func() {
if res == nil {
if ret == nil {
t.Close()
}
}()
@@ -99,10 +99,11 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (res *table, err e
}
func (t *table) Close() error {
var errs []error
for _, c := range t.cols {
c.Close()
errs = append(errs, c.Close())
}
return nil
return errors.Join(errs...)
}
func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
@@ -206,7 +207,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
func (c *cursor) Next() error {
if c.scan.Step() {
count := c.scan.ColumnCount()
for i := 0; i < count; i++ {
for i := range count {
err := c.cell.BindValue(i+1, c.scan.ColumnValue(i))
if err != nil {
return err

View File

@@ -16,7 +16,9 @@ package regexp
import (
"errors"
"regexp"
"regexp/syntax"
"strings"
"unicode/utf8"
"github.com/ncruces/go-sqlite3"
)
@@ -50,34 +52,83 @@ func Register(db *sqlite3.Conn) error {
// SELECT column WHERE column GLOB :glob_prefix AND column REGEXP :regexp
//
// [LIKE optimization]: https://sqlite.org/optoverview.html#the_like_optimization
func GlobPrefix(re *regexp.Regexp) string {
prefix, complete := re.LiteralPrefix()
i := strings.IndexAny(prefix, "*?[")
if i < 0 {
if complete {
return prefix
}
i = len(prefix)
func GlobPrefix(expr string) string {
re, err := syntax.Parse(expr, syntax.Perl)
if err != nil {
return "" // no match possible
}
return prefix[:i] + "*"
prog, err := syntax.Compile(re.Simplify())
if err != nil {
return "" // notest
}
i := &prog.Inst[prog.Start]
var empty syntax.EmptyOp
loop1:
for {
switch i.Op {
case syntax.InstFail:
return "" // notest
case syntax.InstCapture, syntax.InstNop:
// skip
case syntax.InstEmptyWidth:
empty |= syntax.EmptyOp(i.Arg)
default:
break loop1
}
i = &prog.Inst[i.Out]
}
if empty&syntax.EmptyBeginText == 0 {
return "*" // not anchored
}
var glob strings.Builder
loop2:
for {
switch i.Op {
case syntax.InstFail:
return "" // notest
case syntax.InstCapture, syntax.InstEmptyWidth, syntax.InstNop:
// skip
case syntax.InstRune, syntax.InstRune1:
if len(i.Rune) != 1 || syntax.Flags(i.Arg)&syntax.FoldCase != 0 {
break loop2
}
switch r := i.Rune[0]; r {
case '*', '?', '[', utf8.RuneError:
break loop2
default:
glob.WriteRune(r)
}
default:
break loop2
}
i = &prog.Inst[i.Out]
}
glob.WriteByte('*')
return glob.String()
}
func load(ctx sqlite3.Context, i int, expr string) (*regexp.Regexp, error) {
func load(ctx sqlite3.Context, arg []sqlite3.Value, i int) (*regexp.Regexp, error) {
re, ok := ctx.GetAuxData(i).(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(expr)
if err != nil {
return nil, err
re, ok = arg[i].Pointer().(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(arg[i].Text())
if err != nil {
return nil, err
}
re = r
}
re = r
ctx.SetAuxData(0, r)
ctx.SetAuxData(i, re)
}
return re, nil
}
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
_ = arg[1] // bounds check
re, err := load(ctx, 0, arg[0].Text())
re, err := load(ctx, arg, 0)
if err != nil {
ctx.ResultError(err)
return // notest
@@ -87,18 +138,17 @@ func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
func regexLike(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, 1, arg[1].Text())
re, err := load(ctx, arg, 1)
if err != nil {
ctx.ResultError(err)
return // notest
}
text := arg[0].RawText()
ctx.ResultBool(re.Match(text))
}
func regexCount(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, 1, arg[1].Text())
re, err := load(ctx, arg, 1)
if err != nil {
ctx.ResultError(err)
return // notest
@@ -113,7 +163,7 @@ func regexCount(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
func regexSubstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, 1, arg[1].Text())
re, err := load(ctx, arg, 1)
if err != nil {
ctx.ResultError(err)
return // notest
@@ -138,7 +188,7 @@ func regexSubstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
func regexInstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, 1, arg[1].Text())
re, err := load(ctx, arg, 1)
if err != nil {
ctx.ResultError(err)
return // notest
@@ -166,16 +216,14 @@ func regexInstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
func regexReplace(ctx sqlite3.Context, arg ...sqlite3.Value) {
_ = arg[2] // bounds check
re, err := load(ctx, 1, arg[1].Text())
re, err := load(ctx, arg, 1)
if err != nil {
ctx.ResultError(err)
return // notest
}
text := arg[0].RawText()
repl := arg[2].RawText()
text := arg[0].RawText()
var pos, n int
if len(arg) > 3 {
pos = arg[3].Int()

View File

@@ -3,8 +3,10 @@ package regexp
import (
"database/sql"
"regexp"
"strings"
"testing"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
@@ -103,24 +105,81 @@ func TestRegister_errors(t *testing.T) {
}
}
func TestRegister_pointer(t *testing.T) {
t.Parallel()
tmp := memdb.TestDB(t)
db, err := driver.Open(tmp, Register)
if err != nil {
t.Fatal(err)
}
defer db.Close()
var got int
err = db.QueryRow(`SELECT regexp_count('ABCABCAXYaxy', ?, 1)`,
sqlite3.Pointer(regexp.MustCompile(`(?i)A.`))).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != 4 {
t.Errorf("got %d, want %d", got, 4)
}
}
func TestGlobPrefix(t *testing.T) {
tests := []struct {
re string
want string
}{
{``, ""},
{`a`, "a"},
{`a*`, "*"},
{`a+`, "a*"},
{`ab*`, "a*"},
{`ab+`, "ab*"},
{`a\?b`, "a*"},
{`[`, ""},
{``, "*"},
{`^`, "*"},
{`a`, "*"},
{`ab`, "*"},
{`^a`, "a*"},
{`^a*`, "*"},
{`^a+`, "a*"},
{`^ab*`, "a*"},
{`^ab+`, "ab*"},
{`^a\?b`, "a*"},
{`^[a-z]`, "*"},
}
for _, tt := range tests {
t.Run(tt.re, func(t *testing.T) {
if got := GlobPrefix(regexp.MustCompile(tt.re)); got != tt.want {
t.Errorf("GlobPrefix() = %v, want %v", got, tt.want)
if got := GlobPrefix(tt.re); got != tt.want {
t.Errorf("GlobPrefix(%v) = %v, want %v", tt.re, got, tt.want)
}
})
}
}
func FuzzGlobPrefix(f *testing.F) {
f.Add(``, ``)
f.Add(`[`, ``)
f.Add(`^`, ``)
f.Add(`a`, `a`)
f.Add(`ab`, `b`)
f.Add(`^a`, `a`)
f.Add(`^a*`, `ab`)
f.Add(`^a+`, `ab`)
f.Add(`^ab*`, `ab`)
f.Add(`^ab+`, `ab`)
f.Add(`^a\?b`, `ab`)
f.Add(`^[a-z]`, `ab`)
f.Fuzz(func(t *testing.T, lit, str string) {
re, err := regexp.Compile(lit)
if err != nil {
t.SkipNow()
}
if re.MatchString(str) {
prefix, ok := strings.CutSuffix(GlobPrefix(lit), "*")
if !ok {
t.Fatalf("missing * after %q for %q with %q", prefix, lit, str)
}
if !strings.HasPrefix(str, prefix) {
t.Fatalf("missing prefix %q for %q with %q", prefix, lit, str)
}
}
})
}

140
ext/serdes/serdes.go Normal file
View File

@@ -0,0 +1,140 @@
// Package serdes provides functions to (de)serialize databases.
package serdes
import (
"io"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/vfs"
)
const vfsName = "github.com/ncruces/go-sqlite3/ext/serdes.sliceVFS"
func init() {
vfs.Register(vfsName, sliceVFS{})
}
var fileToOpen = make(chan *sliceFile, 1)
// Serialize backs up a database into a byte slice.
//
// https://sqlite.org/c3ref/serialize.html
func Serialize(db *sqlite3.Conn, schema string) ([]byte, error) {
var file sliceFile
fileToOpen <- &file
err := db.Backup(schema, "file:serdes.db?vfs="+vfsName)
return file.data, err
}
// Deserialize restores a database from a byte slice,
// DESTROYING any contents previously stored in schema.
//
// To non-destructively open a database from a byte slice,
// consider alternatives like the ["reader"] or ["memdb"] VFSes.
//
// This differs from the similarly named SQLite API
// in that it DOES NOT disconnect from schema
// to reopen as an in-memory database.
//
// https://sqlite.org/c3ref/deserialize.html
//
// ["memdb"]: https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/memdb
// ["reader"]: https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/readervfs
func Deserialize(db *sqlite3.Conn, schema string, data []byte) error {
fileToOpen <- &sliceFile{data}
return db.Restore(schema, "file:serdes.db?vfs="+vfsName)
}
type sliceVFS struct{}
func (sliceVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
if flags&vfs.OPEN_MAIN_DB == 0 || name != "serdes.db" {
return nil, flags, sqlite3.CANTOPEN
}
select {
case file := <-fileToOpen:
return file, flags | vfs.OPEN_MEMORY, nil
default:
return nil, flags, sqlite3.MISUSE
}
}
func (sliceVFS) Delete(name string, dirSync bool) error {
// notest // OPEN_MEMORY
return sqlite3.IOERR_DELETE
}
func (sliceVFS) Access(name string, flag vfs.AccessFlag) (bool, error) {
return name == "serdes.db", nil
}
func (sliceVFS) FullPathname(name string) (string, error) {
return name, nil
}
type sliceFile struct{ data []byte }
func (f *sliceFile) ReadAt(b []byte, off int64) (n int, err error) {
if d := f.data; off < int64(len(d)) {
n = copy(b, d[off:])
}
if n == 0 {
err = io.EOF
}
return
}
func (f *sliceFile) WriteAt(b []byte, off int64) (n int, err error) {
if d := f.data; off > int64(len(d)) {
f.data = append(d, make([]byte, off-int64(len(d)))...)
}
d := append(f.data[:off], b...)
if len(d) > len(f.data) {
f.data = d
}
return len(b), nil
}
func (f *sliceFile) Size() (int64, error) {
return int64(len(f.data)), nil
}
func (f *sliceFile) Truncate(size int64) error {
if d := f.data; size < int64(len(d)) {
f.data = d[:size]
}
return nil
}
func (f *sliceFile) SizeHint(size int64) error {
if d := f.data; size > int64(len(d)) {
f.data = append(d, make([]byte, size-int64(len(d)))...)
}
return nil
}
func (*sliceFile) Close() error { return nil }
func (*sliceFile) Sync(flag vfs.SyncFlag) error { return nil }
func (*sliceFile) Lock(lock vfs.LockLevel) error { return nil }
func (*sliceFile) Unlock(lock vfs.LockLevel) error { return nil }
func (*sliceFile) CheckReservedLock() (bool, error) {
// notest // OPEN_MEMORY
return false, nil
}
func (*sliceFile) SectorSize() int {
// notest // IOCAP_POWERSAFE_OVERWRITE
return 0
}
func (*sliceFile) DeviceCharacteristics() vfs.DeviceCharacteristic {
return vfs.IOCAP_ATOMIC |
vfs.IOCAP_SAFE_APPEND |
vfs.IOCAP_SEQUENTIAL |
vfs.IOCAP_POWERSAFE_OVERWRITE |
vfs.IOCAP_SUBPAGE_READ
}

87
ext/serdes/serdes_test.go Normal file
View File

@@ -0,0 +1,87 @@
package serdes_test
import (
"errors"
"io"
"net/http"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/serdes"
)
func TestDeserialize(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
input, err := httpGet()
if err != nil {
t.Fatal(err)
}
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = serdes.Deserialize(db, "temp", input)
if err != nil {
t.Fatal(err)
}
output, err := serdes.Serialize(db, "temp")
if err != nil {
t.Fatal(err)
}
if len(input) != len(output) {
t.Fatal("lengths are different")
}
for i := range input {
// These may be different.
switch {
case 24 <= i && i < 28:
// File change counter.
continue
case 40 <= i && i < 44:
// Schema cookie.
continue
case 92 <= i && i < 100:
// SQLite version that wrote the file.
continue
}
if input[i] != output[i] {
t.Errorf("difference at %d: %d %d", i, input[i], output[i])
}
}
}
func httpGet() ([]byte, error) {
res, err := http.Get("https://raw.githubusercontent.com/jpwhite3/northwind-SQLite3/refs/heads/main/dist/northwind.db")
if err != nil {
return nil, err
}
defer res.Body.Close()
return io.ReadAll(res.Body)
}
func TestOpen_errors(t *testing.T) {
_, err := sqlite3.Open("file:test.db?vfs=github.com/ncruces/go-sqlite3/ext/serdes.sliceVFS")
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.CANTOPEN) {
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
}
_, err = sqlite3.Open("file:serdes.db?vfs=github.com/ncruces/go-sqlite3/ext/serdes.sliceVFS")
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.MISUSE) {
t.Errorf("got %v, want sqlite3.MISUSE", err)
}
}

View File

@@ -8,6 +8,7 @@ package statement
import (
"encoding/json"
"errors"
"strconv"
"strings"
"unsafe"
@@ -43,7 +44,7 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (*table, error) {
var str strings.Builder
str.WriteString("CREATE TABLE x(")
outputs := stmt.ColumnCount()
for i := 0; i < outputs; i++ {
for i := range outputs {
name := sqlite3.QuoteIdentifier(stmt.ColumnName(i))
str.WriteString(sep)
str.WriteString(name)
@@ -150,17 +151,18 @@ type cursor struct {
func (c *cursor) Close() error {
if c.stmt == c.table.stmt {
c.table.inuse = false
c.stmt.ClearBindings()
return c.stmt.Reset()
return errors.Join(
c.stmt.Reset(),
c.stmt.ClearBindings())
}
return c.stmt.Close()
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
c.arg = arg
c.rowID = 0
c.stmt.ClearBindings()
if err := c.stmt.Reset(); err != nil {
err := errors.Join(
c.stmt.Reset(),
c.stmt.ClearBindings())
if err != nil {
return err
}
@@ -183,6 +185,8 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
return err
}
}
c.arg = append(c.arg[:0], arg...)
c.rowID = 0
return c.Next()
}

View File

@@ -7,7 +7,7 @@ const (
some
)
func newBoolean(kind int) func() sqlite3.AggregateFunction {
func newBoolean(kind int) sqlite3.AggregateConstructor {
return func() sqlite3.AggregateFunction { return &boolean{kind: kind} }
}

19
ext/stats/kahan.go Normal file
View File

@@ -0,0 +1,19 @@
package stats
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
type kahan struct{ hi, lo float64 }
func (k *kahan) add(x float64) {
y := k.lo + x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}
func (k *kahan) sub(x float64) {
y := k.lo - x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}

112
ext/stats/mode.go Normal file
View File

@@ -0,0 +1,112 @@
package stats
import (
"unsafe"
"github.com/ncruces/go-sqlite3"
)
func newMode() sqlite3.AggregateFunction {
return &mode{}
}
type mode struct {
ints counter[int64]
reals counter[float64]
texts counter[string]
blobs counter[string]
}
func (m mode) Value(ctx sqlite3.Context) {
var (
max = 0
typ = sqlite3.NULL
i64 int64
f64 float64
str string
)
for k, v := range m.ints {
if v > max || v == max && k < i64 {
typ = sqlite3.INTEGER
max = v
i64 = k
}
}
f64 = float64(i64)
for k, v := range m.reals {
if v > max || v == max && k < f64 {
typ = sqlite3.FLOAT
max = v
f64 = k
}
}
for k, v := range m.texts {
if v > max || v == max && typ == sqlite3.TEXT && k < str {
typ = sqlite3.TEXT
max = v
str = k
}
}
for k, v := range m.blobs {
if v > max || v == max && typ == sqlite3.BLOB && k < str {
typ = sqlite3.BLOB
max = v
str = k
}
}
switch typ {
case sqlite3.INTEGER:
ctx.ResultInt64(i64)
case sqlite3.FLOAT:
ctx.ResultFloat(f64)
case sqlite3.TEXT:
ctx.ResultText(str)
case sqlite3.BLOB:
ctx.ResultBlob(unsafe.Slice(unsafe.StringData(str), len(str)))
}
}
func (b *mode) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
switch arg[0].Type() {
case sqlite3.INTEGER:
b.ints.add(arg[0].Int64())
case sqlite3.FLOAT:
b.reals.add(arg[0].Float())
case sqlite3.TEXT:
b.texts.add(arg[0].Text())
case sqlite3.BLOB:
b.blobs.add(string(arg[0].RawBlob()))
}
}
func (b *mode) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
switch arg[0].Type() {
case sqlite3.INTEGER:
b.ints.del(arg[0].Int64())
case sqlite3.FLOAT:
b.reals.del(arg[0].Float())
case sqlite3.TEXT:
b.texts.del(arg[0].Text())
case sqlite3.BLOB:
b.blobs.del(string(arg[0].RawBlob()))
}
}
type counter[T comparable] map[T]int
func (c *counter[T]) add(k T) {
if (*c) == nil {
(*c) = make(counter[T])
}
(*c)[k]++
}
func (c counter[T]) del(k T) {
switch n := c[k]; n {
default:
c[k] = n - 1
case 1:
delete(c, k)
case 0:
}
}

85
ext/stats/mode_test.go Normal file
View File

@@ -0,0 +1,85 @@
package stats_test
import (
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
)
func TestRegister_mode(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT mode(column1) FROM (VALUES (NULL), (1), (NULL), (2), (NULL), (3), (3))`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
if got := stmt.ColumnInt(0); got != 3 {
t.Errorf("got %v, want 3", got)
}
}
stmt.Close()
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (1), (1), (2), (2), (3))`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
if got := stmt.ColumnInt(0); got != 1 {
t.Errorf("got %v, want 1", got)
}
}
stmt.Close()
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (0.5), (1), (2.5), (2), (2.5))`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 2.5 {
t.Errorf("got %v, want 2.5", got)
}
}
stmt.Close()
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES ('red'), ('green'), ('blue'), ('red'))`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
if got := stmt.ColumnText(0); got != "red" {
t.Errorf("got %q, want red", got)
}
}
stmt.Close()
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (X'cafebabe'), ('green'), ('blue'), (X'cafebabe'))`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
if got := stmt.ColumnText(0); got != "\xca\xfe\xba\xbe" {
t.Errorf("got %q, want cafebabe", got)
}
}
stmt.Close()
stmt, _, err = db.Prepare(`
SELECT mode(column1) OVER (ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
FROM (VALUES (1), (1), (2.5), ('blue'), (X'cafebabe'), (1), (1))
`)
if err != nil {
t.Fatal(err)
}
for stmt.Step() {
}
stmt.Close()
}

101
ext/stats/moments.go Normal file
View File

@@ -0,0 +1,101 @@
package stats
import "math"
// FisherPearson skewness and kurtosis using
// Terriberry's algorithm with Kahan summation:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
type moments struct {
m1, m2, m3, m4 kahan
n int64
}
func (m moments) mean() float64 {
return m.m1.hi
}
func (m moments) var_pop() float64 {
return m.m2.hi / float64(m.n)
}
func (m moments) var_samp() float64 {
return m.m2.hi / float64(m.n-1) // Bessel's correction
}
func (m moments) stddev_pop() float64 {
return math.Sqrt(m.var_pop())
}
func (m moments) stddev_samp() float64 {
return math.Sqrt(m.var_samp())
}
func (m moments) skewness_pop() float64 {
m2 := m.m2.hi
if div := m2 * m2 * m2; div != 0 {
return m.m3.hi * math.Sqrt(float64(m.n)/div)
}
return math.NaN()
}
func (m moments) skewness_samp() float64 {
n := m.n
// https://mathworks.com/help/stats/skewness.html#f1132178
return m.skewness_pop() * math.Sqrt(float64(n*(n-1))) / float64(n-2)
}
func (m moments) kurtosis_pop() float64 {
return m.raw_kurtosis_pop() - 3
}
func (m moments) raw_kurtosis_pop() float64 {
m2 := m.m2.hi
if div := m2 * m2; div != 0 {
return m.m4.hi * float64(m.n) / div
}
return math.NaN()
}
func (m moments) kurtosis_samp() float64 {
n := m.n
k := math.FMA(m.raw_kurtosis_pop(), float64(n+1), float64(3-3*n))
return k * float64(n-1) / float64((n-2)*(n-3))
}
func (m moments) raw_kurtosis_samp() float64 {
n := m.n
// https://mathworks.com/help/stats/kurtosis.html#f4975293
k := math.FMA(m.raw_kurtosis_pop(), float64(n+1), float64(3-3*n))
return math.FMA(k, float64(n-1)/float64((n-2)*(n-3)), 3)
}
func (m *moments) enqueue(x float64) {
n := m.n + 1
m.n = n
d1 := x - m.m1.hi - m.m1.lo
dn := d1 / float64(n)
d2 := dn * dn
t1 := d1 * dn * float64(n-1)
m.m4.add(t1*d2*float64(n*n-3*n+3) + 6*d2*m.m2.hi - 4*dn*m.m3.hi)
m.m3.add(t1*dn*float64(n-2) - 3*dn*m.m2.hi)
m.m2.add(t1)
m.m1.add(dn)
}
func (m *moments) dequeue(x float64) {
n := m.n - 1
if n <= 0 {
*m = moments{}
return
}
m.n = n
d1 := x - m.m1.hi - m.m1.lo
dn := d1 / float64(n)
d2 := dn * dn
t1 := d1 * dn * float64(n+1)
m.m4.sub(t1*d2*float64(n*n+3*n+3) - 6*d2*m.m2.hi - 4*dn*m.m3.hi)
m.m3.sub(t1*dn*float64(n+2) - 3*dn*m.m2.hi)
m.m2.sub(t1)
m.m1.sub(dn)
}

87
ext/stats/moments_test.go Normal file
View File

@@ -0,0 +1,87 @@
package stats
import (
"math"
"testing"
)
func Test_moments(t *testing.T) {
t.Parallel()
var s1 moments
s1.enqueue(1)
s1.dequeue(1)
if !math.IsNaN(s1.skewness_pop()) {
t.Errorf("want NaN")
}
if !math.IsNaN(s1.raw_kurtosis_pop()) {
t.Errorf("want NaN")
}
s1.enqueue(+0.5377)
s1.enqueue(+1.8339)
s1.enqueue(-2.2588)
s1.enqueue(+0.8622)
s1.enqueue(+0.3188)
s1.enqueue(-1.3077)
s1.enqueue(-0.4336)
s1.enqueue(+0.3426)
s1.enqueue(+3.5784)
s1.enqueue(+2.7694)
if got := s1.skewness_pop(); float32(got) != 0.106098293 {
t.Errorf("got %v, want 0.1061", got)
}
if got := s1.skewness_samp(); float32(got) != 0.1258171 {
t.Errorf("got %v, want 0.1258", got)
}
if got := s1.raw_kurtosis_pop(); float32(got) != 2.3121266 {
t.Errorf("got %v, want 2.3121", got)
}
if got := s1.raw_kurtosis_samp(); float32(got) != 2.7482237 {
t.Errorf("got %v, want 2.7483", got)
}
var s2 welford
s2.enqueue(+0.5377)
s2.enqueue(+1.8339)
s2.enqueue(-2.2588)
s2.enqueue(+0.8622)
s2.enqueue(+0.3188)
s2.enqueue(-1.3077)
s2.enqueue(-0.4336)
s2.enqueue(+0.3426)
s2.enqueue(+3.5784)
s2.enqueue(+2.7694)
if got, want := s1.mean(), s2.mean(); got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := s1.stddev_pop(), s2.stddev_pop(); got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := s1.stddev_samp(), s2.stddev_samp(); got != want {
t.Errorf("got %v, want %v", got, want)
}
s1.enqueue(math.Pi)
s1.enqueue(math.Sqrt2)
s1.enqueue(math.E)
s1.dequeue(math.Pi)
s1.dequeue(math.E)
s1.dequeue(math.Sqrt2)
if got := s1.skewness_pop(); float32(got) != 0.106098293 {
t.Errorf("got %v, want 0.1061", got)
}
if got := s1.skewness_samp(); float32(got) != 0.1258171 {
t.Errorf("got %v, want 0.1258", got)
}
if got := s1.raw_kurtosis_pop(); float32(got) != 2.3121266 {
t.Errorf("got %v, want 2.3121", got)
}
if got := s1.raw_kurtosis_samp(); float32(got) != 2.7482237 {
t.Errorf("got %v, want 2.7483", got)
}
}

View File

@@ -11,6 +11,9 @@ import (
"github.com/ncruces/sort/quick"
)
// Compatible with:
// https://sqlite.org/src/file/ext/misc/percentile.c
const (
median = iota
percentile_100
@@ -18,7 +21,7 @@ const (
percentile_disc
)
func newPercentile(kind int) func() sqlite3.AggregateFunction {
func newPercentile(kind int) sqlite3.AggregateConstructor {
return func() sqlite3.AggregateFunction { return &percentile{kind: kind} }
}

View File

@@ -1,13 +1,17 @@
// Package stats provides aggregate functions for statistics.
//
// Provided functions:
// - stddev_pop: population standard deviation
// - stddev_samp: sample standard deviation
// - var_pop: population variance
// - var_samp: sample variance
// - stddev_pop: population standard deviation
// - stddev_samp: sample standard deviation
// - skewness_pop: Pearson population skewness
// - skewness_samp: Pearson sample skewness
// - kurtosis_pop: Fisher population excess kurtosis
// - kurtosis_samp: Fisher sample excess kurtosis
// - covar_pop: population covariance
// - covar_samp: sample covariance
// - corr: correlation coefficient
// - corr: Pearson correlation coefficient
// - regr_r2: correlation coefficient squared
// - regr_avgx: average of the independent variable
// - regr_avgy: average of the dependent variable
@@ -17,10 +21,12 @@
// - regr_count: count non-null pairs of variables
// - regr_slope: slope of the least-squares-fit linear equation
// - regr_intercept: y-intercept of the least-squares-fit linear equation
// - regr_json: all regr stats in a JSON object
// - percentile_disc: discrete percentile
// - percentile_cont: continuous percentile
// - median: median value
// - regr_json: all regr stats as a JSON object
// - percentile_disc: discrete quantile
// - percentile_cont: continuous quantile
// - percentile: continuous percentile
// - median: middle value
// - mode: most frequent value
// - every: boolean and
// - some: boolean or
//
@@ -59,6 +65,10 @@ func Register(db *sqlite3.Conn) error {
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp)),
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop)),
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp)),
db.CreateWindowFunction("skewness_pop", 1, flags, newMoments(skewness_pop)),
db.CreateWindowFunction("skewness_samp", 1, flags, newMoments(skewness_samp)),
db.CreateWindowFunction("kurtosis_pop", 1, flags, newMoments(kurtosis_pop)),
db.CreateWindowFunction("kurtosis_samp", 1, flags, newMoments(kurtosis_samp)),
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop)),
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp)),
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr)),
@@ -77,7 +87,8 @@ func Register(db *sqlite3.Conn) error {
db.CreateWindowFunction("percentile_cont", 2, order, newPercentile(percentile_cont)),
db.CreateWindowFunction("percentile_disc", 2, order, newPercentile(percentile_disc)),
db.CreateWindowFunction("every", 1, flags, newBoolean(every)),
db.CreateWindowFunction("some", 1, flags, newBoolean(some)))
db.CreateWindowFunction("some", 1, flags, newBoolean(some)),
db.CreateWindowFunction("mode", 1, order, newMode))
}
const (
@@ -85,6 +96,10 @@ const (
var_samp
stddev_pop
stddev_samp
skewness_pop
skewness_samp
kurtosis_pop
kurtosis_samp
corr
regr_r2
regr_sxx
@@ -98,7 +113,24 @@ const (
regr_json
)
func newVariance(kind int) func() sqlite3.AggregateFunction {
func special(kind int, n int64) (null, zero bool) {
switch kind {
case var_pop, stddev_pop, regr_sxx, regr_syy, regr_sxy:
return n <= 0, n == 1
case regr_avgx, regr_avgy:
return n <= 0, false
case kurtosis_samp:
return n <= 3, false
case skewness_samp:
return n <= 2, false
case skewness_pop:
return n <= 1, n == 2
default:
return n <= 1, false
}
}
func newVariance(kind int) sqlite3.AggregateConstructor {
return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
}
@@ -108,6 +140,14 @@ type variance struct {
}
func (fn *variance) Value(ctx sqlite3.Context) {
switch null, zero := special(fn.kind, fn.n); {
case zero:
ctx.ResultFloat(0)
return
case null:
return
}
var r float64
switch fn.kind {
case var_pop:
@@ -138,7 +178,7 @@ func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
}
func newCovariance(kind int) func() sqlite3.AggregateFunction {
func newCovariance(kind int) sqlite3.AggregateConstructor {
return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
}
@@ -148,6 +188,18 @@ type covariance struct {
}
func (fn *covariance) Value(ctx sqlite3.Context) {
if fn.kind == regr_count {
ctx.ResultInt64(fn.regr_count())
return
}
switch null, zero := special(fn.kind, fn.n); {
case zero:
ctx.ResultFloat(0)
return
case null:
return
}
var r float64
switch fn.kind {
case var_pop:
@@ -172,11 +224,9 @@ func (fn *covariance) Value(ctx sqlite3.Context) {
r = fn.regr_slope()
case regr_intercept:
r = fn.regr_intercept()
case regr_count:
ctx.ResultInt64(fn.regr_count())
return
case regr_json:
ctx.ResultText(fn.regr_json())
var buf [128]byte
ctx.ResultRawText(fn.regr_json(buf[:0]))
return
}
ctx.ResultFloat(r)
@@ -203,3 +253,51 @@ func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
fn.dequeue(fa, fb)
}
}
func newMoments(kind int) sqlite3.AggregateConstructor {
return func() sqlite3.AggregateFunction { return &momentfn{kind: kind} }
}
type momentfn struct {
kind int
moments
}
func (fn *momentfn) Value(ctx sqlite3.Context) {
switch null, zero := special(fn.kind, fn.n); {
case zero:
ctx.ResultFloat(0)
return
case null:
return
}
var r float64
switch fn.kind {
case skewness_pop:
r = fn.skewness_pop()
case skewness_samp:
r = fn.skewness_samp()
case kurtosis_pop:
r = fn.kurtosis_pop()
case kurtosis_samp:
r = fn.kurtosis_samp()
}
ctx.ResultFloat(r)
}
func (fn *momentfn) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
a := arg[0]
f := a.Float()
if f != 0.0 || a.NumericType() != sqlite3.NULL {
fn.enqueue(f)
}
}
func (fn *momentfn) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
a := arg[0]
f := a.Float()
if f != 0.0 || a.NumericType() != sqlite3.NULL {
fn.dequeue(f)
}
}

View File

@@ -29,16 +29,29 @@ func TestRegister_variance(t *testing.T) {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT stddev_pop(x) FROM data`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
}
stmt.Close()
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`
stmt, _, err = db.Prepare(`
SELECT
sum(x), avg(x),
var_samp(x), var_pop(x),
stddev_samp(x), stddev_pop(x)
stddev_samp(x), stddev_pop(x),
skewness_samp(x), skewness_pop(x),
kurtosis_samp(x), kurtosis_pop(x)
FROM data`)
if err != nil {
t.Fatal(err)
@@ -62,10 +75,27 @@ func TestRegister_variance(t *testing.T) {
if got := stmt.ColumnFloat(5); got != math.Sqrt(22.5) {
t.Errorf("got %v, want √22.5", got)
}
if got := stmt.ColumnFloat(6); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(7); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(8); float32(got) != -3.3 {
t.Errorf("got %v, want -3.3", got)
}
if got := stmt.ColumnFloat(9); got != -1.64 {
t.Errorf("got %v, want -1.64", got)
}
}
stmt.Close()
stmt, _, err = db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
stmt, _, err = db.Prepare(`
SELECT
var_samp(x) OVER (ROWS 1 PRECEDING),
var_pop(x) OVER (ROWS 1 PRECEDING),
skewness_pop(x) OVER (ROWS 1 PRECEDING)
FROM data`)
if err != nil {
t.Fatal(err)
}
@@ -96,12 +126,26 @@ func TestRegister_covariance(t *testing.T) {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT regr_count(y, x), regr_json(y, x) FROM data`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want 0", got)
}
if got := stmt.ColumnType(1); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
}
stmt.Close()
err = db.Exec(`INSERT INTO data (y, x) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT
stmt, _, err = db.Prepare(`SELECT
corr(y, x), covar_samp(y, x), covar_pop(y, x),
regr_avgy(y, x), regr_avgx(y, x),
regr_syy(y, x), regr_sxx(y, x), regr_sxy(y, x),
@@ -157,7 +201,12 @@ func TestRegister_covariance(t *testing.T) {
}
stmt.Close()
stmt, _, err = db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`)
stmt, _, err = db.Prepare(`
SELECT
covar_samp(y, x) OVER (ROWS 1 PRECEDING),
covar_pop(y, x) OVER (ROWS 1 PRECEDING),
regr_avgx(y, x) OVER (ROWS 1 PRECEDING)
FROM data`)
if err != nil {
t.Fatal(err)
}
@@ -171,6 +220,9 @@ func TestRegister_covariance(t *testing.T) {
t.Errorf("got %v, want %v", got, want[i])
}
}
if stmt.Err() != nil {
t.Fatal(stmt.Err())
}
stmt.Close()
}

View File

@@ -3,22 +3,20 @@ package stats
import (
"math"
"strconv"
"strings"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Welford's algorithm with Kahan summation:
// The effect of truncation in statistical computation [van Reeken, AJ 1970]
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
// See also:
// https://duckdb.org/docs/sql/aggregates.html#statistical-aggregates
type welford struct {
m1, m2 kahan
n int64
}
func (w welford) average() float64 {
func (w welford) mean() float64 {
return w.m1.hi
}
@@ -39,17 +37,23 @@ func (w welford) stddev_samp() float64 {
}
func (w *welford) enqueue(x float64) {
w.n++
n := w.n + 1
w.n = n
d1 := x - w.m1.hi - w.m1.lo
w.m1.add(d1 / float64(w.n))
w.m1.add(d1 / float64(n))
d2 := x - w.m1.hi - w.m1.lo
w.m2.add(d1 * d2)
}
func (w *welford) dequeue(x float64) {
w.n--
n := w.n - 1
if n <= 0 {
*w = welford{}
return
}
w.n = n
d1 := x - w.m1.hi - w.m1.lo
w.m1.sub(d1 / float64(w.n))
w.m1.sub(d1 / float64(n))
d2 := x - w.m1.hi - w.m1.lo
w.m2.sub(d1 * d2)
}
@@ -112,38 +116,35 @@ func (w welford2) regr_r2() float64 {
return w.cov.hi * w.cov.hi / (w.m2y.hi * w.m2x.hi)
}
func (w welford2) regr_json() string {
var json strings.Builder
var num [32]byte
json.Grow(128)
json.WriteString(`{"count":`)
json.Write(strconv.AppendInt(num[:0], w.regr_count(), 10))
json.WriteString(`,"avgy":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_avgy(), 'g', -1, 64))
json.WriteString(`,"avgx":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_avgx(), 'g', -1, 64))
json.WriteString(`,"syy":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_syy(), 'g', -1, 64))
json.WriteString(`,"sxx":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_sxx(), 'g', -1, 64))
json.WriteString(`,"sxy":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_sxy(), 'g', -1, 64))
json.WriteString(`,"slope":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_slope(), 'g', -1, 64))
json.WriteString(`,"intercept":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_intercept(), 'g', -1, 64))
json.WriteString(`,"r2":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_r2(), 'g', -1, 64))
json.WriteByte('}')
return json.String()
func (w welford2) regr_json(dst []byte) []byte {
dst = append(dst, `{"count":`...)
dst = strconv.AppendInt(dst, w.regr_count(), 10)
dst = append(dst, `,"avgy":`...)
dst = util.AppendNumber(dst, w.regr_avgy())
dst = append(dst, `,"avgx":`...)
dst = util.AppendNumber(dst, w.regr_avgx())
dst = append(dst, `,"syy":`...)
dst = util.AppendNumber(dst, w.regr_syy())
dst = append(dst, `,"sxx":`...)
dst = util.AppendNumber(dst, w.regr_sxx())
dst = append(dst, `,"sxy":`...)
dst = util.AppendNumber(dst, w.regr_sxy())
dst = append(dst, `,"slope":`...)
dst = util.AppendNumber(dst, w.regr_slope())
dst = append(dst, `,"intercept":`...)
dst = util.AppendNumber(dst, w.regr_intercept())
dst = append(dst, `,"r2":`...)
dst = util.AppendNumber(dst, w.regr_r2())
return append(dst, '}')
}
func (w *welford2) enqueue(y, x float64) {
w.n++
n := w.n + 1
w.n = n
d1y := y - w.m1y.hi - w.m1y.lo
d1x := x - w.m1x.hi - w.m1x.lo
w.m1y.add(d1y / float64(w.n))
w.m1x.add(d1x / float64(w.n))
w.m1y.add(d1y / float64(n))
w.m1x.add(d1x / float64(n))
d2y := y - w.m1y.hi - w.m1y.lo
d2x := x - w.m1x.hi - w.m1x.lo
w.m2y.add(d1y * d2y)
@@ -152,30 +153,19 @@ func (w *welford2) enqueue(y, x float64) {
}
func (w *welford2) dequeue(y, x float64) {
w.n--
n := w.n - 1
if n <= 0 {
*w = welford2{}
return
}
w.n = n
d1y := y - w.m1y.hi - w.m1y.lo
d1x := x - w.m1x.hi - w.m1x.lo
w.m1y.sub(d1y / float64(w.n))
w.m1x.sub(d1x / float64(w.n))
w.m1y.sub(d1y / float64(n))
w.m1x.sub(d1x / float64(n))
d2y := y - w.m1y.hi - w.m1y.lo
d2x := x - w.m1x.hi - w.m1x.lo
w.m2y.sub(d1y * d2y)
w.m2x.sub(d1x * d2x)
w.cov.sub(d1y * d2x)
}
type kahan struct{ hi, lo float64 }
func (k *kahan) add(x float64) {
y := k.lo + x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}
func (k *kahan) sub(x float64) {
y := k.lo - x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}

View File

@@ -9,12 +9,14 @@ func Test_welford(t *testing.T) {
t.Parallel()
var s1, s2 welford
s1.enqueue(1)
s1.dequeue(1)
s1.enqueue(4)
s1.enqueue(7)
s1.enqueue(13)
s1.enqueue(16)
if got := s1.average(); got != 10 {
if got := s1.mean(); got != 10 {
t.Errorf("got %v, want 10", got)
}
if got := s1.var_samp(); got != 30 {
@@ -43,6 +45,8 @@ func Test_covar(t *testing.T) {
t.Parallel()
var c1, c2 welford2
c1.enqueue(1, 1)
c1.dequeue(1, 1)
c1.enqueue(3, 70)
c1.enqueue(5, 80)

View File

@@ -5,16 +5,18 @@
// - LIKE and REGEXP operators,
// - collation sequences.
//
// It also provides, from PostgreSQL:
// - unaccent(),
// - initcap().
//
// The implementation is not 100% compatible with the [ICU extension]:
// - upper() and lower() use [strings.ToUpper], [strings.ToLower] and [cases];
// - the LIKE operator follows [strings.EqualFold] rules;
// - the REGEXP operator uses Go [regexp/syntax];
// - collation sequences use [collate].
//
// It also provides (approximately) from PostgreSQL:
// - casefold(),
// - initcap(),
// - normalize(),
// - unaccent().
//
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
//
// [ICU extension]: https://sqlite.org/src/dir/ext/icu
@@ -48,13 +50,13 @@ var RegisterLike = true
// Register registers Unicode aware functions for a database connection.
func Register(db *sqlite3.Conn) error {
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
var errs util.ErrorJoiner
var lkfn sqlite3.ScalarFunction
if RegisterLike {
errs.Join(
db.CreateFunction("like", 2, flags, like),
db.CreateFunction("like", 3, flags, like))
lkfn = like
}
errs.Join(
return errors.Join(
db.CreateFunction("like", 2, flags, lkfn),
db.CreateFunction("like", 3, flags, lkfn),
db.CreateFunction("upper", 1, flags, upper),
db.CreateFunction("upper", 2, flags, upper),
db.CreateFunction("lower", 1, flags, lower),
@@ -62,7 +64,10 @@ func Register(db *sqlite3.Conn) error {
db.CreateFunction("regexp", 2, flags, regex),
db.CreateFunction("initcap", 1, flags, initcap),
db.CreateFunction("initcap", 2, flags, initcap),
db.CreateFunction("casefold", 1, flags, casefold),
db.CreateFunction("unaccent", 1, flags, unaccent),
db.CreateFunction("normalize", 1, flags, normalize),
db.CreateFunction("normalize", 2, flags, normalize),
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
name := arg[1].Text()
@@ -76,7 +81,6 @@ func Register(db *sqlite3.Conn) error {
return // notest
}
}))
return errors.Join(errs...)
}
// RegisterCollation registers a Unicode collation sequence for a database connection.
@@ -109,9 +113,8 @@ func upper(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultError(err)
return // notest
}
c := cases.Upper(t)
ctx.SetAuxData(1, c)
cs = c
cs = cases.Upper(t)
ctx.SetAuxData(1, cs)
}
ctx.ResultRawText(cs.Bytes(arg[0].RawText()))
}
@@ -128,9 +131,8 @@ func lower(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultError(err)
return // notest
}
c := cases.Lower(t)
ctx.SetAuxData(1, c)
cs = c
cs = cases.Lower(t)
ctx.SetAuxData(1, cs)
}
ctx.ResultRawText(cs.Bytes(arg[0].RawText()))
}
@@ -147,13 +149,16 @@ func initcap(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultError(err)
return // notest
}
c := cases.Title(t)
ctx.SetAuxData(1, c)
cs = c
cs = cases.Title(t)
ctx.SetAuxData(1, cs)
}
ctx.ResultRawText(cs.Bytes(arg[0].RawText()))
}
func casefold(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultRawText(cases.Fold().Bytes(arg[0].RawText()))
}
func unaccent(ctx sqlite3.Context, arg ...sqlite3.Value) {
unaccent := transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC)
res, _, err := transform.Bytes(unaccent, arg[0].RawText())
@@ -164,16 +169,44 @@ func unaccent(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
}
func normalize(ctx sqlite3.Context, arg ...sqlite3.Value) {
form := norm.NFC
if len(arg) > 1 {
switch strings.ToUpper(arg[1].Text()) {
case "NFC":
//
case "NFD":
form = norm.NFD
case "NFKC":
form = norm.NFKC
case "NFKD":
form = norm.NFKD
default:
ctx.ResultError(util.ErrorString("unicode: invalid form"))
return
}
}
res, _, err := transform.Bytes(form, arg[0].RawText())
if err != nil {
ctx.ResultError(err) // notest
} else {
ctx.ResultRawText(res)
}
}
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
re, ok = arg[0].Pointer().(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
re = r
}
re = r
ctx.SetAuxData(0, r)
ctx.SetAuxData(0, re)
}
ctx.ResultBool(re.Match(arg[1].RawText()))
}

View File

@@ -2,7 +2,7 @@ package unicode
import (
"errors"
"reflect"
"slices"
"testing"
"github.com/ncruces/go-sqlite3"
@@ -49,6 +49,12 @@ func TestRegister(t *testing.T) {
{`upper('Dünyanın İlk Borsası', 'tr-TR')`, "DÜNYANIN İLK BORSASI"},
{`initcap('Kad je hladno Marko nosi džemper')`, "Kad Je Hladno Marko Nosi Džemper"},
{`initcap('Kad je hladno Marko nosi džemper', 'hr-HR')`, "Kad Je Hladno Marko Nosi Džemper"},
{`normalize(X'61cc88')`, "ä"},
{`normalize(X'61cc88', 'NFC' )`, "ä"},
{`normalize(X'61cc88', 'NFKC')`, "ä"},
{`normalize('ä', 'NFD' )`, "\x61\xcc\x88"},
{`normalize('ä', 'NFKD')`, "\x61\xcc\x88"},
{`casefold('Maße')`, "masse"},
{`unaccent('Hôtel')`, "Hotel"},
{`'Hello' REGEXP 'ell'`, "1"},
{`'Hello' REGEXP 'el.'`, "1"},
@@ -115,7 +121,7 @@ func TestRegister_collation(t *testing.T) {
t.Fatal(err)
}
if !reflect.DeepEqual(got, want) {
if !slices.Equal(got, want) {
t.Error("not equal")
}
@@ -166,7 +172,7 @@ func TestRegisterCollationsNeeded(t *testing.T) {
t.Fatal(err)
}
if !reflect.DeepEqual(got, want) {
if !slices.Equal(got, want) {
t.Error("not equal")
}
@@ -208,6 +214,14 @@ func TestRegister_error(t *testing.T) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT normalize('', 'NF')`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT 'hello' REGEXP '\'`)
if err == nil {
t.Error("want error")

View File

@@ -7,6 +7,7 @@ import (
"bytes"
"errors"
"fmt"
"time"
"github.com/google/uuid"
@@ -35,7 +36,9 @@ func Register(db *sqlite3.Conn) error {
db.CreateFunction("uuid", 2, sqlite3.INNOCUOUS, generate),
db.CreateFunction("uuid", 3, sqlite3.INNOCUOUS, generate),
db.CreateFunction("uuid_str", 1, flags, toString),
db.CreateFunction("uuid_blob", 1, flags, toBlob))
db.CreateFunction("uuid_blob", 1, flags, toBlob),
db.CreateFunction("uuid_extract_version", 1, flags, version),
db.CreateFunction("uuid_extract_timestamp", 1, flags, timestamp))
}
func generate(ctx sqlite3.Context, arg ...sqlite3.Value) {
@@ -167,3 +170,30 @@ func toString(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultText(u.String())
}
}
func version(ctx sqlite3.Context, arg ...sqlite3.Value) {
u, err := fromValue(arg[0])
if err != nil {
ctx.ResultError(err)
return // notest
}
if u.Variant() == uuid.RFC4122 {
ctx.ResultInt64(int64(u.Version()))
}
}
func timestamp(ctx sqlite3.Context, arg ...sqlite3.Value) {
u, err := fromValue(arg[0])
if err != nil {
ctx.ResultError(err)
return // notest
}
if u.Variant() == uuid.RFC4122 {
switch u.Version() {
case 1, 2, 6, 7:
ctx.ResultTime(
time.Unix(u.Time().UnixTime()),
sqlite3.TimeFormatDefault)
}
}
}

View File

@@ -2,6 +2,7 @@ package uuid
import (
"testing"
"time"
"github.com/google/uuid"
@@ -106,7 +107,26 @@ func Test_generate(t *testing.T) {
t.Error("want error")
}
hash := []struct {
var tstamp time.Time
var version uuid.Version
err = db.QueryRow(`
SELECT
column1,
uuid_extract_version(column1),
uuid_extract_timestamp(column1)
FROM (VALUES (uuid(7)))
`).Scan(&u, &version, &tstamp)
if err != nil {
t.Fatal(err)
}
if got := u.Version(); got != version {
t.Errorf("got %d, want %d", got, version)
}
if got := time.Unix(u.Time().UnixTime()); !got.Equal(tstamp) {
t.Errorf("got %v, want %v", got, tstamp)
}
tests := []struct {
ver uuid.Version
ns any
data string
@@ -120,7 +140,7 @@ func Test_generate(t *testing.T) {
{3, "url", "https://www.php.net", uuid.MustParse("3f703955-aaba-3e70-a3cb-baff6aa3b28f")},
{5, "url", "https://www.php.net", uuid.MustParse("a8f6ae40-d8a7-58f0-be05-a22f94eca9ec")},
}
for _, tt := range hash {
for _, tt := range tests {
err = db.QueryRow(`SELECT uuid(?, ?, ?)`, tt.ver, tt.ns, tt.data).Scan(&u)
if err != nil {
t.Fatal(err)
@@ -142,14 +162,14 @@ func Test_convert(t *testing.T) {
defer db.Close()
var u uuid.UUID
lits := []string{
tests := []string{
"'6ba7b8119dad11d180b400c04fd430c8'",
"'6ba7b811-9dad-11d1-80b4-00c04fd430c8'",
"'{6ba7b811-9dad-11d1-80b4-00c04fd430c8}'",
"X'6ba7b8119dad11d180b400c04fd430c8'",
}
for _, tt := range lits {
for _, tt := range tests {
err = db.QueryRow(`SELECT uuid_str(` + tt + `)`).Scan(&u)
if err != nil {
t.Fatal(err)
@@ -159,7 +179,7 @@ func Test_convert(t *testing.T) {
}
}
for _, tt := range lits {
for _, tt := range tests {
err = db.QueryRow(`SELECT uuid_blob(` + tt + `)`).Scan(&u)
if err != nil {
t.Fatal(err)
@@ -178,4 +198,14 @@ func Test_convert(t *testing.T) {
if err == nil {
t.Fatal("want error")
}
err = db.QueryRow(`SELECT uuid_extract_version(X'cafe')`).Scan(&u)
if err == nil {
t.Fatal("want error")
}
err = db.QueryRow(`SELECT uuid_extract_timestamp(X'cafe')`).Scan(&u)
if err == nil {
t.Fatal("want error")
}
}

243
func.go
View File

@@ -2,7 +2,10 @@ package sqlite3
import (
"context"
"io"
"iter"
"sync"
"sync/atomic"
"github.com/tetratelabs/wazero/api"
@@ -14,12 +17,12 @@ import (
//
// https://sqlite.org/c3ref/collation_needed.html
func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error {
var enable uint64
var enable int32
if cb != nil {
enable = 1
}
r := c.call("sqlite3_collation_needed_go", uint64(c.handle), enable)
if err := c.error(r); err != nil {
rc := res_t(c.call("sqlite3_collation_needed_go", stk_t(c.handle), stk_t(enable)))
if err := c.error(rc); err != nil {
return err
}
c.collation = cb
@@ -33,8 +36,8 @@ func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error {
// This can be used to load schemas that contain
// one or more unknown collating sequences.
func (c Conn) AnyCollationNeeded() error {
r := c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0)
if err := c.error(r); err != nil {
rc := res_t(c.call("sqlite3_anycollseq_init", stk_t(c.handle), 0, 0))
if err := c.error(rc); err != nil {
return err
}
c.collation = nil
@@ -44,60 +47,103 @@ func (c Conn) AnyCollationNeeded() error {
// CreateCollation defines a new collating sequence.
//
// https://sqlite.org/c3ref/create_collation.html
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
var funcPtr uint32
func (c *Conn) CreateCollation(name string, fn CollatingFunction) error {
var funcPtr ptr_t
defer c.arena.mark()()
namePtr := c.arena.string(name)
if fn != nil {
funcPtr = util.AddHandle(c.ctx, fn)
}
r := c.call("sqlite3_create_collation_go",
uint64(c.handle), uint64(namePtr), uint64(funcPtr))
return c.error(r)
rc := res_t(c.call("sqlite3_create_collation_go",
stk_t(c.handle), stk_t(namePtr), stk_t(funcPtr)))
return c.error(rc)
}
// Collating function is the type of a collation callback.
// Implementations must not retain a or b.
type CollatingFunction func(a, b []byte) int
// CreateFunction defines a new scalar SQL function.
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error {
var funcPtr uint32
var funcPtr ptr_t
defer c.arena.mark()()
namePtr := c.arena.string(name)
if fn != nil {
funcPtr = util.AddHandle(c.ctx, fn)
}
r := c.call("sqlite3_create_function_go",
uint64(c.handle), uint64(namePtr), uint64(nArg),
uint64(flag), uint64(funcPtr))
return c.error(r)
rc := res_t(c.call("sqlite3_create_function_go",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
stk_t(flag), stk_t(funcPtr)))
return c.error(rc)
}
// ScalarFunction is the type of a scalar SQL function.
// Implementations must not retain arg.
type ScalarFunction func(ctx Context, arg ...Value)
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
// If fn returns a [WindowFunction], then an aggregate window function is created.
// If fn returns an [io.Closer], it will be called to free resources.
// CreateAggregateFunction defines a new aggregate SQL function.
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
var funcPtr uint32
func (c *Conn) CreateAggregateFunction(name string, nArg int, flag FunctionFlag, fn AggregateSeqFunction) error {
var funcPtr ptr_t
defer c.arena.mark()()
namePtr := c.arena.string(name)
if fn != nil {
funcPtr = util.AddHandle(c.ctx, fn)
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
var a aggregateFunc
coro := func(yieldCoro func(struct{}) bool) {
seq := func(yieldSeq func([]Value) bool) {
for yieldSeq(a.arg) {
if !yieldCoro(struct{}{}) {
break
}
}
}
fn(&a.ctx, seq)
}
a.next, a.stop = iter.Pull(coro)
return &a
}))
}
call := "sqlite3_create_aggregate_function_go"
if _, ok := fn().(WindowFunction); ok {
call = "sqlite3_create_window_function_go"
}
r := c.call(call,
uint64(c.handle), uint64(namePtr), uint64(nArg),
uint64(flag), uint64(funcPtr))
return c.error(r)
rc := res_t(c.call("sqlite3_create_aggregate_function_go",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
stk_t(flag), stk_t(funcPtr)))
return c.error(rc)
}
// AggregateSeqFunction is the type of an aggregate SQL function.
// Implementations must not retain the slices yielded by seq.
type AggregateSeqFunction func(ctx *Context, seq iter.Seq[[]Value])
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
// If fn returns a [WindowFunction], an aggregate window function is created.
// If fn returns an [io.Closer], it will be called to free resources.
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn AggregateConstructor) error {
var funcPtr ptr_t
defer c.arena.mark()()
namePtr := c.arena.string(name)
if fn != nil {
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
agg := fn()
if win, ok := agg.(WindowFunction); ok {
return win
}
return windowFunc{agg, name}
}))
}
rc := res_t(c.call("sqlite3_create_window_function_go",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
stk_t(flag), stk_t(funcPtr)))
return c.error(rc)
}
// AggregateConstructor is a an [AggregateFunction] constructor.
type AggregateConstructor func() AggregateFunction
// AggregateFunction is the interface an aggregate function should implement.
//
// https://sqlite.org/appfunc.html
@@ -129,102 +175,145 @@ type WindowFunction interface {
func (c *Conn) OverloadFunction(name string, nArg int) error {
defer c.arena.mark()()
namePtr := c.arena.string(name)
r := c.call("sqlite3_overload_function",
uint64(c.handle), uint64(namePtr), uint64(nArg))
return c.error(r)
rc := res_t(c.call("sqlite3_overload_function",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg)))
return c.error(rc)
}
func destroyCallback(ctx context.Context, mod api.Module, pApp uint32) {
func destroyCallback(ctx context.Context, mod api.Module, pApp ptr_t) {
util.DelHandle(ctx, pApp)
}
func collationCallback(ctx context.Context, mod api.Module, pArg, pDB, eTextRep, zName uint32) {
func collationCallback(ctx context.Context, mod api.Module, pArg, pDB ptr_t, eTextRep uint32, zName ptr_t) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.collation != nil {
name := util.ReadString(mod, zName, _MAX_NAME)
c.collation(c, name)
}
}
func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))
func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int32, pKey1 ptr_t, nKey2 int32, pKey2 ptr_t) uint32 {
fn := util.GetHandle(ctx, pApp).(CollatingFunction)
return uint32(fn(util.View(mod, pKey1, int64(nKey1)), util.View(mod, pKey2, int64(nKey2))))
}
func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp, nArg, pArg uint32) {
args := getFuncArgs()
defer putFuncArgs(args)
func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp ptr_t, nArg int32, pArg ptr_t) {
db := ctx.Value(connKey{}).(*Conn)
args := callbackArgs(db, nArg, pArg)
defer returnArgs(args)
fn := util.GetHandle(db.ctx, pApp).(ScalarFunction)
callbackArgs(db, args[:nArg], pArg)
fn(Context{db, pCtx}, args[:nArg]...)
fn(Context{db, pCtx}, *args...)
}
func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp, nArg, pArg uint32) {
args := getFuncArgs()
defer putFuncArgs(args)
func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, nArg int32, pArg ptr_t) {
db := ctx.Value(connKey{}).(*Conn)
callbackArgs(db, args[:nArg], pArg)
args := callbackArgs(db, nArg, pArg)
defer returnArgs(args)
fn, _ := callbackAggregate(db, pAgg, pApp)
fn.Step(Context{db, pCtx}, args[:nArg]...)
fn.Step(Context{db, pCtx}, *args...)
}
func finalCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp uint32) {
func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, final int32) {
db := ctx.Value(connKey{}).(*Conn)
fn, handle := callbackAggregate(db, pAgg, pApp)
fn.Value(Context{db, pCtx})
util.DelHandle(ctx, handle)
// Cleanup.
if final != 0 {
var err error
if handle != 0 {
err = util.DelHandle(ctx, handle)
} else if c, ok := fn.(io.Closer); ok {
err = c.Close()
}
if err != nil {
Context{db, pCtx}.ResultError(err)
return // notest
}
}
}
func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg uint32) {
func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t, nArg int32, pArg ptr_t) {
db := ctx.Value(connKey{}).(*Conn)
fn := util.GetHandle(db.ctx, pAgg).(AggregateFunction)
fn.Value(Context{db, pCtx})
}
func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg, nArg, pArg uint32) {
args := getFuncArgs()
defer putFuncArgs(args)
db := ctx.Value(connKey{}).(*Conn)
callbackArgs(db, args[:nArg], pArg)
args := callbackArgs(db, nArg, pArg)
defer returnArgs(args)
fn := util.GetHandle(db.ctx, pAgg).(WindowFunction)
fn.Inverse(Context{db, pCtx}, args[:nArg]...)
fn.Inverse(Context{db, pCtx}, *args...)
}
func callbackAggregate(db *Conn, pAgg, pApp uint32) (AggregateFunction, uint32) {
func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) {
if pApp == 0 {
handle := util.ReadUint32(db.mod, pAgg)
handle := util.Read32[ptr_t](db.mod, pAgg)
return util.GetHandle(db.ctx, handle).(AggregateFunction), handle
}
// We need to create the aggregate.
fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)()
fn := util.GetHandle(db.ctx, pApp).(AggregateConstructor)()
if pAgg != 0 {
handle := util.AddHandle(db.ctx, fn)
util.WriteUint32(db.mod, pAgg, handle)
util.Write32(db.mod, pAgg, handle)
return fn, handle
}
return fn, 0
}
func callbackArgs(db *Conn, arg []Value, pArg uint32) {
for i := range arg {
arg[i] = Value{
var (
valueArgsPool sync.Pool
valueArgsLen atomic.Int32
)
func callbackArgs(db *Conn, nArg int32, pArg ptr_t) *[]Value {
arg, ok := valueArgsPool.Get().(*[]Value)
if !ok || cap(*arg) < int(nArg) {
max := valueArgsLen.Or(nArg) | nArg
lst := make([]Value, max)
arg = &lst
}
lst := (*arg)[:nArg]
for i := range lst {
lst[i] = Value{
c: db,
handle: util.ReadUint32(db.mod, pArg+ptrlen*uint32(i)),
handle: util.Read32[ptr_t](db.mod, pArg+ptr_t(i)*ptrlen),
}
}
*arg = lst
return arg
}
var funcArgsPool sync.Pool
func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) {
funcArgsPool.Put(p)
func returnArgs(p *[]Value) {
valueArgsPool.Put(p)
}
func getFuncArgs() *[_MAX_FUNCTION_ARG]Value {
if p := funcArgsPool.Get(); p == nil {
return new([_MAX_FUNCTION_ARG]Value)
} else {
return p.(*[_MAX_FUNCTION_ARG]Value)
type aggregateFunc struct {
ctx Context
arg []Value
next func() (struct{}, bool)
stop func()
}
func (a *aggregateFunc) Step(ctx Context, arg ...Value) {
a.ctx = ctx
a.arg = append(a.arg[:0], arg...)
if _, more := a.next(); !more {
a.stop()
}
}
func (a *aggregateFunc) Value(ctx Context) {
a.ctx = ctx
a.stop()
}
func (a *aggregateFunc) Close() error {
a.stop()
return nil
}
type windowFunc struct {
AggregateFunction
name string
}
func (w windowFunc) Inverse(ctx Context, arg ...Value) {
// Implementing inverse allows certain queries that don't really need it to succeed.
ctx.ResultError(util.ErrorString(w.name + ": may not be used as a window function"))
}

57
func_seq_test.go Normal file
View File

@@ -0,0 +1,57 @@
package sqlite3_test
import (
"fmt"
"iter"
"log"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func ExampleConn_CreateAggregateFunction() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE test (col)`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (1), (2), (3)`)
if err != nil {
log.Fatal(err)
}
err = db.CreateAggregateFunction("seq_avg", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS,
func(ctx *sqlite3.Context, seq iter.Seq[[]sqlite3.Value]) {
count := 0
total := 0.0
for arg := range seq {
total += arg[0].Float()
count++
}
ctx.ResultFloat(total / float64(count))
})
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT seq_avg(col) FROM test`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnFloat(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Output:
// 2
}

21
go.mod
View File

@@ -1,24 +1,27 @@
module github.com/ncruces/go-sqlite3
go 1.21
go 1.23.0
toolchain go1.23.0
toolchain go1.24.0
require (
github.com/ncruces/julianday v1.0.0
github.com/ncruces/sort v0.1.2
github.com/tetratelabs/wazero v1.8.2
golang.org/x/crypto v0.32.0
golang.org/x/sys v0.29.0
github.com/ncruces/sort v0.1.5
github.com/tetratelabs/wazero v1.9.0
golang.org/x/crypto v0.36.0
golang.org/x/sys v0.31.0
)
require (
github.com/dchest/siphash v1.2.3 // ext/bloom
github.com/google/uuid v1.6.0 // ext/uuid
github.com/psanford/httpreadat v0.1.0 // example
golang.org/x/sync v0.10.0 // test
golang.org/x/text v0.21.0 // ext/unicode
golang.org/x/sync v0.12.0 // test
golang.org/x/text v0.23.0 // ext/unicode
lukechampine.com/adiantum v1.1.1 // vfs/adiantum
)
retract v0.4.0 // tagged from the wrong branch
retract (
v0.23.2 // tagged from the wrong branch
v0.4.0 // tagged from the wrong branch
)

24
go.sum
View File

@@ -4,19 +4,19 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/ncruces/sort v0.1.2 h1:zKQ9CA4fpHPF6xsUhRTfi5EEryspuBpe/QA4VWQOV1U=
github.com/ncruces/sort v0.1.2/go.mod h1:vEJUTBJtebIuCMmXD18GKo5GJGhsay+xZFOoBEIXFmE=
github.com/ncruces/sort v0.1.5 h1:fiFWXXAqKI8QckPf/6hu/bGFwcEPrirIOFaJqWujs4k=
github.com/ncruces/sort v0.1.5/go.mod h1:obJToO4rYr6VWP0Uw5FYymgYGt3Br4RXcs/JdKaXAPk=
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
github.com/tetratelabs/wazero v1.8.2 h1:yIgLR/b2bN31bjxwXHD8a3d+BogigR952csSDdLYEv4=
github.com/tetratelabs/wazero v1.8.2/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I=
github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
lukechampine.com/adiantum v1.1.1 h1:4fp6gTxWCqpEbLy40ExiYDDED3oUNWx5cTqBCtPdZqA=
lukechampine.com/adiantum v1.1.1/go.mod h1:LrAYVnTYLnUtE/yMp5bQr0HstAf060YUF8nM0B6+rUw=

View File

@@ -1,7 +0,0 @@
go 1.21
use (
.
./gormlite
./embed/bcw2
)

View File

@@ -1,17 +0,0 @@
github.com/ncruces/go-sqlite3 v0.21.0/go.mod h1:zxMOaSG5kFYVFK4xQa0pdwIszqxqJ0W0BxBgwdrNjuA=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8=
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/term v0.26.0/go.mod h1:Si5m1o57C5nBNQo5z1iq+XDijt21BDBDp2bK0QI8e3E=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=

View File

@@ -1,11 +1,11 @@
module github.com/ncruces/go-sqlite3/gormlite
go 1.21
go 1.23.0
toolchain go1.23.0
toolchain go1.24.0
require (
github.com/ncruces/go-sqlite3 v0.21.3
github.com/ncruces/go-sqlite3 v0.24.0
gorm.io/gorm v1.25.12
)
@@ -13,7 +13,7 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/ncruces/julianday v1.0.0 // indirect
github.com/tetratelabs/wazero v1.8.2 // indirect
golang.org/x/sys v0.29.0 // indirect
golang.org/x/text v0.21.0 // indirect
github.com/tetratelabs/wazero v1.9.0 // indirect
golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.22.0 // indirect
)

View File

@@ -2,15 +2,15 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/ncruces/go-sqlite3 v0.21.3 h1:hHkfNQLcbnxPJZhC/RGw9SwP3bfkv/Y0xUHWsr1CdMQ=
github.com/ncruces/go-sqlite3 v0.21.3/go.mod h1:zxMOaSG5kFYVFK4xQa0pdwIszqxqJ0W0BxBgwdrNjuA=
github.com/ncruces/go-sqlite3 v0.24.0 h1:Z4jfmzu2NCd4SmyFwLT2OmF3EnTZbqwATvdiuNHNhLA=
github.com/ncruces/go-sqlite3 v0.24.0/go.mod h1:/Vs8ACZHjJ1SA6E9RZUn3EyB1OP3nDQ4z/ar+0fplTQ=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.8.2 h1:yIgLR/b2bN31bjxwXHD8a3d+BogigR952csSDdLYEv4=
github.com/tetratelabs/wazero v1.8.2/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I=
github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=

View File

@@ -7,18 +7,12 @@ import (
"github.com/tetratelabs/wazero"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// notest
func init() {
if util.CompilerSupported() {
sqlite3.RuntimeConfig = wazero.NewRuntimeConfigCompiler()
} else {
sqlite3.RuntimeConfig = wazero.NewRuntimeConfigInterpreter()
}
sqlite3.RuntimeConfig = sqlite3.RuntimeConfig.WithMemoryLimitPages(512)
sqlite3.RuntimeConfig = wazero.NewRuntimeConfig().WithMemoryLimitPages(512)
if os.Getenv("CI") != "" {
path := filepath.Join(os.TempDir(), "wazero")
if err := os.MkdirAll(path, 0777); err == nil {

View File

@@ -1,27 +0,0 @@
package util
import (
"runtime"
"golang.org/x/sys/cpu"
)
func CompilerSupported() bool {
switch runtime.GOOS {
case "linux", "android",
"windows", "darwin",
"freebsd", "netbsd", "dragonfly",
"solaris", "illumos":
break
default:
return false
}
switch runtime.GOARCH {
case "amd64":
return cpu.X86.HasSSE41
case "arm64":
return true
default:
return false
}
}

View File

@@ -7,9 +7,6 @@ import (
"github.com/tetratelabs/wazero/api"
)
type i32 interface{ ~int32 | ~uint32 }
type i64 interface{ ~int64 | ~uint64 }
type funcVI[T0 i32] func(context.Context, api.Module, T0)
func (fn funcVI[T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {

View File

@@ -20,7 +20,7 @@ func (s *handleState) CloseNotify(ctx context.Context, exitCode uint32) {
s.holes = 0
}
func GetHandle(ctx context.Context, id uint32) any {
func GetHandle(ctx context.Context, id Ptr_t) any {
if id == 0 {
return nil
}
@@ -28,14 +28,14 @@ func GetHandle(ctx context.Context, id uint32) any {
return s.handles[^id]
}
func DelHandle(ctx context.Context, id uint32) error {
func DelHandle(ctx context.Context, id Ptr_t) error {
if id == 0 {
return nil
}
s := ctx.Value(moduleKey{}).(*moduleState)
a := s.handles[^id]
s.handles[^id] = nil
if l := uint32(len(s.handles)); l == ^id {
if l := Ptr_t(len(s.handles)); l == ^id {
s.handles = s.handles[:l-1]
} else {
s.holes++
@@ -46,7 +46,7 @@ func DelHandle(ctx context.Context, id uint32) error {
return nil
}
func AddHandle(ctx context.Context, a any) uint32 {
func AddHandle(ctx context.Context, a any) Ptr_t {
if a == nil {
panic(NilErr)
}
@@ -59,12 +59,12 @@ func AddHandle(ctx context.Context, a any) uint32 {
if h == nil {
s.holes--
s.handles[id] = a
return ^uint32(id)
return ^Ptr_t(id)
}
}
}
// Add a new slot.
s.handles = append(s.handles, a)
return -uint32(len(s.handles))
return -Ptr_t(len(s.handles))
}

View File

@@ -2,6 +2,7 @@ package util
import (
"encoding/json"
"math"
"strconv"
"time"
"unsafe"
@@ -20,7 +21,7 @@ func (j JSON) Scan(value any) error {
case int64:
buf = strconv.AppendInt(nil, v, 10)
case float64:
buf = strconv.AppendFloat(nil, v, 'g', -1, 64)
buf = AppendNumber(nil, v)
case time.Time:
buf = append(buf, '"')
buf = v.AppendFormat(buf, time.RFC3339Nano)
@@ -33,3 +34,17 @@ func (j JSON) Scan(value any) error {
return json.Unmarshal(buf, j.Value)
}
func AppendNumber(dst []byte, f float64) []byte {
switch {
case math.IsNaN(f):
dst = append(dst, "null"...)
case math.IsInf(f, 1):
dst = append(dst, "9.0e999"...)
case math.IsInf(f, -1):
dst = append(dst, "-9.0e999"...)
default:
return strconv.AppendFloat(dst, f, 'g', -1, 64)
}
return dst
}

View File

@@ -7,110 +7,130 @@ import (
"github.com/tetratelabs/wazero/api"
)
func View(mod api.Module, ptr uint32, size uint64) []byte {
const (
PtrLen = 4
IntLen = 4
)
type (
i8 interface{ ~int8 | ~uint8 }
i32 interface{ ~int32 | ~uint32 }
i64 interface{ ~int64 | ~uint64 }
Stk_t = uint64
Ptr_t uint32
Res_t int32
)
func View(mod api.Module, ptr Ptr_t, size int64) []byte {
if ptr == 0 {
panic(NilErr)
}
if size > math.MaxUint32 {
if uint64(size) > math.MaxUint32 {
panic(RangeErr)
}
if size == 0 {
return nil
}
buf, ok := mod.Memory().Read(ptr, uint32(size))
buf, ok := mod.Memory().Read(uint32(ptr), uint32(size))
if !ok {
panic(RangeErr)
}
return buf
}
func ReadUint8(mod api.Module, ptr uint32) uint8 {
func Read[T i8](mod api.Module, ptr Ptr_t) T {
if ptr == 0 {
panic(NilErr)
}
v, ok := mod.Memory().ReadByte(ptr)
v, ok := mod.Memory().ReadByte(uint32(ptr))
if !ok {
panic(RangeErr)
}
return v
return T(v)
}
func ReadUint32(mod api.Module, ptr uint32) uint32 {
func Write[T i8](mod api.Module, ptr Ptr_t, v T) {
if ptr == 0 {
panic(NilErr)
}
v, ok := mod.Memory().ReadUint32Le(ptr)
if !ok {
panic(RangeErr)
}
return v
}
func WriteUint8(mod api.Module, ptr uint32, v uint8) {
if ptr == 0 {
panic(NilErr)
}
ok := mod.Memory().WriteByte(ptr, v)
ok := mod.Memory().WriteByte(uint32(ptr), uint8(v))
if !ok {
panic(RangeErr)
}
}
func WriteUint32(mod api.Module, ptr uint32, v uint32) {
func Read32[T i32](mod api.Module, ptr Ptr_t) T {
if ptr == 0 {
panic(NilErr)
}
ok := mod.Memory().WriteUint32Le(ptr, v)
v, ok := mod.Memory().ReadUint32Le(uint32(ptr))
if !ok {
panic(RangeErr)
}
return T(v)
}
func Write32[T i32](mod api.Module, ptr Ptr_t, v T) {
if ptr == 0 {
panic(NilErr)
}
ok := mod.Memory().WriteUint32Le(uint32(ptr), uint32(v))
if !ok {
panic(RangeErr)
}
}
func ReadUint64(mod api.Module, ptr uint32) uint64 {
func Read64[T i64](mod api.Module, ptr Ptr_t) T {
if ptr == 0 {
panic(NilErr)
}
v, ok := mod.Memory().ReadUint64Le(ptr)
v, ok := mod.Memory().ReadUint64Le(uint32(ptr))
if !ok {
panic(RangeErr)
}
return v
return T(v)
}
func WriteUint64(mod api.Module, ptr uint32, v uint64) {
func Write64[T i64](mod api.Module, ptr Ptr_t, v T) {
if ptr == 0 {
panic(NilErr)
}
ok := mod.Memory().WriteUint64Le(ptr, v)
ok := mod.Memory().WriteUint64Le(uint32(ptr), uint64(v))
if !ok {
panic(RangeErr)
}
}
func ReadFloat64(mod api.Module, ptr uint32) float64 {
return math.Float64frombits(ReadUint64(mod, ptr))
func ReadFloat64(mod api.Module, ptr Ptr_t) float64 {
return math.Float64frombits(Read64[uint64](mod, ptr))
}
func WriteFloat64(mod api.Module, ptr uint32, v float64) {
WriteUint64(mod, ptr, math.Float64bits(v))
func WriteFloat64(mod api.Module, ptr Ptr_t, v float64) {
Write64(mod, ptr, math.Float64bits(v))
}
func ReadString(mod api.Module, ptr, maxlen uint32) string {
func ReadBool(mod api.Module, ptr Ptr_t) bool {
return Read32[int32](mod, ptr) != 0
}
func WriteBool(mod api.Module, ptr Ptr_t, v bool) {
var i int32
if v {
i = 1
}
Write32(mod, ptr, i)
}
func ReadString(mod api.Module, ptr Ptr_t, maxlen int64) string {
if ptr == 0 {
panic(NilErr)
}
switch maxlen {
case 0:
if maxlen <= 0 {
return ""
case math.MaxUint32:
// avoid overflow
default:
maxlen = maxlen + 1
}
mem := mod.Memory()
buf, ok := mem.Read(ptr, maxlen)
maxlen = min(maxlen, math.MaxInt32-1) + 1
buf, ok := mem.Read(uint32(ptr), uint32(maxlen))
if !ok {
buf, ok = mem.Read(ptr, mem.Size()-ptr)
buf, ok = mem.Read(uint32(ptr), mem.Size()-uint32(ptr))
if !ok {
panic(RangeErr)
}
@@ -122,13 +142,13 @@ func ReadString(mod api.Module, ptr, maxlen uint32) string {
}
}
func WriteBytes(mod api.Module, ptr uint32, b []byte) {
buf := View(mod, ptr, uint64(len(b)))
func WriteBytes(mod api.Module, ptr Ptr_t, b []byte) {
buf := View(mod, ptr, int64(len(b)))
copy(buf, b)
}
func WriteString(mod api.Module, ptr uint32, s string) {
buf := View(mod, ptr, uint64(len(s)+1))
func WriteString(mod api.Module, ptr Ptr_t, s string) {
buf := View(mod, ptr, int64(len(s))+1)
buf[len(s)] = 0
copy(buf, s)
}

View File

@@ -31,90 +31,90 @@ func TestView_overflow(t *testing.T) {
func TestReadUint8_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadUint8(mock, 0)
Read[byte](mock, 0)
t.Error("want panic")
}
func TestReadUint8_range(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadUint8(mock, wazerotest.PageSize)
Read[byte](mock, wazerotest.PageSize)
t.Error("want panic")
}
func TestReadUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadUint32(mock, 0)
Read32[uint32](mock, 0)
t.Error("want panic")
}
func TestReadUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadUint32(mock, wazerotest.PageSize-2)
Read32[uint32](mock, wazerotest.PageSize-2)
t.Error("want panic")
}
func TestReadUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadUint64(mock, 0)
Read64[uint64](mock, 0)
t.Error("want panic")
}
func TestReadUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadUint64(mock, wazerotest.PageSize-2)
Read64[uint64](mock, wazerotest.PageSize-2)
t.Error("want panic")
}
func TestWriteUint8_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
WriteUint8(mock, 0, 1)
Write[byte](mock, 0, 1)
t.Error("want panic")
}
func TestWriteUint8_range(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
WriteUint8(mock, wazerotest.PageSize, 1)
Write[byte](mock, wazerotest.PageSize, 1)
t.Error("want panic")
}
func TestWriteUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
WriteUint32(mock, 0, 1)
Write32[uint32](mock, 0, 1)
t.Error("want panic")
}
func TestWriteUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
WriteUint32(mock, wazerotest.PageSize-2, 1)
Write32[uint32](mock, wazerotest.PageSize-2, 1)
t.Error("want panic")
}
func TestWriteUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
WriteUint64(mock, 0, 1)
Write64[uint64](mock, 0, 1)
t.Error("want panic")
}
func TestWriteUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
WriteUint64(mock, wazerotest.PageSize-2, 1)
Write64[uint64](mock, wazerotest.PageSize-2, 1)
t.Error("want panic")
}
func TestReadString_range(t *testing.T) {
defer func() { _ = recover() }()
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadString(mock, wazerotest.PageSize+2, math.MaxUint32)
ReadString(mock, wazerotest.PageSize+2, math.MaxInt)
t.Error("want panic")
}

View File

@@ -25,9 +25,9 @@ func (s *mmapState) new(ctx context.Context, mod api.Module, size int32) *Mapped
// Allocate page aligned memmory.
alloc := mod.ExportedFunction("aligned_alloc")
stack := [...]uint64{
uint64(unix.Getpagesize()),
uint64(size),
stack := [...]Stk_t{
Stk_t(unix.Getpagesize()),
Stk_t(size),
}
if err := alloc.CallWithStack(ctx, stack[:]); err != nil {
panic(err)
@@ -37,20 +37,20 @@ func (s *mmapState) new(ctx context.Context, mod api.Module, size int32) *Mapped
}
// Save the newly allocated region.
ptr := uint32(stack[0])
buf := View(mod, ptr, uint64(size))
res := &MappedRegion{
ptr := Ptr_t(stack[0])
buf := View(mod, ptr, int64(size))
ret := &MappedRegion{
Ptr: ptr,
size: size,
addr: unsafe.Pointer(&buf[0]),
}
s.regions = append(s.regions, res)
return res
s.regions = append(s.regions, ret)
return ret
}
type MappedRegion struct {
addr unsafe.Pointer
Ptr uint32
Ptr Ptr_t
size int32
used bool
}

View File

@@ -29,13 +29,13 @@ func MapRegion(ctx context.Context, mod api.Module, f *os.File, offset int64, si
return nil, err
}
res := &MappedRegion{Handle: h, addr: a}
ret := &MappedRegion{Handle: h, addr: a}
// SliceHeader, although deprecated, avoids a go vet warning.
sh := (*reflect.SliceHeader)(unsafe.Pointer(&res.Data))
sh := (*reflect.SliceHeader)(unsafe.Pointer(&ret.Data))
sh.Len = int(size)
sh.Cap = int(size)
sh.Data = a
return res, nil
return ret, nil
}
func (r *MappedRegion) Unmap() error {

View File

@@ -8,6 +8,8 @@ import (
"github.com/ncruces/go-sqlite3/internal/alloc"
)
type ConnKey struct{}
type moduleKey struct{}
type moduleState struct {
mmapState

View File

@@ -3,7 +3,6 @@ package sqlite3
import (
"context"
"math"
"math/bits"
"os"
"sync"
@@ -48,18 +47,14 @@ func compileSQLite() {
ctx := context.Background()
cfg := RuntimeConfig
if cfg == nil {
if util.CompilerSupported() {
cfg = wazero.NewRuntimeConfigCompiler()
} else {
cfg = wazero.NewRuntimeConfigInterpreter()
}
cfg = wazero.NewRuntimeConfig()
if bits.UintSize < 64 {
cfg = cfg.WithMemoryLimitPages(512) // 32MB
} else {
cfg = cfg.WithMemoryLimitPages(4096) // 256MB
}
cfg = cfg.WithCoreFeatures(api.CoreFeaturesV2)
}
cfg = cfg.WithCoreFeatures(api.CoreFeaturesV2)
instance.runtime = wazero.NewRuntimeWithConfig(ctx, cfg)
@@ -94,7 +89,7 @@ type sqlite struct {
id [32]*byte
mask uint32
}
stack [9]uint64
stack [9]stk_t
}
func instantiateSQLite() (sqlt *sqlite, err error) {
@@ -120,7 +115,7 @@ func (sqlt *sqlite) close() error {
return sqlt.mod.Close(sqlt.ctx)
}
func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
func (sqlt *sqlite) error(rc res_t, handle ptr_t, sql ...string) error {
if rc == _OK {
return nil
}
@@ -131,18 +126,18 @@ func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
panic(util.OOMErr)
}
if r := sqlt.call("sqlite3_errstr", rc); r != 0 {
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_NAME)
if ptr := ptr_t(sqlt.call("sqlite3_errstr", stk_t(rc))); ptr != 0 {
err.str = util.ReadString(sqlt.mod, ptr, _MAX_NAME)
}
if handle != 0 {
if r := sqlt.call("sqlite3_errmsg", uint64(handle)); r != 0 {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_LENGTH)
if ptr := ptr_t(sqlt.call("sqlite3_errmsg", stk_t(handle))); ptr != 0 {
err.msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH)
}
if len(sql) != 0 {
if r := sqlt.call("sqlite3_error_offset", uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:]
if i := int32(sqlt.call("sqlite3_error_offset", stk_t(handle))); i != -1 {
err.sql = sql[0][i:]
}
}
}
@@ -182,7 +177,7 @@ func (sqlt *sqlite) putfn(name string, fn api.Function) {
}
}
func (sqlt *sqlite) call(name string, params ...uint64) uint64 {
func (sqlt *sqlite) call(name string, params ...stk_t) stk_t {
copy(sqlt.stack[:], params)
fn := sqlt.getfn(name)
err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
@@ -190,33 +185,33 @@ func (sqlt *sqlite) call(name string, params ...uint64) uint64 {
panic(err)
}
sqlt.putfn(name, fn)
return sqlt.stack[0]
return stk_t(sqlt.stack[0])
}
func (sqlt *sqlite) free(ptr uint32) {
func (sqlt *sqlite) free(ptr ptr_t) {
if ptr == 0 {
return
}
sqlt.call("sqlite3_free", uint64(ptr))
sqlt.call("sqlite3_free", stk_t(ptr))
}
func (sqlt *sqlite) new(size uint64) uint32 {
ptr := uint32(sqlt.call("sqlite3_malloc64", size))
func (sqlt *sqlite) new(size int64) ptr_t {
ptr := ptr_t(sqlt.call("sqlite3_malloc64", stk_t(size)))
if ptr == 0 && size != 0 {
panic(util.OOMErr)
}
return ptr
}
func (sqlt *sqlite) realloc(ptr uint32, size uint64) uint32 {
ptr = uint32(sqlt.call("sqlite3_realloc64", uint64(ptr), size))
func (sqlt *sqlite) realloc(ptr ptr_t, size int64) ptr_t {
ptr = ptr_t(sqlt.call("sqlite3_realloc64", stk_t(ptr), stk_t(size)))
if ptr == 0 && size != 0 {
panic(util.OOMErr)
}
return ptr
}
func (sqlt *sqlite) newBytes(b []byte) uint32 {
func (sqlt *sqlite) newBytes(b []byte) ptr_t {
if (*[0]byte)(b) == nil {
return 0
}
@@ -224,33 +219,31 @@ func (sqlt *sqlite) newBytes(b []byte) uint32 {
if size == 0 {
size = 1
}
ptr := sqlt.new(uint64(size))
ptr := sqlt.new(int64(size))
util.WriteBytes(sqlt.mod, ptr, b)
return ptr
}
func (sqlt *sqlite) newString(s string) uint32 {
ptr := sqlt.new(uint64(len(s) + 1))
func (sqlt *sqlite) newString(s string) ptr_t {
ptr := sqlt.new(int64(len(s)) + 1)
util.WriteString(sqlt.mod, ptr, s)
return ptr
}
func (sqlt *sqlite) newArena(size uint64) arena {
// Ensure the arena's size is a multiple of 8.
size = (size + 7) &^ 7
const arenaSize = 4096
func (sqlt *sqlite) newArena() arena {
return arena{
sqlt: sqlt,
size: uint32(size),
base: sqlt.new(size),
base: sqlt.new(arenaSize),
}
}
type arena struct {
sqlt *sqlite
ptrs []uint32
base uint32
next uint32
size uint32
ptrs []ptr_t
base ptr_t
next int32
}
func (a *arena) free() {
@@ -277,34 +270,34 @@ func (a *arena) mark() (reset func()) {
}
}
func (a *arena) new(size uint64) uint32 {
func (a *arena) new(size int64) ptr_t {
// Align the next address, to 4 or 8 bytes.
if size&7 != 0 {
a.next = (a.next + 3) &^ 3
} else {
a.next = (a.next + 7) &^ 7
}
if size <= uint64(a.size-a.next) {
ptr := a.base + a.next
a.next += uint32(size)
return ptr
if size <= arenaSize-int64(a.next) {
ptr := a.base + ptr_t(a.next)
a.next += int32(size)
return ptr_t(ptr)
}
ptr := a.sqlt.new(size)
a.ptrs = append(a.ptrs, ptr)
return ptr
return ptr_t(ptr)
}
func (a *arena) bytes(b []byte) uint32 {
func (a *arena) bytes(b []byte) ptr_t {
if (*[0]byte)(b) == nil {
return 0
}
ptr := a.new(uint64(len(b)))
ptr := a.new(int64(len(b)))
util.WriteBytes(a.sqlt.mod, ptr, b)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint64(len(s) + 1))
func (a *arena) string(s string) ptr_t {
ptr := a.new(int64(len(s)) + 1)
util.WriteString(a.sqlt.mod, ptr, s)
return ptr
}
@@ -324,8 +317,7 @@ func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncVI(env, "go_destroy", destroyCallback)
util.ExportFuncVIIII(env, "go_func", funcCallback)
util.ExportFuncVIIIII(env, "go_step", stepCallback)
util.ExportFuncVIII(env, "go_final", finalCallback)
util.ExportFuncVII(env, "go_value", valueCallback)
util.ExportFuncVIIII(env, "go_value", valueCallback)
util.ExportFuncVIIII(env, "go_inverse", inverseCallback)
util.ExportFuncVIIII(env, "go_collation_needed", collationCallback)
util.ExportFuncIIIIII(env, "go_compare", compareCallback)

View File

@@ -2,7 +2,7 @@
# handle, and interrupt, sqlite3_busy_timeout.
--- sqlite3.c.orig
+++ sqlite3.c
@@ -182928,7 +182928,7 @@
@@ -183355,7 +183355,7 @@
if( !sqlite3SafetyCheckOk(db) ) return SQLITE_MISUSE_BKPT;
#endif
if( ms>0 ){

View File

@@ -3,7 +3,7 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
curl -#OL "https://sqlite.org/2024/sqlite-amalgamation-3470200.zip"
curl -#OL "https://sqlite.org/2025/sqlite-amalgamation-3490100.zip"
unzip -d . sqlite-amalgamation-*.zip
mv sqlite-amalgamation-*/sqlite3.c .
mv sqlite-amalgamation-*/sqlite3.h .
@@ -17,32 +17,32 @@ rm -rf sqlite-amalgamation-*
# mv sqlite-snapshot-*/sqlite3ext.h .
# rm -rf sqlite-snapshot-*
cat *.patch | patch --no-backup-if-mismatch
mkdir -p ext/
cd ext/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.47.2/ext/misc/anycollseq.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.47.2/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.47.2/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.47.2/ext/misc/ieee754.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.47.2/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.47.2/ext/misc/series.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.47.2/ext/misc/spellfix.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.47.2/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.49.1/ext/misc/anycollseq.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.49.1/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.49.1/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.49.1/ext/misc/ieee754.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.49.1/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.49.1/ext/misc/series.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.49.1/ext/misc/spellfix.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.49.1/ext/misc/uint.c"
cd ~-
cd ../vfs/tests/mptest/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.47.2/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.47.2/mptest/config02.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.47.2/mptest/crash01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.47.2/mptest/crash02.subtest"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.47.2/mptest/multiwrite01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.49.1/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.49.1/mptest/config02.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.49.1/mptest/crash01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.49.1/mptest/crash02.subtest"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.49.1/mptest/multiwrite01.test"
cd ~-
cd ../vfs/tests/mptest/wasm/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.47.2/mptest/mptest.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.49.1/mptest/mptest.c"
cd ~-
cd ../vfs/tests/speedtest1/wasm/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.47.2/test/speedtest1.c"
cd ~-
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.49.1/test/speedtest1.c"
cd ~-
cat *.patch | patch -p0 --no-backup-if-mismatch

View File

@@ -11,9 +11,8 @@ int go_compare(go_handle, int, const void *, int, const void *);
void go_func(sqlite3_context *, go_handle, int, sqlite3_value **);
void go_step(sqlite3_context *, go_handle *, go_handle, int, sqlite3_value **);
void go_final(sqlite3_context *, go_handle, go_handle);
void go_value(sqlite3_context *, go_handle);
void go_inverse(sqlite3_context *, go_handle *, int, sqlite3_value **);
void go_value(sqlite3_context *, go_handle *, go_handle, bool);
void go_inverse(sqlite3_context *, go_handle, int, sqlite3_value **);
void go_func_wrapper(sqlite3_context *ctx, int nArg, sqlite3_value **pArg) {
go_func(ctx, sqlite3_user_data(ctx), nArg, pArg);
@@ -28,22 +27,26 @@ void go_step_wrapper(sqlite3_context *ctx, int nArg, sqlite3_value **pArg) {
go_step(ctx, agg, data, nArg, pArg);
}
void go_value_wrapper(sqlite3_context *ctx) {
go_handle *agg = sqlite3_aggregate_context(ctx, 4);
go_handle data = NULL;
if (agg == NULL || *agg == NULL) {
data = sqlite3_user_data(ctx);
}
go_value(ctx, agg, data, /*final=*/false);
}
void go_final_wrapper(sqlite3_context *ctx) {
go_handle *agg = sqlite3_aggregate_context(ctx, 0);
go_handle data = NULL;
if (agg == NULL || *agg == NULL) {
data = sqlite3_user_data(ctx);
}
go_final(ctx, agg, data);
}
void go_value_wrapper(sqlite3_context *ctx) {
go_handle *agg = sqlite3_aggregate_context(ctx, 4);
go_value(ctx, *agg);
go_value(ctx, agg, data, /*final=*/true);
}
void go_inverse_wrapper(sqlite3_context *ctx, int nArg, sqlite3_value **pArg) {
go_handle *agg = sqlite3_aggregate_context(ctx, 4);
go_handle *agg = sqlite3_aggregate_context(ctx, 0);
go_inverse(ctx, *agg, nArg, pArg);
}

View File

@@ -19,6 +19,10 @@ void go_log(void *, int, const char *);
unsigned int go_autovacuum_pages(void *, const char *, unsigned int,
unsigned int, unsigned int);
void sqlite3_log_go(int iErrCode, const char *zMsg) {
sqlite3_log(iErrCode, "%s", zMsg);
}
void sqlite3_progress_handler_go(sqlite3 *db, int n) {
sqlite3_progress_handler(db, n, go_progress_handler, /*arg=*/NULL);
}
@@ -65,8 +69,8 @@ int sqlite3_autovacuum_pages_go(sqlite3 *db, go_handle app) {
#ifndef sqliteBusyCallback
static int sqliteBusyCallback(sqlite3 *db, int count) {
return go_busy_timeout(count, db->busyTimeout);
static int sqliteBusyCallback(void *ptr, int count) {
return go_busy_timeout(count, ((sqlite3 *)ptr)->busyTimeout);
}
#endif

View File

@@ -18,8 +18,6 @@
#define HAVE_STDINT_H 1
#define HAVE_INTTYPES_H 1
#define LONGDOUBLE_TYPE double
#define HAVE_LOG2 1
#define HAVE_LOG10 1
#define HAVE_ISNAN 1
@@ -35,14 +33,8 @@
#define HAVE_MALLOC_H 1
#define HAVE_MALLOC_USABLE_SIZE 1
// Because Wasm does not support shared memory,
// SQLite disables WAL for Wasm builds.
#undef SQLITE_OMIT_WAL
// Implemented in hooks.c.
static int sqliteBusyCallback(void *, int);
// Implemented in vfs.c.
int localtime_s(struct tm *const pTm, time_t const *const pTime);
// Implemented in hooks.c.
#ifndef sqliteBusyCallback
static int sqliteBusyCallback(sqlite3 *, int);
#endif
int localtime_s(struct tm *const pTm, time_t const *const pTime);

View File

@@ -90,6 +90,7 @@ struct go_file {
};
sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
// The default VFS.
if (!zVfsName || !strcmp(zVfsName, "os")) {
static sqlite3_vfs os_vfs = {
.iVersion = 2,
@@ -109,18 +110,21 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
return &os_vfs;
}
// Check if a Go VFS exists.
if (!go_vfs_find(zVfsName)) {
return NULL;
}
static sqlite3_vfs *go_vfs_list;
// Do we already have a C wrapper for the Go VFS?
for (sqlite3_vfs *it = go_vfs_list; it; it = it->pNext) {
if (!strcmp(zVfsName, it->zName)) {
return it;
}
}
// Delete C wrappers that are no longer needed.
for (sqlite3_vfs **ptr = &go_vfs_list; *ptr;) {
sqlite3_vfs *it = *ptr;
if (go_vfs_find(it->zName)) {
@@ -131,6 +135,7 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
}
}
// Create a new C wrapper.
sqlite3_vfs *head = go_vfs_list;
go_vfs_list = malloc(sizeof(sqlite3_vfs) + strlen(zVfsName) + 1);
char *name = (char *)(go_vfs_list + 1);
@@ -158,8 +163,11 @@ int localtime_s(struct tm *const pTm, time_t const *const pTime) {
return go_localtime(pTm, (sqlite3_int64)*pTime);
}
int sqlite3_os_init() {
return SQLITE_OK;
int sqlite3_os_init() { return SQLITE_OK; }
int sqlite3_invoke_busy_handler_go(sqlite3_int64 token) {
void **ap = (void **)&token;
return ((int (*)(void *))(ap[0]))(ap[1]);
}
static_assert(offsetof(sqlite3_vfs, zName) == 16, "Unexpected offset");

View File

@@ -1,7 +1,7 @@
# Remove VFS registration. Go handles it.
--- sqlite3.c.orig
+++ sqlite3.c
@@ -26603,7 +26603,7 @@
@@ -26725,7 +26725,7 @@
sqlite3_free(p);
return sqlite3_os_init();
}
@@ -10,7 +10,7 @@
/*
** The list of all registered VFS implementations.
*/
@@ -26700,7 +26700,7 @@
@@ -26822,7 +26822,7 @@
sqlite3_mutex_leave(mutex);
return SQLITE_OK;
}

View File

@@ -22,7 +22,7 @@ func Test_sqlite_error_OOM(t *testing.T) {
defer sqlite.close()
defer func() { _ = recover() }()
sqlite.error(uint64(NOMEM), 0)
sqlite.error(res_t(NOMEM), 0)
t.Error("want panic")
}
@@ -65,7 +65,7 @@ func Test_sqlite_newArena(t *testing.T) {
}
defer sqlite.close()
arena := sqlite.newArena(16)
arena := sqlite.newArena()
defer arena.free()
const title = "Lorem ipsum"
@@ -73,7 +73,7 @@ func Test_sqlite_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != title {
if got := util.ReadString(sqlite.mod, ptr, math.MaxInt); got != title {
t.Errorf("got %q, want %q", got, title)
}
@@ -82,7 +82,7 @@ func Test_sqlite_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != body {
if got := util.ReadString(sqlite.mod, ptr, math.MaxInt); got != body {
t.Errorf("got %q, want %q", got, body)
}
@@ -94,7 +94,7 @@ func Test_sqlite_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := util.View(sqlite.mod, ptr, uint64(len(title))); string(got) != title {
if got := util.View(sqlite.mod, ptr, int64(len(title))); string(got) != title {
t.Errorf("got %q, want %q", got, title)
}
@@ -122,7 +122,7 @@ func Test_sqlite_newBytes(t *testing.T) {
}
want := buf
if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
if got := util.View(sqlite.mod, ptr, int64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
@@ -130,10 +130,6 @@ func Test_sqlite_newBytes(t *testing.T) {
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
if got := util.View(sqlite.mod, ptr, 0); got != nil {
t.Errorf("got %q, want nil", got)
}
}
func Test_sqlite_newString(t *testing.T) {
@@ -157,7 +153,7 @@ func Test_sqlite_newString(t *testing.T) {
}
want := str + "\000"
if got := util.View(sqlite.mod, ptr, uint64(len(want))); string(got) != want {
if got := util.View(sqlite.mod, ptr, int64(len(want))); string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -183,7 +179,7 @@ func Test_sqlite_getString(t *testing.T) {
}
want := "sqlite3"
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != want {
if got := util.ReadString(sqlite.mod, ptr, math.MaxInt); got != want {
t.Errorf("got %q, want %q", got, want)
}
if got := util.ReadString(sqlite.mod, ptr, 0); got != "" {
@@ -192,13 +188,13 @@ func Test_sqlite_getString(t *testing.T) {
func() {
defer func() { _ = recover() }()
util.ReadString(sqlite.mod, ptr, uint32(len(want)/2))
util.ReadString(sqlite.mod, ptr, int64(len(want))/2)
t.Error("want panic")
}()
func() {
defer func() { _ = recover() }()
util.ReadString(sqlite.mod, 0, math.MaxUint32)
util.ReadString(sqlite.mod, 0, math.MaxInt)
t.Error("want panic")
}()
}

256
stmt.go
View File

@@ -16,7 +16,7 @@ type Stmt struct {
c *Conn
err error
sql string
handle uint32
handle ptr_t
}
// Close destroys the prepared statement object.
@@ -29,7 +29,7 @@ func (s *Stmt) Close() error {
return nil
}
r := s.c.call("sqlite3_finalize", uint64(s.handle))
rc := res_t(s.c.call("sqlite3_finalize", stk_t(s.handle)))
stmts := s.c.stmts
for i := range stmts {
if s == stmts[i] {
@@ -42,7 +42,7 @@ func (s *Stmt) Close() error {
}
s.handle = 0
return s.c.error(r)
return s.c.error(rc)
}
// Conn returns the database connection to which the prepared statement belongs.
@@ -64,9 +64,9 @@ func (s *Stmt) SQL() string {
//
// https://sqlite.org/c3ref/expanded_sql.html
func (s *Stmt) ExpandedSQL() string {
r := s.c.call("sqlite3_expanded_sql", uint64(s.handle))
sql := util.ReadString(s.c.mod, uint32(r), _MAX_SQL_LENGTH)
s.c.free(uint32(r))
ptr := ptr_t(s.c.call("sqlite3_expanded_sql", stk_t(s.handle)))
sql := util.ReadString(s.c.mod, ptr, _MAX_SQL_LENGTH)
s.c.free(ptr)
return sql
}
@@ -75,25 +75,25 @@ func (s *Stmt) ExpandedSQL() string {
//
// https://sqlite.org/c3ref/stmt_readonly.html
func (s *Stmt) ReadOnly() bool {
r := s.c.call("sqlite3_stmt_readonly", uint64(s.handle))
return r != 0
b := int32(s.c.call("sqlite3_stmt_readonly", stk_t(s.handle)))
return b != 0
}
// Reset resets the prepared statement object.
//
// https://sqlite.org/c3ref/reset.html
func (s *Stmt) Reset() error {
r := s.c.call("sqlite3_reset", uint64(s.handle))
rc := res_t(s.c.call("sqlite3_reset", stk_t(s.handle)))
s.err = nil
return s.c.error(r)
return s.c.error(rc)
}
// Busy determines if a prepared statement has been reset.
//
// https://sqlite.org/c3ref/stmt_busy.html
func (s *Stmt) Busy() bool {
r := s.c.call("sqlite3_stmt_busy", uint64(s.handle))
return r != 0
rc := res_t(s.c.call("sqlite3_stmt_busy", stk_t(s.handle)))
return rc != 0
}
// Step evaluates the SQL statement.
@@ -107,15 +107,15 @@ func (s *Stmt) Busy() bool {
// https://sqlite.org/c3ref/step.html
func (s *Stmt) Step() bool {
s.c.checkInterrupt(s.c.handle)
r := s.c.call("sqlite3_step", uint64(s.handle))
switch r {
rc := res_t(s.c.call("sqlite3_step", stk_t(s.handle)))
switch rc {
case _ROW:
s.err = nil
return true
case _DONE:
s.err = nil
default:
s.err = s.c.error(r)
s.err = s.c.error(rc)
}
return false
}
@@ -143,30 +143,30 @@ func (s *Stmt) Status(op StmtStatus, reset bool) int {
if op > STMTSTATUS_FILTER_HIT && op != STMTSTATUS_MEMUSED {
return 0
}
var i uint64
var i int32
if reset {
i = 1
}
r := s.c.call("sqlite3_stmt_status", uint64(s.handle),
uint64(op), i)
return int(int32(r))
n := int32(s.c.call("sqlite3_stmt_status", stk_t(s.handle),
stk_t(op), stk_t(i)))
return int(n)
}
// ClearBindings resets all bindings on the prepared statement.
//
// https://sqlite.org/c3ref/clear_bindings.html
func (s *Stmt) ClearBindings() error {
r := s.c.call("sqlite3_clear_bindings", uint64(s.handle))
return s.c.error(r)
rc := res_t(s.c.call("sqlite3_clear_bindings", stk_t(s.handle)))
return s.c.error(rc)
}
// BindCount returns the number of SQL parameters in the prepared statement.
//
// https://sqlite.org/c3ref/bind_parameter_count.html
func (s *Stmt) BindCount() int {
r := s.c.call("sqlite3_bind_parameter_count",
uint64(s.handle))
return int(int32(r))
n := int32(s.c.call("sqlite3_bind_parameter_count",
stk_t(s.handle)))
return int(n)
}
// BindIndex returns the index of a parameter in the prepared statement
@@ -176,9 +176,9 @@ func (s *Stmt) BindCount() int {
func (s *Stmt) BindIndex(name string) int {
defer s.c.arena.mark()()
namePtr := s.c.arena.string(name)
r := s.c.call("sqlite3_bind_parameter_index",
uint64(s.handle), uint64(namePtr))
return int(int32(r))
i := int32(s.c.call("sqlite3_bind_parameter_index",
stk_t(s.handle), stk_t(namePtr)))
return int(i)
}
// BindName returns the name of a parameter in the prepared statement.
@@ -186,10 +186,8 @@ func (s *Stmt) BindIndex(name string) int {
//
// https://sqlite.org/c3ref/bind_parameter_name.html
func (s *Stmt) BindName(param int) string {
r := s.c.call("sqlite3_bind_parameter_name",
uint64(s.handle), uint64(param))
ptr := uint32(r)
ptr := ptr_t(s.c.call("sqlite3_bind_parameter_name",
stk_t(s.handle), stk_t(param)))
if ptr == 0 {
return ""
}
@@ -223,9 +221,9 @@ func (s *Stmt) BindInt(param int, value int) error {
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindInt64(param int, value int64) error {
r := s.c.call("sqlite3_bind_int64",
uint64(s.handle), uint64(param), uint64(value))
return s.c.error(r)
rc := res_t(s.c.call("sqlite3_bind_int64",
stk_t(s.handle), stk_t(param), stk_t(value)))
return s.c.error(rc)
}
// BindFloat binds a float64 to the prepared statement.
@@ -233,9 +231,10 @@ func (s *Stmt) BindInt64(param int, value int64) error {
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindFloat(param int, value float64) error {
r := s.c.call("sqlite3_bind_double",
uint64(s.handle), uint64(param), math.Float64bits(value))
return s.c.error(r)
rc := res_t(s.c.call("sqlite3_bind_double",
stk_t(s.handle), stk_t(param),
stk_t(math.Float64bits(value))))
return s.c.error(rc)
}
// BindText binds a string to the prepared statement.
@@ -247,10 +246,10 @@ func (s *Stmt) BindText(param int, value string) error {
return TOOBIG
}
ptr := s.c.newString(value)
r := s.c.call("sqlite3_bind_text_go",
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)))
return s.c.error(r)
rc := res_t(s.c.call("sqlite3_bind_text_go",
stk_t(s.handle), stk_t(param),
stk_t(ptr), stk_t(len(value))))
return s.c.error(rc)
}
// BindRawText binds a []byte to the prepared statement as text.
@@ -263,10 +262,10 @@ func (s *Stmt) BindRawText(param int, value []byte) error {
return TOOBIG
}
ptr := s.c.newBytes(value)
r := s.c.call("sqlite3_bind_text_go",
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)))
return s.c.error(r)
rc := res_t(s.c.call("sqlite3_bind_text_go",
stk_t(s.handle), stk_t(param),
stk_t(ptr), stk_t(len(value))))
return s.c.error(rc)
}
// BindBlob binds a []byte to the prepared statement.
@@ -279,10 +278,10 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
return TOOBIG
}
ptr := s.c.newBytes(value)
r := s.c.call("sqlite3_bind_blob_go",
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)))
return s.c.error(r)
rc := res_t(s.c.call("sqlite3_bind_blob_go",
stk_t(s.handle), stk_t(param),
stk_t(ptr), stk_t(len(value))))
return s.c.error(rc)
}
// BindZeroBlob binds a zero-filled, length n BLOB to the prepared statement.
@@ -290,9 +289,9 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindZeroBlob(param int, n int64) error {
r := s.c.call("sqlite3_bind_zeroblob64",
uint64(s.handle), uint64(param), uint64(n))
return s.c.error(r)
rc := res_t(s.c.call("sqlite3_bind_zeroblob64",
stk_t(s.handle), stk_t(param), stk_t(n)))
return s.c.error(rc)
}
// BindNull binds a NULL to the prepared statement.
@@ -300,9 +299,9 @@ func (s *Stmt) BindZeroBlob(param int, n int64) error {
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindNull(param int) error {
r := s.c.call("sqlite3_bind_null",
uint64(s.handle), uint64(param))
return s.c.error(r)
rc := res_t(s.c.call("sqlite3_bind_null",
stk_t(s.handle), stk_t(param)))
return s.c.error(rc)
}
// BindTime binds a [time.Time] to the prepared statement.
@@ -316,28 +315,27 @@ func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
}
switch v := format.Encode(value).(type) {
case string:
s.BindText(param, v)
return s.BindText(param, v)
case int64:
s.BindInt64(param, v)
return s.BindInt64(param, v)
case float64:
s.BindFloat(param, v)
return s.BindFloat(param, v)
default:
panic(util.AssertErr())
}
return nil
}
func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
const maxlen = uint64(len(time.RFC3339Nano)) + 5
const maxlen = int64(len(time.RFC3339Nano)) + 5
ptr := s.c.new(maxlen)
buf := util.View(s.c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
r := s.c.call("sqlite3_bind_text_go",
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(buf)))
return s.c.error(r)
rc := res_t(s.c.call("sqlite3_bind_text_go",
stk_t(s.handle), stk_t(param),
stk_t(ptr), stk_t(len(buf))))
return s.c.error(rc)
}
// BindPointer binds a NULL to the prepared statement, just like [Stmt.BindNull],
@@ -348,9 +346,9 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindPointer(param int, ptr any) error {
valPtr := util.AddHandle(s.c.ctx, ptr)
r := s.c.call("sqlite3_bind_pointer_go",
uint64(s.handle), uint64(param), uint64(valPtr))
return s.c.error(r)
rc := res_t(s.c.call("sqlite3_bind_pointer_go",
stk_t(s.handle), stk_t(param), stk_t(valPtr)))
return s.c.error(rc)
}
// BindJSON binds the JSON encoding of value to the prepared statement.
@@ -373,27 +371,27 @@ func (s *Stmt) BindValue(param int, value Value) error {
if value.c != s.c {
return MISUSE
}
r := s.c.call("sqlite3_bind_value",
uint64(s.handle), uint64(param), uint64(value.handle))
return s.c.error(r)
rc := res_t(s.c.call("sqlite3_bind_value",
stk_t(s.handle), stk_t(param), stk_t(value.handle)))
return s.c.error(rc)
}
// DataCount resets the number of columns in a result set.
//
// https://sqlite.org/c3ref/data_count.html
func (s *Stmt) DataCount() int {
r := s.c.call("sqlite3_data_count",
uint64(s.handle))
return int(int32(r))
n := int32(s.c.call("sqlite3_data_count",
stk_t(s.handle)))
return int(n)
}
// ColumnCount returns the number of columns in a result set.
//
// https://sqlite.org/c3ref/column_count.html
func (s *Stmt) ColumnCount() int {
r := s.c.call("sqlite3_column_count",
uint64(s.handle))
return int(int32(r))
n := int32(s.c.call("sqlite3_column_count",
stk_t(s.handle)))
return int(n)
}
// ColumnName returns the name of the result column.
@@ -401,12 +399,12 @@ func (s *Stmt) ColumnCount() int {
//
// https://sqlite.org/c3ref/column_name.html
func (s *Stmt) ColumnName(col int) string {
r := s.c.call("sqlite3_column_name",
uint64(s.handle), uint64(col))
if r == 0 {
ptr := ptr_t(s.c.call("sqlite3_column_name",
stk_t(s.handle), stk_t(col)))
if ptr == 0 {
panic(util.OOMErr)
}
return util.ReadString(s.c.mod, uint32(r), _MAX_NAME)
return util.ReadString(s.c.mod, ptr, _MAX_NAME)
}
// ColumnType returns the initial [Datatype] of the result column.
@@ -414,9 +412,8 @@ func (s *Stmt) ColumnName(col int) string {
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnType(col int) Datatype {
r := s.c.call("sqlite3_column_type",
uint64(s.handle), uint64(col))
return Datatype(r)
return Datatype(s.c.call("sqlite3_column_type",
stk_t(s.handle), stk_t(col)))
}
// ColumnDeclType returns the declared datatype of the result column.
@@ -424,12 +421,12 @@ func (s *Stmt) ColumnType(col int) Datatype {
//
// https://sqlite.org/c3ref/column_decltype.html
func (s *Stmt) ColumnDeclType(col int) string {
r := s.c.call("sqlite3_column_decltype",
uint64(s.handle), uint64(col))
if r == 0 {
ptr := ptr_t(s.c.call("sqlite3_column_decltype",
stk_t(s.handle), stk_t(col)))
if ptr == 0 {
return ""
}
return util.ReadString(s.c.mod, uint32(r), _MAX_NAME)
return util.ReadString(s.c.mod, ptr, _MAX_NAME)
}
// ColumnDatabaseName returns the name of the database
@@ -438,12 +435,12 @@ func (s *Stmt) ColumnDeclType(col int) string {
//
// https://sqlite.org/c3ref/column_database_name.html
func (s *Stmt) ColumnDatabaseName(col int) string {
r := s.c.call("sqlite3_column_database_name",
uint64(s.handle), uint64(col))
if r == 0 {
ptr := ptr_t(s.c.call("sqlite3_column_database_name",
stk_t(s.handle), stk_t(col)))
if ptr == 0 {
return ""
}
return util.ReadString(s.c.mod, uint32(r), _MAX_NAME)
return util.ReadString(s.c.mod, ptr, _MAX_NAME)
}
// ColumnTableName returns the name of the table
@@ -452,12 +449,12 @@ func (s *Stmt) ColumnDatabaseName(col int) string {
//
// https://sqlite.org/c3ref/column_database_name.html
func (s *Stmt) ColumnTableName(col int) string {
r := s.c.call("sqlite3_column_table_name",
uint64(s.handle), uint64(col))
if r == 0 {
ptr := ptr_t(s.c.call("sqlite3_column_table_name",
stk_t(s.handle), stk_t(col)))
if ptr == 0 {
return ""
}
return util.ReadString(s.c.mod, uint32(r), _MAX_NAME)
return util.ReadString(s.c.mod, ptr, _MAX_NAME)
}
// ColumnOriginName returns the name of the table column
@@ -466,12 +463,12 @@ func (s *Stmt) ColumnTableName(col int) string {
//
// https://sqlite.org/c3ref/column_database_name.html
func (s *Stmt) ColumnOriginName(col int) string {
r := s.c.call("sqlite3_column_origin_name",
uint64(s.handle), uint64(col))
if r == 0 {
ptr := ptr_t(s.c.call("sqlite3_column_origin_name",
stk_t(s.handle), stk_t(col)))
if ptr == 0 {
return ""
}
return util.ReadString(s.c.mod, uint32(r), _MAX_NAME)
return util.ReadString(s.c.mod, ptr, _MAX_NAME)
}
// ColumnBool returns the value of the result column as a bool.
@@ -498,9 +495,8 @@ func (s *Stmt) ColumnInt(col int) int {
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnInt64(col int) int64 {
r := s.c.call("sqlite3_column_int64",
uint64(s.handle), uint64(col))
return int64(r)
return int64(s.c.call("sqlite3_column_int64",
stk_t(s.handle), stk_t(col)))
}
// ColumnFloat returns the value of the result column as a float64.
@@ -508,9 +504,9 @@ func (s *Stmt) ColumnInt64(col int) int64 {
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnFloat(col int) float64 {
r := s.c.call("sqlite3_column_double",
uint64(s.handle), uint64(col))
return math.Float64frombits(r)
f := uint64(s.c.call("sqlite3_column_double",
stk_t(s.handle), stk_t(col)))
return math.Float64frombits(f)
}
// ColumnTime returns the value of the result column as a [time.Time].
@@ -562,9 +558,9 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnRawText(col int) []byte {
r := s.c.call("sqlite3_column_text",
uint64(s.handle), uint64(col))
return s.columnRawBytes(col, uint32(r))
ptr := ptr_t(s.c.call("sqlite3_column_text",
stk_t(s.handle), stk_t(col)))
return s.columnRawBytes(col, ptr)
}
// ColumnRawBlob returns the value of the result column as a []byte.
@@ -574,23 +570,23 @@ func (s *Stmt) ColumnRawText(col int) []byte {
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnRawBlob(col int) []byte {
r := s.c.call("sqlite3_column_blob",
uint64(s.handle), uint64(col))
return s.columnRawBytes(col, uint32(r))
ptr := ptr_t(s.c.call("sqlite3_column_blob",
stk_t(s.handle), stk_t(col)))
return s.columnRawBytes(col, ptr)
}
func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte {
func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte {
if ptr == 0 {
r := s.c.call("sqlite3_errcode", uint64(s.c.handle))
if r != _ROW && r != _DONE {
s.err = s.c.error(r)
rc := res_t(s.c.call("sqlite3_errcode", stk_t(s.c.handle)))
if rc != _ROW && rc != _DONE {
s.err = s.c.error(rc)
}
return nil
}
r := s.c.call("sqlite3_column_bytes",
uint64(s.handle), uint64(col))
return util.View(s.c.mod, ptr, r)
n := int32(s.c.call("sqlite3_column_bytes",
stk_t(s.handle), stk_t(col)))
return util.View(s.c.mod, ptr, int64(n))
}
// ColumnJSON parses the JSON-encoded value of the result column
@@ -610,7 +606,7 @@ func (s *Stmt) ColumnJSON(col int, ptr any) error {
case INTEGER:
data = strconv.AppendInt(nil, s.ColumnInt64(col), 10)
case FLOAT:
data = strconv.AppendFloat(nil, s.ColumnFloat(col), 'g', -1, 64)
data = util.AppendNumber(nil, s.ColumnFloat(col))
default:
panic(util.AssertErr())
}
@@ -622,12 +618,12 @@ func (s *Stmt) ColumnJSON(col int, ptr any) error {
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnValue(col int) Value {
r := s.c.call("sqlite3_column_value",
uint64(s.handle), uint64(col))
ptr := ptr_t(s.c.call("sqlite3_column_value",
stk_t(s.handle), stk_t(col)))
return Value{
c: s.c,
unprot: true,
handle: uint32(r),
handle: ptr,
}
}
@@ -641,13 +637,13 @@ func (s *Stmt) ColumnValue(col int) Value {
// subsequent calls to [Stmt] methods.
func (s *Stmt) Columns(dest ...any) error {
defer s.c.arena.mark()()
count := uint64(len(dest))
count := int64(len(dest))
typePtr := s.c.arena.new(count)
dataPtr := s.c.arena.new(count * 8)
r := s.c.call("sqlite3_columns_go",
uint64(s.handle), count, uint64(typePtr), uint64(dataPtr))
if err := s.c.error(r); err != nil {
rc := res_t(s.c.call("sqlite3_columns_go",
stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr)))
if err := s.c.error(rc); err != nil {
return err
}
@@ -661,19 +657,19 @@ func (s *Stmt) Columns(dest ...any) error {
for i := range dest {
switch types[i] {
case byte(INTEGER):
dest[i] = int64(util.ReadUint64(s.c.mod, dataPtr))
dest[i] = util.Read64[int64](s.c.mod, dataPtr)
case byte(FLOAT):
dest[i] = util.ReadFloat64(s.c.mod, dataPtr)
case byte(NULL):
dest[i] = nil
default:
ptr := util.ReadUint32(s.c.mod, dataPtr+0)
ptr := util.Read32[ptr_t](s.c.mod, dataPtr+0)
if ptr == 0 {
dest[i] = []byte{}
continue
}
len := util.ReadUint32(s.c.mod, dataPtr+4)
buf := util.View(s.c.mod, ptr, uint64(len))
len := util.Read32[int32](s.c.mod, dataPtr+4)
buf := util.View(s.c.mod, ptr, int64(len))
if types[i] == byte(TEXT) {
dest[i] = string(buf)
} else {

View File

@@ -365,7 +365,7 @@ func TestBlob_Reopen(t *testing.T) {
}
var rowids []int64
for i := 0; i < 100; i++ {
for range 100 {
err = db.Exec(`INSERT INTO test VALUES (zeroblob(10))`)
if err != nil {
t.Fatal(err)

View File

@@ -92,7 +92,7 @@ func testManyQueryRow(t params) {
t.mustExec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))")
t.mustExec("insert into "+TablePrefix+"foo (id, name) values(?,?)", 1, "bob")
var name string
for i := 0; i < 10000; i++ {
for i := range 10000 {
err := t.QueryRow("select name from "+TablePrefix+"foo where id = ?", 1).Scan(&name)
if err != nil || name != "bob" {
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
@@ -164,11 +164,11 @@ func testPreparedStmt(t params) {
const nRuns = 10
var wg sync.WaitGroup
for i := 0; i < nRuns; i++ {
for range nRuns {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10; j++ {
for range 10 {
count := 0
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
t.Errorf("Query: %v", err)

View File

@@ -1,8 +1,7 @@
package tests
import (
"context"
"errors"
"fmt"
"io"
"log"
"net/url"
@@ -10,7 +9,6 @@ import (
"os/exec"
"path/filepath"
"testing"
"time"
"golang.org/x/sync/errgroup"
@@ -24,15 +22,18 @@ import (
)
func TestMain(m *testing.M) {
sqlite3.AutoExtension(func(c *sqlite3.Conn) error {
return c.ConfigLog(func(code sqlite3.ExtendedErrorCode, msg string) {
// Having to do journal recovery is unexpected.
if errors.Is(code, sqlite3.NOTICE) {
log.Panicf("%v (%d): %s", code, code, msg)
} else {
log.Printf("%v (%d): %s", code, code, msg)
}
})
sqlite3.Initialize()
sqlite3.ConfigLog(func(code sqlite3.ExtendedErrorCode, msg string) {
switch code {
case sqlite3.NOTICE_RECOVER_WAL:
// Wal "recovery" is expected.
break
case sqlite3.NOTICE_RECOVER_ROLLBACK:
// Rollback journal recovery is an error.
log.Panicf("%v (%d): %s", code, code, msg)
default:
log.Printf("%v (%d): %s", code, code, msg)
}
})
m.Run()
}
@@ -54,6 +55,7 @@ func Test_parallel(t *testing.T) {
"?_pragma=busy_timeout(10000)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
createDB(t, name)
testParallel(t, name, iter)
testIntegrity(t, name)
}
@@ -67,7 +69,7 @@ func Test_wal(t *testing.T) {
if testing.Short() {
iter = 1000
} else {
iter = 2500
iter = 5000
}
name := "file:" +
@@ -75,6 +77,7 @@ func Test_wal(t *testing.T) {
"?_pragma=busy_timeout(10000)" +
"&_pragma=journal_mode(wal)" +
"&_pragma=synchronous(off)"
createDB(t, name)
testParallel(t, name, iter)
testIntegrity(t, name)
}
@@ -90,6 +93,7 @@ func Test_memdb(t *testing.T) {
name := memdb.TestDB(t, url.Values{
"_pragma": {"busy_timeout(10000)"},
})
createDB(t, name)
testParallel(t, name, iter)
testIntegrity(t, name)
}
@@ -113,6 +117,7 @@ func Test_adiantum(t *testing.T) {
"&_pragma=busy_timeout(10000)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
createDB(t, name)
testParallel(t, name, iter)
testIntegrity(t, name)
}
@@ -136,6 +141,7 @@ func Test_xts(t *testing.T) {
"&_pragma=busy_timeout(10000)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
createDB(t, name)
testParallel(t, name, iter)
testIntegrity(t, name)
}
@@ -155,14 +161,17 @@ func Test_MultiProcess_rollback(t *testing.T) {
"?_pragma=busy_timeout(10000)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
createDB(t, name)
exe, err := os.Executable()
if err != nil {
t.Fatal(err)
}
cmd := exec.Command(exe, append(os.Args[1:], "-test.v", "-test.run=Test_ChildProcess_rollback")...)
cmd := exec.Command(exe, append(os.Args[1:],
"-test.v", "-test.count=1", "-test.run=Test_ChildProcess_rollback")...)
out, err := cmd.StdoutPipe()
cmd.Stderr = os.Stderr
if err != nil {
t.Fatal(err)
}
@@ -214,14 +223,17 @@ func Test_MultiProcess_wal(t *testing.T) {
"?_pragma=busy_timeout(10000)" +
"&_pragma=journal_mode(wal)" +
"&_pragma=synchronous(off)"
createDB(t, name)
exe, err := os.Executable()
if err != nil {
t.Fatal(err)
}
cmd := exec.Command(exe, append(os.Args[1:], "-test.v", "-test.run=Test_ChildProcess_wal")...)
cmd := exec.Command(exe, append(os.Args[1:],
"-test.v", "-test.count=1", "-test.run=Test_ChildProcess_wal")...)
out, err := cmd.StdoutPipe()
cmd.Stderr = os.Stderr
if err != nil {
t.Fatal(err)
}
@@ -263,14 +275,14 @@ func Benchmark_parallel(b *testing.B) {
b.Skip("skipping without shared memory")
}
sqlite3.Initialize()
b.ResetTimer()
name := "file:" +
filepath.Join(b.TempDir(), "test.db") +
"?_pragma=busy_timeout(10000)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
createDB(b, name)
b.ResetTimer()
testParallel(b, name, b.N)
}
@@ -279,55 +291,51 @@ func Benchmark_wal(b *testing.B) {
b.Skip("skipping without shared memory")
}
sqlite3.Initialize()
b.ResetTimer()
name := "file:" +
filepath.Join(b.TempDir(), "test.db") +
"?_pragma=busy_timeout(10000)" +
"&_pragma=journal_mode(wal)" +
"&_pragma=synchronous(off)"
createDB(b, name)
b.ResetTimer()
testParallel(b, name, b.N)
}
func Benchmark_memdb(b *testing.B) {
sqlite3.Initialize()
b.ResetTimer()
name := memdb.TestDB(b, url.Values{
"_pragma": {"busy_timeout(10000)"},
})
createDB(b, name)
b.ResetTimer()
testParallel(b, name, b.N)
}
func createDB(t testing.TB, name string) {
db, err := sqlite3.Open(name)
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
}
func testParallel(t testing.TB, name string, n int) {
writer := func() error {
db, err := sqlite3.Open(name)
if err != nil {
return err
return fmt.Errorf("writer: open: %w", err)
}
defer db.Close()
err = db.BusyHandler(func(ctx context.Context, count int) (retry bool) {
select {
case <-time.After(time.Millisecond):
return true
case <-ctx.Done():
return false
}
})
if err != nil {
return err
}
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
return err
}
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
return err
return fmt.Errorf("writer: insert: %w", err)
}
return db.Close()
@@ -336,13 +344,13 @@ func testParallel(t testing.TB, name string, n int) {
reader := func() error {
db, err := sqlite3.Open(name)
if err != nil {
return err
return fmt.Errorf("reader: open: %w", err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
if err != nil {
return err
return fmt.Errorf("reader: select: %w", err)
}
defer stmt.Close()
@@ -351,15 +359,15 @@ func testParallel(t testing.TB, name string, n int) {
row++
}
if err := stmt.Err(); err != nil {
return err
return fmt.Errorf("reader: step: %w", err)
}
if row%3 != 0 {
t.Errorf("got %d rows, want multiple of 3", row)
return fmt.Errorf("reader: got %d rows, want multiple of 3", row)
}
err = stmt.Close()
if err != nil {
return err
return fmt.Errorf("reader: close: %w", err)
}
return db.Close()
@@ -372,7 +380,7 @@ func testParallel(t testing.TB, name string, n int) {
var group errgroup.Group
group.SetLimit(6)
for i := 0; i < n; i++ {
for i := range n {
if i&7 != 7 {
group.Go(reader)
} else {

View File

@@ -4,7 +4,6 @@ import (
"database/sql"
"encoding/json"
"math"
"reflect"
"testing"
"time"
@@ -52,8 +51,7 @@ func TestQuote(t *testing.T) {
}
}()
got := sqlite3.Quote(tt.val)
if !reflect.DeepEqual(got, tt.want) {
if got := sqlite3.Quote(tt.val); got != tt.want {
t.Errorf("Quote(%v) = %q, want %q", tt.val, got, tt.want)
}
})
@@ -81,8 +79,7 @@ func TestQuoteIdentifier(t *testing.T) {
}
}()
got := sqlite3.QuoteIdentifier(tt.id)
if !reflect.DeepEqual(got, tt.want) {
if got := sqlite3.QuoteIdentifier(tt.id); got != tt.want {
t.Errorf("QuoteIdentifier(%v) = %q, want %q", tt.id, got, tt.want)
}
})

29
txn.go
View File

@@ -229,13 +229,12 @@ func (c *Conn) txnExecInterrupted(sql string) error {
//
// https://sqlite.org/c3ref/txn_state.html
func (c *Conn) TxnState(schema string) TxnState {
var ptr uint32
var ptr ptr_t
if schema != "" {
defer c.arena.mark()()
ptr = c.arena.string(schema)
}
r := c.call("sqlite3_txn_state", uint64(c.handle), uint64(ptr))
return TxnState(r)
return TxnState(c.call("sqlite3_txn_state", stk_t(c.handle), stk_t(ptr)))
}
// CommitHook registers a callback function to be invoked
@@ -244,11 +243,11 @@ func (c *Conn) TxnState(schema string) TxnState {
//
// https://sqlite.org/c3ref/commit_hook.html
func (c *Conn) CommitHook(cb func() (ok bool)) {
var enable uint64
var enable int32
if cb != nil {
enable = 1
}
c.call("sqlite3_commit_hook_go", uint64(c.handle), enable)
c.call("sqlite3_commit_hook_go", stk_t(c.handle), stk_t(enable))
c.commit = cb
}
@@ -257,11 +256,11 @@ func (c *Conn) CommitHook(cb func() (ok bool)) {
//
// https://sqlite.org/c3ref/commit_hook.html
func (c *Conn) RollbackHook(cb func()) {
var enable uint64
var enable int32
if cb != nil {
enable = 1
}
c.call("sqlite3_rollback_hook_go", uint64(c.handle), enable)
c.call("sqlite3_rollback_hook_go", stk_t(c.handle), stk_t(enable))
c.rollback = cb
}
@@ -270,15 +269,15 @@ func (c *Conn) RollbackHook(cb func()) {
//
// https://sqlite.org/c3ref/update_hook.html
func (c *Conn) UpdateHook(cb func(action AuthorizerActionCode, schema, table string, rowid int64)) {
var enable uint64
var enable int32
if cb != nil {
enable = 1
}
c.call("sqlite3_update_hook_go", uint64(c.handle), enable)
c.call("sqlite3_update_hook_go", stk_t(c.handle), stk_t(enable))
c.update = cb
}
func commitCallback(ctx context.Context, mod api.Module, pDB uint32) (rollback uint32) {
func commitCallback(ctx context.Context, mod api.Module, pDB ptr_t) (rollback int32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.commit != nil {
if !c.commit() {
rollback = 1
@@ -287,17 +286,17 @@ func commitCallback(ctx context.Context, mod api.Module, pDB uint32) (rollback u
return rollback
}
func rollbackCallback(ctx context.Context, mod api.Module, pDB uint32) {
func rollbackCallback(ctx context.Context, mod api.Module, pDB ptr_t) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.rollback != nil {
c.rollback()
}
}
func updateCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zSchema, zTabName uint32, rowid uint64) {
func updateCallback(ctx context.Context, mod api.Module, pDB ptr_t, action AuthorizerActionCode, zSchema, zTabName ptr_t, rowid int64) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.update != nil {
schema := util.ReadString(mod, zSchema, _MAX_NAME)
table := util.ReadString(mod, zTabName, _MAX_NAME)
c.update(action, schema, table, int64(rowid))
c.update(action, schema, table, rowid)
}
}
@@ -305,6 +304,6 @@ func updateCallback(ctx context.Context, mod api.Module, pDB uint32, action Auth
//
// https://sqlite.org/c3ref/db_cacheflush.html
func (c *Conn) CacheFlush() error {
r := c.call("sqlite3_db_cacheflush", uint64(c.handle))
return c.error(r)
rc := res_t(c.call("sqlite3_db_cacheflush", stk_t(c.handle)))
return c.error(rc)
}

View File

@@ -50,7 +50,7 @@ func ParseTable(sql string) (_ *Table, err error) {
copy(buf, sql)
}
stack := [...]uint64{sqlp, uint64(len(sql)), errp}
stack := [...]util.Stk_t{sqlp, util.Stk_t(len(sql)), errp}
err = mod.ExportedFunction("sql3parse_table").CallWithStack(ctx, stack[:])
if err != nil {
return nil, err
@@ -96,9 +96,9 @@ func (t *Table) load(mod api.Module, ptr uint32, sql string) {
t.IsWithoutRowID = loadBool(mod, ptr+26)
t.IsStrict = loadBool(mod, ptr+27)
t.Columns = loadSlice(mod, ptr+28, func(ptr uint32, res *Column) {
t.Columns = loadSlice(mod, ptr+28, func(ptr uint32, ret *Column) {
p, _ := mod.Memory().ReadUint32Le(ptr)
res.load(mod, p, sql)
ret.load(mod, p, sql)
})
t.Type = loadEnum[StatementType](mod, ptr+44)
@@ -166,8 +166,8 @@ type ForeignKey struct {
func (f *ForeignKey) load(mod api.Module, ptr uint32, sql string) {
f.Table = loadString(mod, ptr+0, sql)
f.Columns = loadSlice(mod, ptr+8, func(ptr uint32, res *string) {
*res = loadString(mod, ptr, sql)
f.Columns = loadSlice(mod, ptr+8, func(ptr uint32, ret *string) {
*ret = loadString(mod, ptr, sql)
})
f.OnDelete = loadEnum[FKAction](mod, ptr+16)
@@ -191,12 +191,12 @@ func loadSlice[T any](mod api.Module, ptr uint32, fn func(uint32, *T)) []T {
return nil
}
len, _ := mod.Memory().ReadUint32Le(ptr + 0)
res := make([]T, len)
for i := range res {
fn(ref, &res[i])
ret := make([]T, len)
for i := range ret {
fn(ref, &ret[i])
ref += 4
}
return res
return ret
}
func loadEnum[T ~uint32](mod api.Module, ptr uint32) T {

View File

@@ -38,18 +38,18 @@ func WrapLockState(f vfs.File) vfs.LockLevel {
return vfs.LOCK_EXCLUSIVE + 1 // UNKNOWN_LOCK
}
// WrapPersistentWAL helps wrap [vfs.FilePersistentWAL].
func WrapPersistentWAL(f vfs.File) bool {
if f, ok := f.(vfs.FilePersistentWAL); ok {
return f.PersistentWAL()
// WrapPersistWAL helps wrap [vfs.FilePersistWAL].
func WrapPersistWAL(f vfs.File) bool {
if f, ok := f.(vfs.FilePersistWAL); ok {
return f.PersistWAL()
}
return false
}
// WrapSetPersistentWAL helps wrap [vfs.FilePersistentWAL].
func WrapSetPersistentWAL(f vfs.File, keepWAL bool) {
if f, ok := f.(vfs.FilePersistentWAL); ok {
f.SetPersistentWAL(keepWAL)
// WrapSetPersistWAL helps wrap [vfs.FilePersistWAL].
func WrapSetPersistWAL(f vfs.File, keepWAL bool) {
if f, ok := f.(vfs.FilePersistWAL); ok {
f.SetPersistWAL(keepWAL)
}
}
@@ -99,6 +99,14 @@ func WrapOverwrite(f vfs.File) error {
return sqlite3.NOTFOUND
}
// WrapSyncSuper helps wrap [vfs.FileSync].
func WrapSyncSuper(f vfs.File, super string) error {
if f, ok := f.(vfs.FileSync); ok {
return f.SyncSuper(super)
}
return sqlite3.NOTFOUND
}
// WrapCommitPhaseTwo helps wrap [vfs.FileCommitPhaseTwo].
func WrapCommitPhaseTwo(f vfs.File) error {
if f, ok := f.(vfs.FileCommitPhaseTwo); ok {
@@ -153,6 +161,13 @@ func WrapPragma(f vfs.File, name, value string) (string, error) {
return "", sqlite3.NOTFOUND
}
// WrapBusyHandler helps wrap [vfs.FilePragma].
func WrapBusyHandler(f vfs.File, handler func() bool) {
if f, ok := f.(vfs.FileBusyHandler); ok {
f.BusyHandler(handler)
}
}
// WrapSharedMemory helps wrap [vfs.FileSharedMemory].
func WrapSharedMemory(f vfs.File) vfs.SharedMemory {
if f, ok := f.(vfs.FileSharedMemory); ok {

View File

@@ -14,27 +14,27 @@ import (
// https://sqlite.org/c3ref/value.html
type Value struct {
c *Conn
handle uint32
handle ptr_t
unprot bool
copied bool
}
func (v Value) protected() uint64 {
func (v Value) protected() stk_t {
if v.unprot {
panic(util.ValueErr)
}
return uint64(v.handle)
return stk_t(v.handle)
}
// Dup makes a copy of the SQL value and returns a pointer to that copy.
//
// https://sqlite.org/c3ref/value_dup.html
func (v Value) Dup() *Value {
r := v.c.call("sqlite3_value_dup", uint64(v.handle))
ptr := ptr_t(v.c.call("sqlite3_value_dup", stk_t(v.handle)))
return &Value{
c: v.c,
copied: true,
handle: uint32(r),
handle: ptr,
}
}
@@ -45,7 +45,7 @@ func (dup *Value) Close() error {
if !dup.copied {
panic(util.ValueErr)
}
dup.c.call("sqlite3_value_free", uint64(dup.handle))
dup.c.call("sqlite3_value_free", stk_t(dup.handle))
dup.handle = 0
return nil
}
@@ -54,16 +54,14 @@ func (dup *Value) Close() error {
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) Type() Datatype {
r := v.c.call("sqlite3_value_type", v.protected())
return Datatype(r)
return Datatype(v.c.call("sqlite3_value_type", v.protected()))
}
// Type returns the numeric datatype of the value.
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) NumericType() Datatype {
r := v.c.call("sqlite3_value_numeric_type", v.protected())
return Datatype(r)
return Datatype(v.c.call("sqlite3_value_numeric_type", v.protected()))
}
// Bool returns the value as a bool.
@@ -87,16 +85,15 @@ func (v Value) Int() int {
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) Int64() int64 {
r := v.c.call("sqlite3_value_int64", v.protected())
return int64(r)
return int64(v.c.call("sqlite3_value_int64", v.protected()))
}
// Float returns the value as a float64.
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) Float() float64 {
r := v.c.call("sqlite3_value_double", v.protected())
return math.Float64frombits(r)
f := uint64(v.c.call("sqlite3_value_double", v.protected()))
return math.Float64frombits(f)
}
// Time returns the value as a [time.Time].
@@ -141,8 +138,8 @@ func (v Value) Blob(buf []byte) []byte {
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) RawText() []byte {
r := v.c.call("sqlite3_value_text", v.protected())
return v.rawBytes(uint32(r))
ptr := ptr_t(v.c.call("sqlite3_value_text", v.protected()))
return v.rawBytes(ptr)
}
// RawBlob returns the value as a []byte.
@@ -151,24 +148,24 @@ func (v Value) RawText() []byte {
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) RawBlob() []byte {
r := v.c.call("sqlite3_value_blob", v.protected())
return v.rawBytes(uint32(r))
ptr := ptr_t(v.c.call("sqlite3_value_blob", v.protected()))
return v.rawBytes(ptr)
}
func (v Value) rawBytes(ptr uint32) []byte {
func (v Value) rawBytes(ptr ptr_t) []byte {
if ptr == 0 {
return nil
}
r := v.c.call("sqlite3_value_bytes", v.protected())
return util.View(v.c.mod, ptr, r)
n := int32(v.c.call("sqlite3_value_bytes", v.protected()))
return util.View(v.c.mod, ptr, int64(n))
}
// Pointer gets the pointer associated with this value,
// or nil if it has no associated pointer.
func (v Value) Pointer() any {
r := v.c.call("sqlite3_value_pointer_go", v.protected())
return util.GetHandle(v.c.ctx, uint32(r))
ptr := ptr_t(v.c.call("sqlite3_value_pointer_go", v.protected()))
return util.GetHandle(v.c.ctx, ptr)
}
// JSON parses a JSON-encoded value
@@ -185,7 +182,7 @@ func (v Value) JSON(ptr any) error {
case INTEGER:
data = strconv.AppendInt(nil, v.Int64(), 10)
case FLOAT:
data = strconv.AppendFloat(nil, v.Float(), 'g', -1, 64)
data = util.AppendNumber(nil, v.Float())
default:
panic(util.AssertErr())
}
@@ -197,16 +194,16 @@ func (v Value) JSON(ptr any) error {
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) NoChange() bool {
r := v.c.call("sqlite3_value_nochange", v.protected())
return r != 0
b := int32(v.c.call("sqlite3_value_nochange", v.protected()))
return b != 0
}
// FromBind returns true if value originated from a bound parameter.
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) FromBind() bool {
r := v.c.call("sqlite3_value_frombind", v.protected())
return r != 0
b := int32(v.c.call("sqlite3_value_frombind", v.protected()))
return b != 0
}
// InFirst returns the first element
@@ -216,13 +213,13 @@ func (v Value) FromBind() bool {
func (v Value) InFirst() (Value, error) {
defer v.c.arena.mark()()
valPtr := v.c.arena.new(ptrlen)
r := v.c.call("sqlite3_vtab_in_first", uint64(v.handle), uint64(valPtr))
if err := v.c.error(r); err != nil {
rc := res_t(v.c.call("sqlite3_vtab_in_first", stk_t(v.handle), stk_t(valPtr)))
if err := v.c.error(rc); err != nil {
return Value{}, err
}
return Value{
c: v.c,
handle: util.ReadUint32(v.c.mod, valPtr),
handle: util.Read32[ptr_t](v.c.mod, valPtr),
}, nil
}
@@ -233,12 +230,12 @@ func (v Value) InFirst() (Value, error) {
func (v Value) InNext() (Value, error) {
defer v.c.arena.mark()()
valPtr := v.c.arena.new(ptrlen)
r := v.c.call("sqlite3_vtab_in_next", uint64(v.handle), uint64(valPtr))
if err := v.c.error(r); err != nil {
rc := res_t(v.c.call("sqlite3_vtab_in_next", stk_t(v.handle), stk_t(valPtr)))
if err := v.c.error(rc); err != nil {
return Value{}, err
}
return Value{
c: v.c,
handle: util.ReadUint32(v.c.mod, valPtr),
handle: util.Read32[ptr_t](v.c.mod, valPtr),
}, nil
}

View File

@@ -6,22 +6,30 @@ It replaces the default SQLite VFS with a **pure Go** implementation,
and exposes [interfaces](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs#VFS)
that should allow you to implement your own [custom VFSes](#custom-vfses).
Since it is a from scratch reimplementation,
there are naturally some ways it deviates from the original.
See the [support matrix](https://github.com/ncruces/go-sqlite3/wiki/Support-matrix)
for the list of supported OS and CPU architectures.
The main differences are [file locking](#file-locking) and [WAL mode](#write-ahead-logging) support.
Since this is a from scratch reimplementation,
there are naturally some ways it deviates from the original.
It's also not as battle tested as the original.
The main differences to be aware of are
[file locking](#file-locking) and
[WAL mode](#write-ahead-logging) support.
### File Locking
POSIX advisory locks, which SQLite uses on Unix, are
[broken by design](https://github.com/sqlite/sqlite/blob/b74eb0/src/os_unix.c#L1073-L1161).
POSIX advisory locks,
which SQLite uses on [Unix](https://github.com/sqlite/sqlite/blob/5d60f4/src/os_unix.c#L13-L14),
are [broken by design](https://github.com/sqlite/sqlite/blob/5d60f4/src/os_unix.c#L1074-L1162).
Instead, on Linux and macOS, this package uses
[OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
to synchronize access to database files.
This package can also use
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2),
albeit with reduced concurrency (`BEGIN IMMEDIATE` behaves like `BEGIN EXCLUSIVE`).
albeit with reduced concurrency (`BEGIN IMMEDIATE` behaves like `BEGIN EXCLUSIVE`,
[docs](https://sqlite.org/lang_transaction.html#immediate)).
BSD locks are the default on BSD and illumos,
but you can opt into them with the `sqlite3_flock` build tag.
@@ -44,11 +52,11 @@ to check if your build supports file locking.
### Write-Ahead Logging
On Unix, this package may use `mmap` to implement
On Unix, this package uses `mmap` to implement
[shared-memory for the WAL-index](https://sqlite.org/wal.html#implementation_of_shared_memory_for_the_wal_index),
like SQLite.
On Windows, this package may use `MapViewOfFile`, like SQLite.
On Windows, this package uses `MapViewOfFile`, like SQLite.
You can also opt into a cross-platform, in-process, memory sharing implementation
with the `sqlite3_dotlk` build tag.
@@ -63,6 +71,11 @@ you must disable connection pooling by calling
You can use [`vfs.SupportsSharedMemory`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs#SupportsSharedMemory)
to check if your build supports shared memory.
### Blocking Locks
On Windows and macOS, this package implements
[Wal-mode blocking locks](https://sqlite.org/src/doc/tip/doc/wal-lock.md).
### Batch-Atomic Write
On Linux, this package may support
@@ -94,8 +107,10 @@ The VFS can be customized with a few build tags:
> [`unix-flock` VFS](https://sqlite.org/compile.html#enable_locking_style);
> `sqlite3_dotlk` builds are compatible with the
> [`unix-dotfile` VFS](https://sqlite.org/compile.html#enable_locking_style).
> If incompatible file locking is used, accessing databases concurrently with
> _other_ SQLite libraries will eventually corrupt data.
> [!CAUTION]
> Concurrently accessing databases using incompatible VFSes
> will eventually corrupt data.
### Custom VFSes

Some files were not shown because too many files have changed in this diff Show More