mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 14:09:13 +00:00
Compare commits
271 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d4764fb2fa | ||
|
|
42df71f3ff | ||
|
|
91e969c06b | ||
|
|
8ab0ddf53e | ||
|
|
74d22ded0a | ||
|
|
d962611796 | ||
|
|
7df3814c34 | ||
|
|
c5f49b835a | ||
|
|
0e55451a0b | ||
|
|
ea9a58ab19 | ||
|
|
0b46e74ea6 | ||
|
|
8dca850bee | ||
|
|
5b78823416 | ||
|
|
1764a571da | ||
|
|
9837310af7 | ||
|
|
ec8961a621 | ||
|
|
8c37aa2d97 | ||
|
|
ca93c498e7 | ||
|
|
15e9087fa8 | ||
|
|
7028e3a5b9 | ||
|
|
03bb20de6e | ||
|
|
20a51a344e | ||
|
|
2dbcc480f7 | ||
|
|
0f0716c438 | ||
|
|
0286e50e25 | ||
|
|
8ac10eb8b4 | ||
|
|
0ff41bb966 | ||
|
|
ba9caf0405 | ||
|
|
2c167dd116 | ||
|
|
ce0da893b4 | ||
|
|
9bbbab77f6 | ||
|
|
bab2d26652 | ||
|
|
3132b272de | ||
|
|
8f9a6ca4c1 | ||
|
|
99b097de3b | ||
|
|
4a956e80a2 | ||
|
|
5f4ff03f6f | ||
|
|
5890049488 | ||
|
|
5e73c5d714 | ||
|
|
6d92aa16ef | ||
|
|
191d1337e7 | ||
|
|
b65e894849 | ||
|
|
0b040d3f09 | ||
|
|
1db4366226 | ||
|
|
9e1cbfb5bb | ||
|
|
7f2d70a0f3 | ||
|
|
ea860e407d | ||
|
|
d4561d08f9 | ||
|
|
14c1e490b4 | ||
|
|
23aad5f62f | ||
|
|
e5bd10a1ff | ||
|
|
5cf06c45f7 | ||
|
|
08f9fc758a | ||
|
|
b588d5f991 | ||
|
|
4c24bd0cb6 | ||
|
|
cc353e4848 | ||
|
|
c3ebb04045 | ||
|
|
11e064574c | ||
|
|
770420289a | ||
|
|
62f69011f1 | ||
|
|
4f9e3f900b | ||
|
|
4e90618350 | ||
|
|
54bb94ce58 | ||
|
|
07fec784e1 | ||
|
|
da4638cbff | ||
|
|
085872c2f3 | ||
|
|
de49aa2b06 | ||
|
|
1f3ad0165e | ||
|
|
0bda48d1d9 | ||
|
|
0026bc91aa | ||
|
|
d84ca9d627 | ||
|
|
5d14e01f94 | ||
|
|
342df983d4 | ||
|
|
00476fb1e2 | ||
|
|
8a64ee6eaa | ||
|
|
8f9a8e2752 | ||
|
|
d8880e4cee | ||
|
|
4b154a842c | ||
|
|
758a53e9bf | ||
|
|
1a42b4c590 | ||
|
|
7e4ec1df1c | ||
|
|
2c582a1d66 | ||
|
|
20a67ca669 | ||
|
|
789e2dc136 | ||
|
|
0399f10c06 | ||
|
|
75c6744b5b | ||
|
|
754e806164 | ||
|
|
2640c9fb54 | ||
|
|
9719d4b0e3 | ||
|
|
b21c69dc1f | ||
|
|
b0f8ff44a5 | ||
|
|
f37bca6a80 | ||
|
|
b4e8fcb752 | ||
|
|
14b98a5d05 | ||
|
|
36a62264f9 | ||
|
|
33ea564f38 | ||
|
|
5c55d8692f | ||
|
|
be2f3036b4 | ||
|
|
784f82f42f | ||
|
|
cd6ba43e77 | ||
|
|
d7aef63844 | ||
|
|
64e5046f10 | ||
|
|
0bdce8aa68 | ||
|
|
69a2881a10 | ||
|
|
24ad4445f1 | ||
|
|
c159bbd88f | ||
|
|
c90f8205f7 | ||
|
|
b64b9b0415 | ||
|
|
9142e19d61 | ||
|
|
4a76f2b064 | ||
|
|
c9b364507e | ||
|
|
2204b96ff6 | ||
|
|
b46f480d79 | ||
|
|
040a026925 | ||
|
|
e678040a4e | ||
|
|
f1cc12569c | ||
|
|
721a987e0e | ||
|
|
f3d65142cc | ||
|
|
93f711c77b | ||
|
|
341bd063e8 | ||
|
|
f765882670 | ||
|
|
ff3676ff4a | ||
|
|
54877a53cd | ||
|
|
fccc6c10a7 | ||
|
|
fc21ffcc71 | ||
|
|
687e643d7a | ||
|
|
fc5ced209c | ||
|
|
c1bed07e3a | ||
|
|
a0771f2363 | ||
|
|
6bad547d3d | ||
|
|
c2c1aea578 | ||
|
|
60ab485b29 | ||
|
|
e17a432fde | ||
|
|
c780ef16e2 | ||
|
|
b609930142 | ||
|
|
fd165ce724 | ||
|
|
d3973b23e3 | ||
|
|
320b68e74f | ||
|
|
2c3850e5d1 | ||
|
|
db7aacff9f | ||
|
|
d748d98e39 | ||
|
|
13b8642384 | ||
|
|
29c5c816cb | ||
|
|
b32db76da6 | ||
|
|
383f620a1e | ||
|
|
a3c3515e96 | ||
|
|
e580f080b9 | ||
|
|
9ea7099c24 | ||
|
|
29aa365806 | ||
|
|
bb87a920f7 | ||
|
|
48379336dc | ||
|
|
251a92fa1a | ||
|
|
f5206ea8da | ||
|
|
68ef4593d6 | ||
|
|
79bf171210 | ||
|
|
ad16d329ea | ||
|
|
9706fa9607 | ||
|
|
45494f5fb6 | ||
|
|
1b0bf3495e | ||
|
|
73ac7e06f6 | ||
|
|
a3ce8f9de5 | ||
|
|
2043d5fca4 | ||
|
|
3bd11a0a86 | ||
|
|
39f3fa64eb | ||
|
|
4c19387535 | ||
|
|
e6c9f18934 | ||
|
|
970eb6a2f9 | ||
|
|
fac27b8bab | ||
|
|
9f626b2f52 | ||
|
|
1f5d8bf7df | ||
|
|
41dc46af7e | ||
|
|
e5c285b783 | ||
|
|
6290a14990 | ||
|
|
948641194b | ||
|
|
befed7cf23 | ||
|
|
3547d9ffb0 | ||
|
|
a67165eb09 | ||
|
|
0ba393199a | ||
|
|
9e4258bc46 | ||
|
|
b645721d10 | ||
|
|
6c296231a5 | ||
|
|
c067e3630b | ||
|
|
35a2dbd847 | ||
|
|
b36f73c66d | ||
|
|
d36f19fd91 | ||
|
|
eba71b1f42 | ||
|
|
d78239bfbf | ||
|
|
49852732b2 | ||
|
|
9b90d076cb | ||
|
|
15b94577b1 | ||
|
|
25557244cc | ||
|
|
c2d3bf0cfc | ||
|
|
58a5682084 | ||
|
|
1ed954e96f | ||
|
|
9e7a0a875d | ||
|
|
26adda4529 | ||
|
|
2f6cd8de1d | ||
|
|
e027e055ff | ||
|
|
63fdc141e5 | ||
|
|
0bbd145a49 | ||
|
|
c755ef96e6 | ||
|
|
9a69e407cc | ||
|
|
e9db0d8e84 | ||
|
|
dadf53e175 | ||
|
|
f536765206 | ||
|
|
12034c4f0b | ||
|
|
b4e5d1a213 | ||
|
|
b06c7dda6c | ||
|
|
5e1909a20e | ||
|
|
77d74baca5 | ||
|
|
4142680d5a | ||
|
|
9f4fe6f27c | ||
|
|
7870ce0690 | ||
|
|
ec3226e16e | ||
|
|
4dd7bd0ff2 | ||
|
|
975feb2fd4 | ||
|
|
58f8c2d33e | ||
|
|
019660eed6 | ||
|
|
30c1bcdbe9 | ||
|
|
9b4002f5ac | ||
|
|
2a78d4bc2b | ||
|
|
c09623a903 | ||
|
|
fa613f9ddb | ||
|
|
57997201ee | ||
|
|
6995cca5c0 | ||
|
|
a10eef3ac8 | ||
|
|
d627ca3dc1 | ||
|
|
b2f7ab8335 | ||
|
|
c9135b9823 | ||
|
|
0d9ed94aad | ||
|
|
1d951ecd18 | ||
|
|
c0298ad274 | ||
|
|
42bad5891a | ||
|
|
40090d8250 | ||
|
|
d2f162972d | ||
|
|
e2da469834 | ||
|
|
1677b97fa4 | ||
|
|
407e13d238 | ||
|
|
9132f74b69 | ||
|
|
c024121fd2 | ||
|
|
aa8287f8e7 | ||
|
|
ab09da7136 | ||
|
|
a159b548ed | ||
|
|
d9b37307e7 | ||
|
|
3bae1d7d4b | ||
|
|
8887036c20 | ||
|
|
ccb3dcd097 | ||
|
|
a9f33cc2b0 | ||
|
|
f025ffb385 | ||
|
|
aa4357a78f | ||
|
|
aef7f051a8 | ||
|
|
a79ee4c2c6 | ||
|
|
7424747338 | ||
|
|
11830e05a6 | ||
|
|
7dc4520690 | ||
|
|
0c09dd89c2 | ||
|
|
31c5000875 | ||
|
|
8175407754 | ||
|
|
abfad02d95 | ||
|
|
f7c3fb8062 | ||
|
|
c3633dda35 | ||
|
|
f2d894194d | ||
|
|
e08c7b3adf | ||
|
|
66601dd3cb | ||
|
|
58b66b75f1 | ||
|
|
e0c6086aa9 | ||
|
|
9bc39c5b91 | ||
|
|
12193cedea | ||
|
|
71d95bf9d5 | ||
|
|
7e23100ff7 | ||
|
|
e32d8401fb |
23
.github/actions/lfs/action.yml
vendored
23
.github/actions/lfs/action.yml
vendored
@@ -1,23 +0,0 @@
|
||||
name: Git LFS pull
|
||||
description: Cached Git LFS pull.
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Create LFS file list
|
||||
shell: bash
|
||||
run: git lfs ls-files --long | cut -d ' ' -f1 | sort > .lfs-assets-id
|
||||
|
||||
- name: Restore LFS cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: .git/lfs/objects
|
||||
key: lfs-${{ hashFiles('.lfs-assets-id') }}
|
||||
restore-keys: lfs-
|
||||
enableCrossOsArchive: true
|
||||
|
||||
- name: Git LFS pull
|
||||
shell: bash
|
||||
run: |
|
||||
git lfs pull
|
||||
git lfs prune
|
||||
7
.github/actions/vmactions/template.yml
vendored
7
.github/actions/vmactions/template.yml
vendored
@@ -1,11 +1,6 @@
|
||||
name: VM Actions matrix
|
||||
description: VM Actions matrix template
|
||||
|
||||
inputs:
|
||||
run:
|
||||
description: The CI command to run
|
||||
required: true
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
@@ -13,4 +8,4 @@ runs:
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
run: ${{inputs.run}}
|
||||
run: . ./test.sh
|
||||
2
.github/workflows/build-test.sh
vendored
2
.github/workflows/build-test.sh
vendored
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
echo 'set -eu' > test.sh
|
||||
echo 'set -eux' > test.sh
|
||||
|
||||
for p in $(go list ./...); do
|
||||
dir=".${p#github.com/ncruces/go-sqlite3}"
|
||||
|
||||
25
.github/workflows/cross.sh
vendored
25
.github/workflows/cross.sh
vendored
@@ -1,25 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
echo android ; GOOS=android GOARCH=amd64 go build .
|
||||
echo darwin ; GOOS=darwin GOARCH=amd64 go build .
|
||||
echo dragonfly ; GOOS=dragonfly GOARCH=amd64 go build .
|
||||
echo freebsd ; GOOS=freebsd GOARCH=amd64 go build .
|
||||
echo illumos ; GOOS=illumos GOARCH=amd64 go build .
|
||||
echo ios ; GOOS=ios GOARCH=amd64 go build .
|
||||
echo linux ; GOOS=linux GOARCH=amd64 go build .
|
||||
echo netbsd ; GOOS=netbsd GOARCH=amd64 go build .
|
||||
echo openbsd ; GOOS=openbsd GOARCH=amd64 go build .
|
||||
echo plan9 ; GOOS=plan9 GOARCH=amd64 go build .
|
||||
echo solaris ; GOOS=solaris GOARCH=amd64 go build .
|
||||
echo windows ; GOOS=windows GOARCH=amd64 go build .
|
||||
echo aix ; GOOS=aix GOARCH=ppc64 go build .
|
||||
echo js ; GOOS=js GOARCH=wasm go build .
|
||||
echo wasip1 ; GOOS=wasip1 GOARCH=wasm go build .
|
||||
echo linux-flock ; GOOS=linux GOARCH=amd64 go build -tags sqlite3_flock .
|
||||
echo linux-dotlk ; GOOS=linux GOARCH=amd64 go build -tags sqlite3_dotlk .
|
||||
echo darwin-flock ; GOOS=darwin GOARCH=amd64 go build -tags sqlite3_flock .
|
||||
echo darwin-dotlk ; GOOS=darwin GOARCH=amd64 go build -tags sqlite3_dotlk .
|
||||
echo windows-dotlk ; GOOS=windows GOARCH=amd64 go build -tags sqlite3_dotlk .
|
||||
echo freebsd-dotlk ; GOOS=freebsd GOARCH=amd64 go build -tags sqlite3_dotlk .
|
||||
echo solaris-dotlk ; GOOS=solaris GOARCH=amd64 go build -tags sqlite3_dotlk .
|
||||
16
.github/workflows/cross.yml
vendored
16
.github/workflows/cross.yml
vendored
@@ -1,16 +0,0 @@
|
||||
name: Cross compile
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with: { go-version: stable }
|
||||
|
||||
- name: Build
|
||||
run: .github/workflows/cross.sh
|
||||
23
.github/workflows/libc.yml
vendored
Normal file
23
.github/workflows/libc.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Benchmark libc
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-24.04, ubuntu-24.04-arm, macos-15, macos-15-intel]
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-go@v6
|
||||
with: { go-version: stable }
|
||||
|
||||
- name: Benchmark
|
||||
shell: bash
|
||||
run: sqlite3/libc/benchmark.sh
|
||||
25
.github/workflows/repro.sh
vendored
25
.github/workflows/repro.sh
vendored
@@ -1,34 +1,15 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
if [[ "$OSTYPE" == "linux"* ]]; then
|
||||
WASI_SDK="https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-25/wasi-sdk-25.0-x86_64-linux.tar.gz"
|
||||
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_120_b/binaryen-version_120_b-x86_64-linux.tar.gz"
|
||||
elif [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
WASI_SDK="https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-25/wasi-sdk-25.0-arm64-macos.tar.gz"
|
||||
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_120_b/binaryen-version_120_b-arm64-macos.tar.gz"
|
||||
elif [[ "$OSTYPE" == "msys" || "$OSTYPE" == "cygwin" ]]; then
|
||||
WASI_SDK="https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-25/wasi-sdk-25.0-x86_64-windows.tar.gz"
|
||||
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_120_b/binaryen-version_120_b-x86_64-windows.tar.gz"
|
||||
fi
|
||||
|
||||
# Download tools
|
||||
mkdir -p tools/
|
||||
[ -d "tools/wasi-sdk" ] || curl -#L "$WASI_SDK" | tar xzC tools &
|
||||
[ -d "tools/binaryen" ] || curl -#L "$BINARYEN" | tar xzC tools &
|
||||
wait
|
||||
|
||||
[ -d "tools/wasi-sdk" ] || mv "tools/wasi-sdk"* "tools/wasi-sdk"
|
||||
[ -d "tools/binaryen" ] || mv "tools/binaryen"* "tools/binaryen"
|
||||
|
||||
# Download and build SQLite
|
||||
sqlite3/download.sh
|
||||
sqlite3/tools.sh
|
||||
embed/build.sh
|
||||
embed/bcw2/build.sh
|
||||
|
||||
# Download and build sqlite-createtable-parser
|
||||
util/sql3util/parse/download.sh
|
||||
util/sql3util/parse/build.sh
|
||||
util/sql3util/wasm/download.sh
|
||||
util/sql3util/wasm/build.sh
|
||||
|
||||
# Check diffs
|
||||
git diff --exit-code
|
||||
8
.github/workflows/repro.yml
vendored
8
.github/workflows/repro.yml
vendored
@@ -17,18 +17,16 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: ilammy/msvc-dev-cmd@v1
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with: { go-version: stable }
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Build
|
||||
shell: bash
|
||||
run: .github/workflows/repro.sh
|
||||
|
||||
- uses: actions/attest-build-provenance@v2
|
||||
- uses: actions/attest-build-provenance@v3
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
with:
|
||||
subject-path: |
|
||||
embed/sqlite3.wasm
|
||||
embed/bcw2/bcw2.wasm
|
||||
util/sql3util/parse/sql3parse_table.wasm
|
||||
util/sql3util/wasm/sql3parse_table.wasm
|
||||
172
.github/workflows/test.yml
vendored
172
.github/workflows/test.yml
vendored
@@ -2,36 +2,38 @@ name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
branches: [ 'main' ]
|
||||
paths:
|
||||
- '**.go'
|
||||
- '**.mod'
|
||||
- '**.wasm'
|
||||
- '**.wasm.bz2'
|
||||
- '**.yml'
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
branches: [ 'main' ]
|
||||
paths:
|
||||
- '**.go'
|
||||
- '**.mod'
|
||||
- '**.wasm'
|
||||
- '**.wasm.bz2'
|
||||
- '**.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [macos-latest, ubuntu-latest, windows-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-go@v6
|
||||
with: { go-version: stable }
|
||||
|
||||
- name: Git LFS pull
|
||||
uses: ./.github/actions/lfs
|
||||
|
||||
- name: Format
|
||||
run: gofmt -s -w . && git diff --exit-code
|
||||
if: matrix.os != 'windows-latest'
|
||||
@@ -49,25 +51,35 @@ jobs:
|
||||
run: go vet ./...
|
||||
|
||||
- name: Build
|
||||
run: go build -v ./...
|
||||
run: go build ./...
|
||||
|
||||
- name: Test
|
||||
run: go test -v ./... -bench . -benchtime=1x
|
||||
run: go test ./... -bench . -benchtime=1x
|
||||
|
||||
- name: Test BSD locks
|
||||
run: go test -v -tags sqlite3_flock ./...
|
||||
if: matrix.os == 'macos-latest'
|
||||
run: go test -tags sqlite3_flock ./...
|
||||
if: matrix.os != 'windows-latest'
|
||||
|
||||
- name: Test dot locks
|
||||
run: go test -v -tags sqlite3_dotlk ./...
|
||||
run: go test -tags sqlite3_dotlk ./...
|
||||
if: matrix.os != 'windows-latest'
|
||||
|
||||
- name: Test modules
|
||||
shell: bash
|
||||
run: |
|
||||
go work init .
|
||||
go work use -r embed/bcw2 gormlite
|
||||
go test ./embed/bcw2 ./gormlite
|
||||
|
||||
- name: Test GORM
|
||||
shell: bash
|
||||
run: gormlite/test.sh
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
|
||||
- name: Collect coverage
|
||||
run: go run github.com/dave/courtney@latest
|
||||
run: |
|
||||
go get -tool github.com/dave/courtney@v0.4.4
|
||||
go tool courtney
|
||||
if: |
|
||||
github.event_name == 'push' &&
|
||||
matrix.os == 'ubuntu-latest'
|
||||
@@ -81,38 +93,46 @@ jobs:
|
||||
github.event_name == 'push' &&
|
||||
matrix.os == 'ubuntu-latest'
|
||||
|
||||
test-bsd:
|
||||
test-cross:
|
||||
strategy:
|
||||
matrix:
|
||||
os:
|
||||
- name: freebsd
|
||||
version: '14.2'
|
||||
flags: '-test.v'
|
||||
version: '15.0'
|
||||
- name: netbsd
|
||||
version: '10.0'
|
||||
flags: '-test.v'
|
||||
version: '10.1'
|
||||
- name: illumos
|
||||
action: omnios
|
||||
version: 'r151056'
|
||||
- name: openbsd
|
||||
version: '7.6'
|
||||
flags: '-test.v -test.short'
|
||||
version: '7.8'
|
||||
tflags: '-test.short'
|
||||
- name: freebsd
|
||||
arch: arm64
|
||||
version: '15.0'
|
||||
tflags: '-test.short'
|
||||
- name: netbsd
|
||||
arch: arm64
|
||||
version: '10.1'
|
||||
tflags: '-test.short'
|
||||
runs-on: ubuntu-latest
|
||||
needs: test
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Git LFS pull
|
||||
uses: ./.github/actions/lfs
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Build
|
||||
env:
|
||||
GOOS: ${{ matrix.os.name }}
|
||||
TESTFLAGS: ${{ matrix.os.flags }}
|
||||
GOARCH: ${{ matrix.os.arch }}
|
||||
TESTFLAGS: ${{ matrix.os.tflags }}
|
||||
run: .github/workflows/build-test.sh
|
||||
|
||||
- name: Test
|
||||
uses: cross-platform-actions/action@v0.26.0
|
||||
uses: cross-platform-actions/action@v0.32.0
|
||||
with:
|
||||
operating_system: ${{ matrix.os.name }}
|
||||
operating_system: ${{ matrix.os.action || matrix.os.name }}
|
||||
architecture: ${{ matrix.os.arch }}
|
||||
version: ${{ matrix.os.version }}
|
||||
shell: bash
|
||||
run: . ./test.sh
|
||||
@@ -124,22 +144,16 @@ jobs:
|
||||
os:
|
||||
- name: dragonfly
|
||||
action: 'vmactions/dragonflybsd-vm@v1'
|
||||
tflags: '-test.v'
|
||||
- name: illumos
|
||||
action: 'vmactions/omnios-vm@v1'
|
||||
tflags: '-test.v'
|
||||
action: 'vmactions/openindiana-vm@v0'
|
||||
- name: solaris
|
||||
action: 'vmactions/solaris-vm@v1'
|
||||
bflags: '-tags sqlite3_dotlk'
|
||||
tflags: '-test.v'
|
||||
runs-on: ubuntu-latest
|
||||
needs: test
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Git LFS pull
|
||||
uses: ./.github/actions/lfs
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Build
|
||||
env:
|
||||
@@ -151,10 +165,27 @@ jobs:
|
||||
|
||||
- name: Test
|
||||
uses: ./.github/actions/vmactions
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
run: . ./test.sh
|
||||
|
||||
test-wasip1:
|
||||
runs-on: ubuntu-latest
|
||||
needs: test
|
||||
|
||||
steps:
|
||||
- uses: bytecodealliance/actions/wasmtime/setup@v1
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-go@v6
|
||||
with: { go-version: stable }
|
||||
|
||||
- name: Set path
|
||||
run: echo "$(go env GOROOT)/lib/wasm" >> "$GITHUB_PATH"
|
||||
|
||||
- name: Test wasmtime
|
||||
env:
|
||||
GOOS: wasip1
|
||||
GOARCH: wasm
|
||||
GOWASIRUNTIME: wasmtime
|
||||
GOWASIRUNTIMEARGS: '--env CI=true'
|
||||
run: go test -short -tags sqlite3_dotlk -skip Example ./...
|
||||
|
||||
test-qemu:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -162,39 +193,60 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: docker/setup-qemu-action@v3
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-go@v6
|
||||
with: { go-version: stable }
|
||||
|
||||
- name: Git LFS pull
|
||||
uses: ./.github/actions/lfs
|
||||
|
||||
- name: Test 386 (32-bit)
|
||||
run: GOARCH=386 go test -v -short ./...
|
||||
|
||||
- name: Test arm64 (compiler)
|
||||
run: GOARCH=arm64 go test -v -short ./...
|
||||
run: GOARCH=386 go test -short ./...
|
||||
|
||||
- name: Test riscv64 (interpreter)
|
||||
run: GOARCH=riscv64 go test -v -short ./...
|
||||
run: GOARCH=riscv64 go test -short ./...
|
||||
|
||||
- name: Test ppc64le (interpreter)
|
||||
run: GOARCH=ppc64le go test -v -short ./...
|
||||
run: GOARCH=ppc64le go test -short ./...
|
||||
|
||||
- name: Test loong64 (interpreter)
|
||||
run: GOARCH=loong64 go test -short ./...
|
||||
|
||||
- name: Test s390x (big-endian)
|
||||
run: GOARCH=s390x go test -v -short -tags sqlite3_dotlk ./...
|
||||
run: GOARCH=s390x go test -short -tags sqlite3_dotlk ./...
|
||||
|
||||
test-macintel:
|
||||
runs-on: macos-13
|
||||
test-linuxarm:
|
||||
runs-on: ubuntu-24.04-arm
|
||||
needs: test
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-go@v6
|
||||
with: { go-version: stable }
|
||||
|
||||
- name: Git LFS pull
|
||||
uses: ./.github/actions/lfs
|
||||
- name: Test
|
||||
run: go test ./...
|
||||
|
||||
- name: Test arm (32-bit)
|
||||
run: GOARCH=arm GOARM=7 go test -short ./...
|
||||
|
||||
test-macintel:
|
||||
runs-on: macos-15-intel
|
||||
needs: test
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-go@v6
|
||||
with: { go-version: stable }
|
||||
|
||||
- name: Test
|
||||
run: go test -v ./...
|
||||
run: go test ./...
|
||||
|
||||
test-winarm:
|
||||
runs-on: windows-11-arm
|
||||
needs: test
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-go@v6
|
||||
with: { go-version: stable }
|
||||
|
||||
- name: Test
|
||||
run: go test ./...
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -13,4 +13,11 @@
|
||||
|
||||
# Dependency directories (remove the comment below to include it)
|
||||
# vendor/
|
||||
tools
|
||||
tools
|
||||
|
||||
# Go workspace file
|
||||
go.work
|
||||
go.work.sum
|
||||
|
||||
# env file
|
||||
.env
|
||||
51
README.md
51
README.md
@@ -30,10 +30,10 @@ db.QueryRow(`SELECT sqlite_version()`).Scan(&version)
|
||||
|
||||
- [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
|
||||
wraps the [C SQLite API](https://sqlite.org/cintro.html)
|
||||
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
|
||||
([example](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
|
||||
- [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver)
|
||||
provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
|
||||
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
|
||||
([example](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
|
||||
- [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
|
||||
embeds a build of SQLite into your application.
|
||||
- [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
|
||||
@@ -44,12 +44,19 @@ db.QueryRow(`SELECT sqlite_version()`).Scan(&version)
|
||||
### Advanced features
|
||||
|
||||
- [incremental BLOB I/O](https://sqlite.org/c3ref/blob_open.html)
|
||||
([example](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/blobio#example-package))
|
||||
- [nested transactions](https://sqlite.org/lang_savepoint.html)
|
||||
([example](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-Savepoint))
|
||||
- [custom functions](https://sqlite.org/c3ref/create_function.html)
|
||||
([example](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-Conn.CreateFunction))
|
||||
- [virtual tables](https://sqlite.org/vtab.html)
|
||||
([example](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-CreateModule))
|
||||
- [custom VFSes](https://sqlite.org/vfs.html)
|
||||
([examples](vfs/README.md#custom-vfses))
|
||||
- [online backup](https://sqlite.org/backup.html)
|
||||
([example](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#Conn))
|
||||
- [JSON support](https://sqlite.org/json1.html)
|
||||
([example](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package-Json))
|
||||
- [math functions](https://sqlite.org/lang_mathfunc.html)
|
||||
- [full-text search](https://sqlite.org/fts5.html)
|
||||
- [geospatial search](https://sqlite.org/geopoly.html)
|
||||
@@ -57,7 +64,6 @@ db.QueryRow(`SELECT sqlite_version()`).Scan(&version)
|
||||
- [statistics functions](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
|
||||
- [encryption at rest](vfs/adiantum/README.md)
|
||||
- [many extensions](ext/README.md)
|
||||
- [custom VFSes](vfs/README.md#custom-vfses)
|
||||
- [and more…](embed/README.md)
|
||||
|
||||
### Caveats
|
||||
@@ -65,31 +71,52 @@ db.QueryRow(`SELECT sqlite_version()`).Scan(&version)
|
||||
This module replaces the SQLite [OS Interface](https://sqlite.org/vfs.html)
|
||||
(aka VFS) with a [pure Go](vfs/) implementation,
|
||||
which has advantages and disadvantages.
|
||||
|
||||
Read more about the Go VFS design [here](vfs/README.md).
|
||||
|
||||
Because each database connection executes within a Wasm sandboxed environment,
|
||||
memory usage will be higher than alternatives.
|
||||
|
||||
### Testing
|
||||
|
||||
This project aims for [high test coverage](https://github.com/ncruces/go-sqlite3/wiki/Test-coverage-report).
|
||||
It also benefits greatly from [SQLite's](https://sqlite.org/testing.html) and
|
||||
[wazero's](https://tetrate.io/blog/introducing-wazero-from-tetrate/#:~:text=Rock%2Dsolid%20test%20approach) thorough testing.
|
||||
[wazero's](https://tetrate.io/blog/introducing-wazero-from-tetrate/#:~:text=Rock%2Dsolid%20test%20approach)
|
||||
thorough testing.
|
||||
|
||||
Every commit is [tested](https://github.com/ncruces/go-sqlite3/wiki/Test-matrix) on
|
||||
Linux (amd64/arm64/386/riscv64/ppc64le/s390x), macOS (amd64/arm64),
|
||||
Windows (amd64), FreeBSD (amd64), OpenBSD (amd64), NetBSD (amd64),
|
||||
DragonFly BSD (amd64), illumos (amd64), and Solaris (amd64).
|
||||
Every commit is tested on:
|
||||
* Linux: amd64, arm64, 386, arm, riscv64, ppc64le, loong64, s390x
|
||||
* macOS: amd64, arm64
|
||||
* Windows: amd64, arm64
|
||||
* BSD:
|
||||
* FreeBSD: amd64, arm64
|
||||
* NetBSD: amd64, arm64
|
||||
* DragonFly BSD: amd64
|
||||
* OpenBSD: amd64
|
||||
* illumos: amd64
|
||||
* Solaris: amd64
|
||||
|
||||
Certain operating system and CPU combinations have some limitations. See the [support matrix](https://github.com/ncruces/go-sqlite3/wiki/Support-matrix) for a complete overview.
|
||||
|
||||
The Go VFS is tested by running SQLite's
|
||||
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c).
|
||||
|
||||
### Performance
|
||||
|
||||
Perfomance of the [`database/sql`](https://pkg.go.dev/database/sql) driver is
|
||||
Performance of the [`database/sql`](https://pkg.go.dev/database/sql) driver is
|
||||
[competitive](https://github.com/cvilsmeier/go-sqlite-bench) with alternatives.
|
||||
|
||||
The Wasm and VFS layers are also tested by running SQLite's
|
||||
The Wasm and VFS layers are also benchmarked by running SQLite's
|
||||
[speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c).
|
||||
|
||||
### Concurrency
|
||||
|
||||
This module behaves similarly to SQLite in [multi-thread](https://sqlite.org/threadsafe.html) mode:
|
||||
it is goroutine-safe, provided that no single database connection, or object derived from it,
|
||||
is used concurrently by multiple goroutines.
|
||||
|
||||
The [`database/sql`](https://pkg.go.dev/database/sql) API is safe to use concurrently,
|
||||
according to its documentation.
|
||||
|
||||
### FAQ, issues, new features
|
||||
|
||||
For questions, please see [Discussions](https://github.com/ncruces/go-sqlite3/discussions/categories/q-a).
|
||||
@@ -98,7 +125,7 @@ Also, post there if you used this driver for something interesting
|
||||
([_"Show and tell"_](https://github.com/ncruces/go-sqlite3/discussions/categories/show-and-tell)),
|
||||
have an [idea](https://github.com/ncruces/go-sqlite3/discussions/categories/ideas)…
|
||||
|
||||
The [Issue](https://github.com/ncruces/go-sqlite3/issues) tracker is for bugs we want fixed,
|
||||
The [Issue](https://github.com/ncruces/go-sqlite3/issues) tracker is for bugs,
|
||||
and features we're working on, planning to work on, or asking for help with.
|
||||
|
||||
### Alternatives
|
||||
|
||||
38
backup.go
38
backup.go
@@ -5,8 +5,8 @@ package sqlite3
|
||||
// https://sqlite.org/c3ref/backup.html
|
||||
type Backup struct {
|
||||
c *Conn
|
||||
handle uint32
|
||||
otherc uint32
|
||||
handle ptr_t
|
||||
otherc ptr_t
|
||||
}
|
||||
|
||||
// Backup backs up srcDB on the src connection to the "main" database in dstURI.
|
||||
@@ -61,7 +61,7 @@ func (src *Conn) BackupInit(srcDB, dstURI string) (*Backup, error) {
|
||||
return src.backupInit(dst, "main", src.handle, srcDB)
|
||||
}
|
||||
|
||||
func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string) (*Backup, error) {
|
||||
func (c *Conn) backupInit(dst ptr_t, dstName string, src ptr_t, srcName string) (*Backup, error) {
|
||||
defer c.arena.mark()()
|
||||
dstPtr := c.arena.string(dstName)
|
||||
srcPtr := c.arena.string(srcName)
|
||||
@@ -71,19 +71,19 @@ func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string
|
||||
other = src
|
||||
}
|
||||
|
||||
r := c.call("sqlite3_backup_init",
|
||||
uint64(dst), uint64(dstPtr),
|
||||
uint64(src), uint64(srcPtr))
|
||||
if r == 0 {
|
||||
ptr := ptr_t(c.call("sqlite3_backup_init",
|
||||
stk_t(dst), stk_t(dstPtr),
|
||||
stk_t(src), stk_t(srcPtr)))
|
||||
if ptr == 0 {
|
||||
defer c.closeDB(other)
|
||||
r = c.call("sqlite3_errcode", uint64(dst))
|
||||
return nil, c.sqlite.error(r, dst)
|
||||
rc := res_t(c.call("sqlite3_errcode", stk_t(dst)))
|
||||
return nil, c.sqlite.error(rc, dst)
|
||||
}
|
||||
|
||||
return &Backup{
|
||||
c: c,
|
||||
otherc: other,
|
||||
handle: uint32(r),
|
||||
handle: ptr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -97,10 +97,10 @@ func (b *Backup) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
r := b.c.call("sqlite3_backup_finish", uint64(b.handle))
|
||||
rc := res_t(b.c.call("sqlite3_backup_finish", stk_t(b.handle)))
|
||||
b.c.closeDB(b.otherc)
|
||||
b.handle = 0
|
||||
return b.c.error(r)
|
||||
return b.c.error(rc)
|
||||
}
|
||||
|
||||
// Step copies up to nPage pages between the source and destination databases.
|
||||
@@ -108,11 +108,11 @@ func (b *Backup) Close() error {
|
||||
//
|
||||
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backupstep
|
||||
func (b *Backup) Step(nPage int) (done bool, err error) {
|
||||
r := b.c.call("sqlite3_backup_step", uint64(b.handle), uint64(nPage))
|
||||
if r == _DONE {
|
||||
rc := res_t(b.c.call("sqlite3_backup_step", stk_t(b.handle), stk_t(nPage)))
|
||||
if rc == _DONE {
|
||||
return true, nil
|
||||
}
|
||||
return false, b.c.error(r)
|
||||
return false, b.c.error(rc)
|
||||
}
|
||||
|
||||
// Remaining returns the number of pages still to be backed up
|
||||
@@ -120,8 +120,8 @@ func (b *Backup) Step(nPage int) (done bool, err error) {
|
||||
//
|
||||
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backupremaining
|
||||
func (b *Backup) Remaining() int {
|
||||
r := b.c.call("sqlite3_backup_remaining", uint64(b.handle))
|
||||
return int(int32(r))
|
||||
n := int32(b.c.call("sqlite3_backup_remaining", stk_t(b.handle)))
|
||||
return int(n)
|
||||
}
|
||||
|
||||
// PageCount returns the total number of pages in the source database
|
||||
@@ -129,6 +129,6 @@ func (b *Backup) Remaining() int {
|
||||
//
|
||||
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backuppagecount
|
||||
func (b *Backup) PageCount() int {
|
||||
r := b.c.call("sqlite3_backup_pagecount", uint64(b.handle))
|
||||
return int(int32(r))
|
||||
n := int32(b.c.call("sqlite3_backup_pagecount", stk_t(b.handle)))
|
||||
return int(n)
|
||||
}
|
||||
|
||||
73
blob.go
73
blob.go
@@ -20,8 +20,8 @@ type Blob struct {
|
||||
c *Conn
|
||||
bytes int64
|
||||
offset int64
|
||||
handle uint32
|
||||
bufptr uint32
|
||||
handle ptr_t
|
||||
bufptr ptr_t
|
||||
buflen int64
|
||||
}
|
||||
|
||||
@@ -31,29 +31,32 @@ var _ io.ReadWriteSeeker = &Blob{}
|
||||
//
|
||||
// https://sqlite.org/c3ref/blob_open.html
|
||||
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
|
||||
if c.interrupt.Err() != nil {
|
||||
return nil, INTERRUPT
|
||||
}
|
||||
|
||||
defer c.arena.mark()()
|
||||
blobPtr := c.arena.new(ptrlen)
|
||||
dbPtr := c.arena.string(db)
|
||||
tablePtr := c.arena.string(table)
|
||||
columnPtr := c.arena.string(column)
|
||||
|
||||
var flags uint64
|
||||
var flags int32
|
||||
if write {
|
||||
flags = 1
|
||||
}
|
||||
|
||||
c.checkInterrupt(c.handle)
|
||||
r := c.call("sqlite3_blob_open", uint64(c.handle),
|
||||
uint64(dbPtr), uint64(tablePtr), uint64(columnPtr),
|
||||
uint64(row), flags, uint64(blobPtr))
|
||||
rc := res_t(c.call("sqlite3_blob_open", stk_t(c.handle),
|
||||
stk_t(dbPtr), stk_t(tablePtr), stk_t(columnPtr),
|
||||
stk_t(row), stk_t(flags), stk_t(blobPtr)))
|
||||
|
||||
if err := c.error(r); err != nil {
|
||||
if err := c.error(rc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blob := Blob{c: c}
|
||||
blob.handle = util.ReadUint32(c.mod, blobPtr)
|
||||
blob.bytes = int64(c.call("sqlite3_blob_bytes", uint64(blob.handle)))
|
||||
blob.handle = util.Read32[ptr_t](c.mod, blobPtr)
|
||||
blob.bytes = int64(int32(c.call("sqlite3_blob_bytes", stk_t(blob.handle))))
|
||||
return &blob, nil
|
||||
}
|
||||
|
||||
@@ -67,10 +70,10 @@ func (b *Blob) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
r := b.c.call("sqlite3_blob_close", uint64(b.handle))
|
||||
rc := res_t(b.c.call("sqlite3_blob_close", stk_t(b.handle)))
|
||||
b.c.free(b.bufptr)
|
||||
b.handle = 0
|
||||
return b.c.error(r)
|
||||
return b.c.error(rc)
|
||||
}
|
||||
|
||||
// Size returns the size of the BLOB in bytes.
|
||||
@@ -94,13 +97,13 @@ func (b *Blob) Read(p []byte) (n int, err error) {
|
||||
want = avail
|
||||
}
|
||||
if want > b.buflen {
|
||||
b.bufptr = b.c.realloc(b.bufptr, uint64(want))
|
||||
b.bufptr = b.c.realloc(b.bufptr, want)
|
||||
b.buflen = want
|
||||
}
|
||||
|
||||
r := b.c.call("sqlite3_blob_read", uint64(b.handle),
|
||||
uint64(b.bufptr), uint64(want), uint64(b.offset))
|
||||
err = b.c.error(r)
|
||||
rc := res_t(b.c.call("sqlite3_blob_read", stk_t(b.handle),
|
||||
stk_t(b.bufptr), stk_t(want), stk_t(b.offset)))
|
||||
err = b.c.error(rc)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -109,7 +112,7 @@ func (b *Blob) Read(p []byte) (n int, err error) {
|
||||
err = io.EOF
|
||||
}
|
||||
|
||||
copy(p, util.View(b.c.mod, b.bufptr, uint64(want)))
|
||||
copy(p, util.View(b.c.mod, b.bufptr, want))
|
||||
return int(want), err
|
||||
}
|
||||
|
||||
@@ -127,19 +130,19 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
|
||||
want = avail
|
||||
}
|
||||
if want > b.buflen {
|
||||
b.bufptr = b.c.realloc(b.bufptr, uint64(want))
|
||||
b.bufptr = b.c.realloc(b.bufptr, want)
|
||||
b.buflen = want
|
||||
}
|
||||
|
||||
for want > 0 {
|
||||
r := b.c.call("sqlite3_blob_read", uint64(b.handle),
|
||||
uint64(b.bufptr), uint64(want), uint64(b.offset))
|
||||
err = b.c.error(r)
|
||||
rc := res_t(b.c.call("sqlite3_blob_read", stk_t(b.handle),
|
||||
stk_t(b.bufptr), stk_t(want), stk_t(b.offset)))
|
||||
err = b.c.error(rc)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
mem := util.View(b.c.mod, b.bufptr, uint64(want))
|
||||
mem := util.View(b.c.mod, b.bufptr, want)
|
||||
m, err := w.Write(mem[:want])
|
||||
b.offset += int64(m)
|
||||
n += int64(m)
|
||||
@@ -165,14 +168,14 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
|
||||
func (b *Blob) Write(p []byte) (n int, err error) {
|
||||
want := int64(len(p))
|
||||
if want > b.buflen {
|
||||
b.bufptr = b.c.realloc(b.bufptr, uint64(want))
|
||||
b.bufptr = b.c.realloc(b.bufptr, want)
|
||||
b.buflen = want
|
||||
}
|
||||
util.WriteBytes(b.c.mod, b.bufptr, p)
|
||||
|
||||
r := b.c.call("sqlite3_blob_write", uint64(b.handle),
|
||||
uint64(b.bufptr), uint64(want), uint64(b.offset))
|
||||
err = b.c.error(r)
|
||||
rc := res_t(b.c.call("sqlite3_blob_write", stk_t(b.handle),
|
||||
stk_t(b.bufptr), stk_t(want), stk_t(b.offset)))
|
||||
err = b.c.error(rc)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -196,17 +199,17 @@ func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
want = 1
|
||||
}
|
||||
if want > b.buflen {
|
||||
b.bufptr = b.c.realloc(b.bufptr, uint64(want))
|
||||
b.bufptr = b.c.realloc(b.bufptr, want)
|
||||
b.buflen = want
|
||||
}
|
||||
|
||||
for {
|
||||
mem := util.View(b.c.mod, b.bufptr, uint64(want))
|
||||
mem := util.View(b.c.mod, b.bufptr, want)
|
||||
m, err := r.Read(mem[:want])
|
||||
if m > 0 {
|
||||
r := b.c.call("sqlite3_blob_write", uint64(b.handle),
|
||||
uint64(b.bufptr), uint64(m), uint64(b.offset))
|
||||
err := b.c.error(r)
|
||||
rc := res_t(b.c.call("sqlite3_blob_write", stk_t(b.handle),
|
||||
stk_t(b.bufptr), stk_t(m), stk_t(b.offset)))
|
||||
err := b.c.error(rc)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
@@ -253,9 +256,11 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
|
||||
//
|
||||
// https://sqlite.org/c3ref/blob_reopen.html
|
||||
func (b *Blob) Reopen(row int64) error {
|
||||
b.c.checkInterrupt(b.c.handle)
|
||||
err := b.c.error(b.c.call("sqlite3_blob_reopen", uint64(b.handle), uint64(row)))
|
||||
b.bytes = int64(b.c.call("sqlite3_blob_bytes", uint64(b.handle)))
|
||||
if b.c.interrupt.Err() != nil {
|
||||
return INTERRUPT
|
||||
}
|
||||
err := b.c.error(res_t(b.c.call("sqlite3_blob_reopen", stk_t(b.handle), stk_t(row))))
|
||||
b.bytes = int64(int32(b.c.call("sqlite3_blob_bytes", stk_t(b.handle))))
|
||||
b.offset = 0
|
||||
return err
|
||||
}
|
||||
|
||||
178
config.go
178
config.go
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
|
||||
@@ -32,7 +33,7 @@ func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) {
|
||||
defer c.arena.mark()()
|
||||
argsPtr := c.arena.new(intlen + ptrlen)
|
||||
|
||||
var flag int
|
||||
var flag int32
|
||||
switch {
|
||||
case len(arg) == 0:
|
||||
flag = -1
|
||||
@@ -40,31 +41,40 @@ func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) {
|
||||
flag = 1
|
||||
}
|
||||
|
||||
util.WriteUint32(c.mod, argsPtr+0*ptrlen, uint32(flag))
|
||||
util.WriteUint32(c.mod, argsPtr+1*ptrlen, argsPtr)
|
||||
util.Write32(c.mod, argsPtr+0*ptrlen, flag)
|
||||
util.Write32(c.mod, argsPtr+1*ptrlen, argsPtr)
|
||||
|
||||
r := c.call("sqlite3_db_config", uint64(c.handle),
|
||||
uint64(op), uint64(argsPtr))
|
||||
return util.ReadUint32(c.mod, argsPtr) != 0, c.error(r)
|
||||
rc := res_t(c.call("sqlite3_db_config", stk_t(c.handle),
|
||||
stk_t(op), stk_t(argsPtr)))
|
||||
return util.ReadBool(c.mod, argsPtr), c.error(rc)
|
||||
}
|
||||
|
||||
var defaultLogger atomic.Pointer[func(code ExtendedErrorCode, msg string)]
|
||||
|
||||
// ConfigLog sets up the default error logging callback for new connections.
|
||||
//
|
||||
// https://sqlite.org/errlog.html
|
||||
func ConfigLog(cb func(code ExtendedErrorCode, msg string)) {
|
||||
defaultLogger.Store(&cb)
|
||||
}
|
||||
|
||||
// ConfigLog sets up the error logging callback for the connection.
|
||||
//
|
||||
// https://sqlite.org/errlog.html
|
||||
func (c *Conn) ConfigLog(cb func(code ExtendedErrorCode, msg string)) error {
|
||||
var enable uint64
|
||||
var enable int32
|
||||
if cb != nil {
|
||||
enable = 1
|
||||
}
|
||||
r := c.call("sqlite3_config_log_go", enable)
|
||||
if err := c.error(r); err != nil {
|
||||
rc := res_t(c.call("sqlite3_config_log_go", stk_t(enable)))
|
||||
if err := c.error(rc); err != nil {
|
||||
return err
|
||||
}
|
||||
c.log = cb
|
||||
return nil
|
||||
}
|
||||
|
||||
func logCallback(ctx context.Context, mod api.Module, _, iCode, zMsg uint32) {
|
||||
func logCallback(ctx context.Context, mod api.Module, _ ptr_t, iCode res_t, zMsg ptr_t) {
|
||||
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.log != nil {
|
||||
msg := util.ReadString(mod, zMsg, _MAX_LENGTH)
|
||||
c.log(xErrorCode(iCode), msg)
|
||||
@@ -88,93 +98,93 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro
|
||||
defer c.arena.mark()()
|
||||
ptr := c.arena.new(max(ptrlen, intlen))
|
||||
|
||||
var schemaPtr uint32
|
||||
var schemaPtr ptr_t
|
||||
if schema != "" {
|
||||
schemaPtr = c.arena.string(schema)
|
||||
}
|
||||
|
||||
var rc uint64
|
||||
var res any
|
||||
var rc res_t
|
||||
var ret any
|
||||
switch op {
|
||||
default:
|
||||
return nil, MISUSE
|
||||
|
||||
case FCNTL_RESET_CACHE:
|
||||
rc = c.call("sqlite3_file_control",
|
||||
uint64(c.handle), uint64(schemaPtr),
|
||||
uint64(op), 0)
|
||||
case FCNTL_RESET_CACHE, FCNTL_NULL_IO:
|
||||
rc = res_t(c.call("sqlite3_file_control",
|
||||
stk_t(c.handle), stk_t(schemaPtr),
|
||||
stk_t(op), 0))
|
||||
|
||||
case FCNTL_PERSIST_WAL, FCNTL_POWERSAFE_OVERWRITE:
|
||||
var flag int
|
||||
var flag int32
|
||||
switch {
|
||||
case len(arg) == 0:
|
||||
flag = -1
|
||||
case arg[0]:
|
||||
flag = 1
|
||||
}
|
||||
util.WriteUint32(c.mod, ptr, uint32(flag))
|
||||
rc = c.call("sqlite3_file_control",
|
||||
uint64(c.handle), uint64(schemaPtr),
|
||||
uint64(op), uint64(ptr))
|
||||
res = util.ReadUint32(c.mod, ptr) != 0
|
||||
util.Write32(c.mod, ptr, flag)
|
||||
rc = res_t(c.call("sqlite3_file_control",
|
||||
stk_t(c.handle), stk_t(schemaPtr),
|
||||
stk_t(op), stk_t(ptr)))
|
||||
ret = util.ReadBool(c.mod, ptr)
|
||||
|
||||
case FCNTL_CHUNK_SIZE:
|
||||
util.WriteUint32(c.mod, ptr, uint32(arg[0].(int)))
|
||||
rc = c.call("sqlite3_file_control",
|
||||
uint64(c.handle), uint64(schemaPtr),
|
||||
uint64(op), uint64(ptr))
|
||||
util.Write32(c.mod, ptr, int32(arg[0].(int)))
|
||||
rc = res_t(c.call("sqlite3_file_control",
|
||||
stk_t(c.handle), stk_t(schemaPtr),
|
||||
stk_t(op), stk_t(ptr)))
|
||||
|
||||
case FCNTL_RESERVE_BYTES:
|
||||
bytes := -1
|
||||
if len(arg) > 0 {
|
||||
bytes = arg[0].(int)
|
||||
}
|
||||
util.WriteUint32(c.mod, ptr, uint32(bytes))
|
||||
rc = c.call("sqlite3_file_control",
|
||||
uint64(c.handle), uint64(schemaPtr),
|
||||
uint64(op), uint64(ptr))
|
||||
res = int(util.ReadUint32(c.mod, ptr))
|
||||
util.Write32(c.mod, ptr, int32(bytes))
|
||||
rc = res_t(c.call("sqlite3_file_control",
|
||||
stk_t(c.handle), stk_t(schemaPtr),
|
||||
stk_t(op), stk_t(ptr)))
|
||||
ret = int(util.Read32[int32](c.mod, ptr))
|
||||
|
||||
case FCNTL_DATA_VERSION:
|
||||
rc = c.call("sqlite3_file_control",
|
||||
uint64(c.handle), uint64(schemaPtr),
|
||||
uint64(op), uint64(ptr))
|
||||
res = util.ReadUint32(c.mod, ptr)
|
||||
rc = res_t(c.call("sqlite3_file_control",
|
||||
stk_t(c.handle), stk_t(schemaPtr),
|
||||
stk_t(op), stk_t(ptr)))
|
||||
ret = util.Read32[uint32](c.mod, ptr)
|
||||
|
||||
case FCNTL_LOCKSTATE:
|
||||
rc = c.call("sqlite3_file_control",
|
||||
uint64(c.handle), uint64(schemaPtr),
|
||||
uint64(op), uint64(ptr))
|
||||
res = vfs.LockLevel(util.ReadUint32(c.mod, ptr))
|
||||
rc = res_t(c.call("sqlite3_file_control",
|
||||
stk_t(c.handle), stk_t(schemaPtr),
|
||||
stk_t(op), stk_t(ptr)))
|
||||
ret = util.Read32[vfs.LockLevel](c.mod, ptr)
|
||||
|
||||
case FCNTL_VFS_POINTER:
|
||||
rc = c.call("sqlite3_file_control",
|
||||
uint64(c.handle), uint64(schemaPtr),
|
||||
uint64(op), uint64(ptr))
|
||||
rc = res_t(c.call("sqlite3_file_control",
|
||||
stk_t(c.handle), stk_t(schemaPtr),
|
||||
stk_t(op), stk_t(ptr)))
|
||||
if rc == _OK {
|
||||
const zNameOffset = 16
|
||||
ptr = util.ReadUint32(c.mod, ptr)
|
||||
ptr = util.ReadUint32(c.mod, ptr+zNameOffset)
|
||||
ptr = util.Read32[ptr_t](c.mod, ptr)
|
||||
ptr = util.Read32[ptr_t](c.mod, ptr+zNameOffset)
|
||||
name := util.ReadString(c.mod, ptr, _MAX_NAME)
|
||||
res = vfs.Find(name)
|
||||
ret = vfs.Find(name)
|
||||
}
|
||||
|
||||
case FCNTL_FILE_POINTER, FCNTL_JOURNAL_POINTER:
|
||||
rc = c.call("sqlite3_file_control",
|
||||
uint64(c.handle), uint64(schemaPtr),
|
||||
uint64(op), uint64(ptr))
|
||||
rc = res_t(c.call("sqlite3_file_control",
|
||||
stk_t(c.handle), stk_t(schemaPtr),
|
||||
stk_t(op), stk_t(ptr)))
|
||||
if rc == _OK {
|
||||
const fileHandleOffset = 4
|
||||
ptr = util.ReadUint32(c.mod, ptr)
|
||||
ptr = util.ReadUint32(c.mod, ptr+fileHandleOffset)
|
||||
res = util.GetHandle(c.ctx, ptr)
|
||||
ptr = util.Read32[ptr_t](c.mod, ptr)
|
||||
ptr = util.Read32[ptr_t](c.mod, ptr+fileHandleOffset)
|
||||
ret = util.GetHandle(c.ctx, ptr)
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.error(rc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res, nil
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// Limit allows the size of various constructs to be
|
||||
@@ -182,20 +192,20 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro
|
||||
//
|
||||
// https://sqlite.org/c3ref/limit.html
|
||||
func (c *Conn) Limit(id LimitCategory, value int) int {
|
||||
r := c.call("sqlite3_limit", uint64(c.handle), uint64(id), uint64(value))
|
||||
return int(int32(r))
|
||||
v := int32(c.call("sqlite3_limit", stk_t(c.handle), stk_t(id), stk_t(value)))
|
||||
return int(v)
|
||||
}
|
||||
|
||||
// SetAuthorizer registers an authorizer callback with the database connection.
|
||||
//
|
||||
// https://sqlite.org/c3ref/set_authorizer.html
|
||||
func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4th, schema, inner string) AuthorizerReturnCode) error {
|
||||
var enable uint64
|
||||
var enable int32
|
||||
if cb != nil {
|
||||
enable = 1
|
||||
}
|
||||
r := c.call("sqlite3_set_authorizer_go", uint64(c.handle), enable)
|
||||
if err := c.error(r); err != nil {
|
||||
rc := res_t(c.call("sqlite3_set_authorizer_go", stk_t(c.handle), stk_t(enable)))
|
||||
if err := c.error(rc); err != nil {
|
||||
return err
|
||||
}
|
||||
c.authorizer = cb
|
||||
@@ -203,7 +213,7 @@ func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4
|
||||
|
||||
}
|
||||
|
||||
func authorizerCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zName3rd, zName4th, zSchema, zInner uint32) (rc AuthorizerReturnCode) {
|
||||
func authorizerCallback(ctx context.Context, mod api.Module, pDB ptr_t, action AuthorizerActionCode, zName3rd, zName4th, zSchema, zInner ptr_t) (rc AuthorizerReturnCode) {
|
||||
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.authorizer != nil {
|
||||
var name3rd, name4th, schema, inner string
|
||||
if zName3rd != 0 {
|
||||
@@ -227,15 +237,15 @@ func authorizerCallback(ctx context.Context, mod api.Module, pDB uint32, action
|
||||
//
|
||||
// https://sqlite.org/c3ref/trace_v2.html
|
||||
func (c *Conn) Trace(mask TraceEvent, cb func(evt TraceEvent, arg1 any, arg2 any) error) error {
|
||||
r := c.call("sqlite3_trace_go", uint64(c.handle), uint64(mask))
|
||||
if err := c.error(r); err != nil {
|
||||
rc := res_t(c.call("sqlite3_trace_go", stk_t(c.handle), stk_t(mask)))
|
||||
if err := c.error(rc); err != nil {
|
||||
return err
|
||||
}
|
||||
c.trace = cb
|
||||
return nil
|
||||
}
|
||||
|
||||
func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pArg1, pArg2 uint32) (rc uint32) {
|
||||
func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pArg1, pArg2 ptr_t) (rc res_t) {
|
||||
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.trace != nil {
|
||||
var arg1, arg2 any
|
||||
if evt == TRACE_CLOSE {
|
||||
@@ -248,14 +258,14 @@ func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pAr
|
||||
case TRACE_STMT:
|
||||
arg2 = s.SQL()
|
||||
case TRACE_PROFILE:
|
||||
arg2 = int64(util.ReadUint64(mod, pArg2))
|
||||
arg2 = util.Read64[int64](mod, pArg2)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if arg1 != nil {
|
||||
_, rc = errorCode(c.trace(evt, arg1, arg2), ERROR)
|
||||
_ = c.trace(evt, arg1, arg2)
|
||||
}
|
||||
}
|
||||
return rc
|
||||
@@ -265,24 +275,28 @@ func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pAr
|
||||
//
|
||||
// https://sqlite.org/c3ref/wal_checkpoint_v2.html
|
||||
func (c *Conn) WALCheckpoint(schema string, mode CheckpointMode) (nLog, nCkpt int, err error) {
|
||||
if c.interrupt.Err() != nil {
|
||||
return 0, 0, INTERRUPT
|
||||
}
|
||||
|
||||
defer c.arena.mark()()
|
||||
nLogPtr := c.arena.new(ptrlen)
|
||||
nCkptPtr := c.arena.new(ptrlen)
|
||||
schemaPtr := c.arena.string(schema)
|
||||
r := c.call("sqlite3_wal_checkpoint_v2",
|
||||
uint64(c.handle), uint64(schemaPtr), uint64(mode),
|
||||
uint64(nLogPtr), uint64(nCkptPtr))
|
||||
nLog = int(int32(util.ReadUint32(c.mod, nLogPtr)))
|
||||
nCkpt = int(int32(util.ReadUint32(c.mod, nCkptPtr)))
|
||||
return nLog, nCkpt, c.error(r)
|
||||
rc := res_t(c.call("sqlite3_wal_checkpoint_v2",
|
||||
stk_t(c.handle), stk_t(schemaPtr), stk_t(mode),
|
||||
stk_t(nLogPtr), stk_t(nCkptPtr)))
|
||||
nLog = int(util.Read32[int32](c.mod, nLogPtr))
|
||||
nCkpt = int(util.Read32[int32](c.mod, nCkptPtr))
|
||||
return nLog, nCkpt, c.error(rc)
|
||||
}
|
||||
|
||||
// WALAutoCheckpoint configures WAL auto-checkpoints.
|
||||
//
|
||||
// https://sqlite.org/c3ref/wal_autocheckpoint.html
|
||||
func (c *Conn) WALAutoCheckpoint(pages int) error {
|
||||
r := c.call("sqlite3_wal_autocheckpoint", uint64(c.handle), uint64(pages))
|
||||
return c.error(r)
|
||||
rc := res_t(c.call("sqlite3_wal_autocheckpoint", stk_t(c.handle), stk_t(pages)))
|
||||
return c.error(rc)
|
||||
}
|
||||
|
||||
// WALHook registers a callback function to be invoked
|
||||
@@ -290,15 +304,15 @@ func (c *Conn) WALAutoCheckpoint(pages int) error {
|
||||
//
|
||||
// https://sqlite.org/c3ref/wal_hook.html
|
||||
func (c *Conn) WALHook(cb func(db *Conn, schema string, pages int) error) {
|
||||
var enable uint64
|
||||
var enable int32
|
||||
if cb != nil {
|
||||
enable = 1
|
||||
}
|
||||
c.call("sqlite3_wal_hook_go", uint64(c.handle), enable)
|
||||
c.call("sqlite3_wal_hook_go", stk_t(c.handle), stk_t(enable))
|
||||
c.wal = cb
|
||||
}
|
||||
|
||||
func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema uint32, pages int32) (rc uint32) {
|
||||
func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema ptr_t, pages int32) (rc res_t) {
|
||||
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.wal != nil {
|
||||
schema := util.ReadString(mod, zSchema, _MAX_NAME)
|
||||
err := c.wal(c, schema, int(pages))
|
||||
@@ -311,15 +325,15 @@ func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema uint32, pa
|
||||
//
|
||||
// https://sqlite.org/c3ref/autovacuum_pages.html
|
||||
func (c *Conn) AutoVacuumPages(cb func(schema string, dbPages, freePages, bytesPerPage uint) uint) error {
|
||||
var funcPtr uint32
|
||||
var funcPtr ptr_t
|
||||
if cb != nil {
|
||||
funcPtr = util.AddHandle(c.ctx, cb)
|
||||
}
|
||||
r := c.call("sqlite3_autovacuum_pages_go", uint64(c.handle), uint64(funcPtr))
|
||||
return c.error(r)
|
||||
rc := res_t(c.call("sqlite3_autovacuum_pages_go", stk_t(c.handle), stk_t(funcPtr)))
|
||||
return c.error(rc)
|
||||
}
|
||||
|
||||
func autoVacuumCallback(ctx context.Context, mod api.Module, pApp, zSchema, nDbPage, nFreePage, nBytePerPage uint32) uint32 {
|
||||
func autoVacuumCallback(ctx context.Context, mod api.Module, pApp, zSchema ptr_t, nDbPage, nFreePage, nBytePerPage uint32) uint32 {
|
||||
fn := util.GetHandle(ctx, pApp).(func(schema string, dbPages, freePages, bytesPerPage uint) uint)
|
||||
schema := util.ReadString(mod, zSchema, _MAX_NAME)
|
||||
return uint32(fn(schema, uint(nDbPage), uint(nFreePage), uint(nBytePerPage)))
|
||||
@@ -329,14 +343,14 @@ func autoVacuumCallback(ctx context.Context, mod api.Module, pApp, zSchema, nDbP
|
||||
//
|
||||
// https://sqlite.org/c3ref/hard_heap_limit64.html
|
||||
func (c *Conn) SoftHeapLimit(n int64) int64 {
|
||||
return int64(c.call("sqlite3_soft_heap_limit64", uint64(n)))
|
||||
return int64(c.call("sqlite3_soft_heap_limit64", stk_t(n)))
|
||||
}
|
||||
|
||||
// HardHeapLimit imposes a hard limit on heap size.
|
||||
//
|
||||
// https://sqlite.org/c3ref/hard_heap_limit64.html
|
||||
func (c *Conn) HardHeapLimit(n int64) int64 {
|
||||
return int64(c.call("sqlite3_hard_heap_limit64", uint64(n)))
|
||||
return int64(c.call("sqlite3_hard_heap_limit64", stk_t(n)))
|
||||
}
|
||||
|
||||
// EnableChecksums enables checksums on a database.
|
||||
@@ -378,6 +392,6 @@ func (c *Conn) EnableChecksums(schema string) error {
|
||||
}
|
||||
|
||||
// Checkpoint the WAL.
|
||||
_, _, err = c.WALCheckpoint(schema, CHECKPOINT_RESTART)
|
||||
_, _, err = c.WALCheckpoint(schema, CHECKPOINT_FULL)
|
||||
return err
|
||||
}
|
||||
|
||||
264
conn.go
264
conn.go
@@ -3,9 +3,11 @@ package sqlite3
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"iter"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -23,7 +25,6 @@ type Conn struct {
|
||||
*sqlite
|
||||
|
||||
interrupt context.Context
|
||||
pending *Stmt
|
||||
stmts []*Stmt
|
||||
busy func(context.Context, int) bool
|
||||
log func(xErrorCode, string)
|
||||
@@ -34,11 +35,12 @@ type Conn struct {
|
||||
update func(AuthorizerActionCode, string, string, int64)
|
||||
commit func() bool
|
||||
rollback func()
|
||||
arena arena
|
||||
|
||||
busy1st time.Time
|
||||
busylst time.Time
|
||||
handle uint32
|
||||
arena arena
|
||||
handle ptr_t
|
||||
gosched uint8
|
||||
}
|
||||
|
||||
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI].
|
||||
@@ -47,7 +49,7 @@ func Open(filename string) (*Conn, error) {
|
||||
}
|
||||
|
||||
// OpenContext is like [Open] but includes a context,
|
||||
// which is used to interrupt the process of opening the connectiton.
|
||||
// which is used to interrupt the process of opening the connection.
|
||||
func OpenContext(ctx context.Context, filename string) (*Conn, error) {
|
||||
return newConn(ctx, filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
|
||||
}
|
||||
@@ -67,9 +69,9 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
|
||||
return newConn(context.Background(), filename, flags)
|
||||
}
|
||||
|
||||
type connKey struct{}
|
||||
type connKey = util.ConnKey
|
||||
|
||||
func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _ error) {
|
||||
func newConn(ctx context.Context, filename string, flags OpenFlag) (ret *Conn, _ error) {
|
||||
err := ctx.Err()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -81,7 +83,7 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if res == nil {
|
||||
if ret == nil {
|
||||
c.Close()
|
||||
c.sqlite.close()
|
||||
} else {
|
||||
@@ -90,7 +92,10 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _
|
||||
}()
|
||||
|
||||
c.ctx = context.WithValue(c.ctx, connKey{}, c)
|
||||
c.arena = c.newArena(1024)
|
||||
if logger := defaultLogger.Load(); logger != nil {
|
||||
c.ConfigLog(*logger)
|
||||
}
|
||||
c.arena = c.newArena()
|
||||
c.handle, err = c.openDB(filename, flags)
|
||||
if err == nil {
|
||||
err = initExtensions(c)
|
||||
@@ -101,21 +106,21 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
|
||||
func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) {
|
||||
defer c.arena.mark()()
|
||||
connPtr := c.arena.new(ptrlen)
|
||||
namePtr := c.arena.string(filename)
|
||||
|
||||
flags |= OPEN_EXRESCODE
|
||||
r := c.call("sqlite3_open_v2", uint64(namePtr), uint64(connPtr), uint64(flags), 0)
|
||||
rc := res_t(c.call("sqlite3_open_v2", stk_t(namePtr), stk_t(connPtr), stk_t(flags), 0))
|
||||
|
||||
handle := util.ReadUint32(c.mod, connPtr)
|
||||
if err := c.sqlite.error(r, handle); err != nil {
|
||||
handle := util.Read32[ptr_t](c.mod, connPtr)
|
||||
if err := c.sqlite.error(rc, handle); err != nil {
|
||||
c.closeDB(handle)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
c.call("sqlite3_progress_handler_go", uint64(handle), 100)
|
||||
c.call("sqlite3_progress_handler_go", stk_t(handle), 1000)
|
||||
if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
|
||||
var pragmas strings.Builder
|
||||
if _, after, ok := strings.Cut(filename, "?"); ok {
|
||||
@@ -127,10 +132,9 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
|
||||
}
|
||||
}
|
||||
if pragmas.Len() != 0 {
|
||||
c.checkInterrupt(handle)
|
||||
pragmaPtr := c.arena.string(pragmas.String())
|
||||
r := c.call("sqlite3_exec", uint64(handle), uint64(pragmaPtr), 0, 0, 0)
|
||||
if err := c.sqlite.error(r, handle, pragmas.String()); err != nil {
|
||||
rc := res_t(c.call("sqlite3_exec", stk_t(handle), stk_t(pragmaPtr), 0, 0, 0))
|
||||
if err := c.sqlite.error(rc, handle, pragmas.String()); err != nil {
|
||||
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
|
||||
c.closeDB(handle)
|
||||
return 0, err
|
||||
@@ -140,9 +144,9 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
func (c *Conn) closeDB(handle uint32) {
|
||||
r := c.call("sqlite3_close_v2", uint64(handle))
|
||||
if err := c.sqlite.error(r, handle); err != nil {
|
||||
func (c *Conn) closeDB(handle ptr_t) {
|
||||
rc := res_t(c.call("sqlite3_close_v2", stk_t(handle)))
|
||||
if err := c.sqlite.error(rc, handle); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -161,11 +165,8 @@ func (c *Conn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.pending.Close()
|
||||
c.pending = nil
|
||||
|
||||
r := c.call("sqlite3_close", uint64(c.handle))
|
||||
if err := c.error(r); err != nil {
|
||||
rc := res_t(c.call("sqlite3_close", stk_t(c.handle)))
|
||||
if err := c.error(rc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -178,12 +179,17 @@ func (c *Conn) Close() error {
|
||||
//
|
||||
// https://sqlite.org/c3ref/exec.html
|
||||
func (c *Conn) Exec(sql string) error {
|
||||
defer c.arena.mark()()
|
||||
sqlPtr := c.arena.string(sql)
|
||||
if c.interrupt.Err() != nil {
|
||||
return INTERRUPT
|
||||
}
|
||||
return c.exec(sql)
|
||||
}
|
||||
|
||||
c.checkInterrupt(c.handle)
|
||||
r := c.call("sqlite3_exec", uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
|
||||
return c.error(r, sql)
|
||||
func (c *Conn) exec(sql string) error {
|
||||
defer c.arena.mark()()
|
||||
textPtr := c.arena.string(sql)
|
||||
rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(textPtr), 0, 0, 0))
|
||||
return c.error(rc, sql)
|
||||
}
|
||||
|
||||
// Prepare calls [Conn.PrepareFlags] with no flags.
|
||||
@@ -201,24 +207,26 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
|
||||
if len(sql) > _MAX_SQL_LENGTH {
|
||||
return nil, "", TOOBIG
|
||||
}
|
||||
if c.interrupt.Err() != nil {
|
||||
return nil, "", INTERRUPT
|
||||
}
|
||||
|
||||
defer c.arena.mark()()
|
||||
stmtPtr := c.arena.new(ptrlen)
|
||||
tailPtr := c.arena.new(ptrlen)
|
||||
sqlPtr := c.arena.string(sql)
|
||||
textPtr := c.arena.string(sql)
|
||||
|
||||
c.checkInterrupt(c.handle)
|
||||
r := c.call("sqlite3_prepare_v3", uint64(c.handle),
|
||||
uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),
|
||||
uint64(stmtPtr), uint64(tailPtr))
|
||||
rc := res_t(c.call("sqlite3_prepare_v3", stk_t(c.handle),
|
||||
stk_t(textPtr), stk_t(len(sql)+1), stk_t(flags),
|
||||
stk_t(stmtPtr), stk_t(tailPtr)))
|
||||
|
||||
stmt = &Stmt{c: c}
|
||||
stmt.handle = util.ReadUint32(c.mod, stmtPtr)
|
||||
if sql := sql[util.ReadUint32(c.mod, tailPtr)-sqlPtr:]; sql != "" {
|
||||
stmt = &Stmt{c: c, sql: sql}
|
||||
stmt.handle = util.Read32[ptr_t](c.mod, stmtPtr)
|
||||
if sql := sql[util.Read32[ptr_t](c.mod, tailPtr)-textPtr:]; sql != "" {
|
||||
tail = sql
|
||||
}
|
||||
|
||||
if err := c.error(r, sql); err != nil {
|
||||
if err := c.error(rc, sql); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if stmt.handle == 0 {
|
||||
@@ -232,9 +240,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
|
||||
//
|
||||
// https://sqlite.org/c3ref/db_name.html
|
||||
func (c *Conn) DBName(n int) string {
|
||||
r := c.call("sqlite3_db_name", uint64(c.handle), uint64(n))
|
||||
|
||||
ptr := uint32(r)
|
||||
ptr := ptr_t(c.call("sqlite3_db_name", stk_t(c.handle), stk_t(n)))
|
||||
if ptr == 0 {
|
||||
return ""
|
||||
}
|
||||
@@ -245,34 +251,34 @@ func (c *Conn) DBName(n int) string {
|
||||
//
|
||||
// https://sqlite.org/c3ref/db_filename.html
|
||||
func (c *Conn) Filename(schema string) *vfs.Filename {
|
||||
var ptr uint32
|
||||
var ptr ptr_t
|
||||
if schema != "" {
|
||||
defer c.arena.mark()()
|
||||
ptr = c.arena.string(schema)
|
||||
}
|
||||
r := c.call("sqlite3_db_filename", uint64(c.handle), uint64(ptr))
|
||||
return vfs.GetFilename(c.ctx, c.mod, uint32(r), vfs.OPEN_MAIN_DB)
|
||||
ptr = ptr_t(c.call("sqlite3_db_filename", stk_t(c.handle), stk_t(ptr)))
|
||||
return vfs.GetFilename(c.ctx, c.mod, ptr, vfs.OPEN_MAIN_DB)
|
||||
}
|
||||
|
||||
// ReadOnly determines if a database is read-only.
|
||||
//
|
||||
// https://sqlite.org/c3ref/db_readonly.html
|
||||
func (c *Conn) ReadOnly(schema string) (ro bool, ok bool) {
|
||||
var ptr uint32
|
||||
var ptr ptr_t
|
||||
if schema != "" {
|
||||
defer c.arena.mark()()
|
||||
ptr = c.arena.string(schema)
|
||||
}
|
||||
r := c.call("sqlite3_db_readonly", uint64(c.handle), uint64(ptr))
|
||||
return int32(r) > 0, int32(r) < 0
|
||||
b := int32(c.call("sqlite3_db_readonly", stk_t(c.handle), stk_t(ptr)))
|
||||
return b > 0, b < 0
|
||||
}
|
||||
|
||||
// GetAutocommit tests the connection for auto-commit mode.
|
||||
//
|
||||
// https://sqlite.org/c3ref/get_autocommit.html
|
||||
func (c *Conn) GetAutocommit() bool {
|
||||
r := c.call("sqlite3_get_autocommit", uint64(c.handle))
|
||||
return r != 0
|
||||
b := int32(c.call("sqlite3_get_autocommit", stk_t(c.handle)))
|
||||
return b != 0
|
||||
}
|
||||
|
||||
// LastInsertRowID returns the rowid of the most recent successful INSERT
|
||||
@@ -280,8 +286,7 @@ func (c *Conn) GetAutocommit() bool {
|
||||
//
|
||||
// https://sqlite.org/c3ref/last_insert_rowid.html
|
||||
func (c *Conn) LastInsertRowID() int64 {
|
||||
r := c.call("sqlite3_last_insert_rowid", uint64(c.handle))
|
||||
return int64(r)
|
||||
return int64(c.call("sqlite3_last_insert_rowid", stk_t(c.handle)))
|
||||
}
|
||||
|
||||
// SetLastInsertRowID allows the application to set the value returned by
|
||||
@@ -289,7 +294,7 @@ func (c *Conn) LastInsertRowID() int64 {
|
||||
//
|
||||
// https://sqlite.org/c3ref/set_last_insert_rowid.html
|
||||
func (c *Conn) SetLastInsertRowID(id int64) {
|
||||
c.call("sqlite3_set_last_insert_rowid", uint64(c.handle), uint64(id))
|
||||
c.call("sqlite3_set_last_insert_rowid", stk_t(c.handle), stk_t(id))
|
||||
}
|
||||
|
||||
// Changes returns the number of rows modified, inserted or deleted
|
||||
@@ -298,8 +303,7 @@ func (c *Conn) SetLastInsertRowID(id int64) {
|
||||
//
|
||||
// https://sqlite.org/c3ref/changes.html
|
||||
func (c *Conn) Changes() int64 {
|
||||
r := c.call("sqlite3_changes64", uint64(c.handle))
|
||||
return int64(r)
|
||||
return int64(c.call("sqlite3_changes64", stk_t(c.handle)))
|
||||
}
|
||||
|
||||
// TotalChanges returns the number of rows modified, inserted or deleted
|
||||
@@ -308,16 +312,15 @@ func (c *Conn) Changes() int64 {
|
||||
//
|
||||
// https://sqlite.org/c3ref/total_changes.html
|
||||
func (c *Conn) TotalChanges() int64 {
|
||||
r := c.call("sqlite3_total_changes64", uint64(c.handle))
|
||||
return int64(r)
|
||||
return int64(c.call("sqlite3_total_changes64", stk_t(c.handle)))
|
||||
}
|
||||
|
||||
// ReleaseMemory frees memory used by a database connection.
|
||||
//
|
||||
// https://sqlite.org/c3ref/db_release_memory.html
|
||||
func (c *Conn) ReleaseMemory() error {
|
||||
r := c.call("sqlite3_db_release_memory", uint64(c.handle))
|
||||
return c.error(r)
|
||||
rc := res_t(c.call("sqlite3_db_release_memory", stk_t(c.handle)))
|
||||
return c.error(rc)
|
||||
}
|
||||
|
||||
// GetInterrupt gets the context set with [Conn.SetInterrupt].
|
||||
@@ -340,43 +343,22 @@ func (c *Conn) GetInterrupt() context.Context {
|
||||
//
|
||||
// https://sqlite.org/c3ref/interrupt.html
|
||||
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
|
||||
if ctx == nil {
|
||||
panic("nil Context")
|
||||
}
|
||||
old = c.interrupt
|
||||
c.interrupt = ctx
|
||||
|
||||
if ctx == old || ctx.Done() == old.Done() {
|
||||
return old
|
||||
}
|
||||
|
||||
// A busy SQL statement prevents SQLite from ignoring an interrupt
|
||||
// that comes before any other statements are started.
|
||||
if c.pending == nil {
|
||||
defer c.arena.mark()()
|
||||
stmtPtr := c.arena.new(ptrlen)
|
||||
loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`)
|
||||
c.call("sqlite3_prepare_v3", uint64(c.handle), uint64(loopPtr), math.MaxUint64,
|
||||
uint64(PREPARE_PERSISTENT), uint64(stmtPtr), 0)
|
||||
c.pending = &Stmt{c: c}
|
||||
c.pending.handle = util.ReadUint32(c.mod, stmtPtr)
|
||||
}
|
||||
|
||||
if old.Done() != nil && ctx.Err() == nil {
|
||||
c.pending.Reset()
|
||||
}
|
||||
if ctx.Done() != nil {
|
||||
c.pending.Step()
|
||||
}
|
||||
return old
|
||||
}
|
||||
|
||||
func (c *Conn) checkInterrupt(handle uint32) {
|
||||
if c.interrupt.Err() != nil {
|
||||
c.call("sqlite3_interrupt", uint64(handle))
|
||||
}
|
||||
}
|
||||
|
||||
func progressCallback(ctx context.Context, mod api.Module, _ uint32) (interrupt uint32) {
|
||||
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.interrupt.Err() != nil {
|
||||
interrupt = 1
|
||||
func progressCallback(ctx context.Context, mod api.Module, _ ptr_t) (interrupt int32) {
|
||||
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
|
||||
if c.gosched++; c.gosched%16 == 0 {
|
||||
runtime.Gosched()
|
||||
}
|
||||
if c.interrupt.Err() != nil {
|
||||
interrupt = 1
|
||||
}
|
||||
}
|
||||
return interrupt
|
||||
}
|
||||
@@ -386,11 +368,11 @@ func progressCallback(ctx context.Context, mod api.Module, _ uint32) (interrupt
|
||||
// https://sqlite.org/c3ref/busy_timeout.html
|
||||
func (c *Conn) BusyTimeout(timeout time.Duration) error {
|
||||
ms := min((timeout+time.Millisecond-1)/time.Millisecond, math.MaxInt32)
|
||||
r := c.call("sqlite3_busy_timeout", uint64(c.handle), uint64(ms))
|
||||
return c.error(r)
|
||||
rc := res_t(c.call("sqlite3_busy_timeout", stk_t(c.handle), stk_t(ms)))
|
||||
return c.error(rc)
|
||||
}
|
||||
|
||||
func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (retry uint32) {
|
||||
func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (retry int32) {
|
||||
// https://fractaledmind.github.io/2024/04/15/sqlite-on-rails-the-how-and-why-of-optimal-performance/
|
||||
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.interrupt.Err() == nil {
|
||||
switch {
|
||||
@@ -413,25 +395,22 @@ func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (r
|
||||
//
|
||||
// https://sqlite.org/c3ref/busy_handler.html
|
||||
func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool)) error {
|
||||
var enable uint64
|
||||
var enable int32
|
||||
if cb != nil {
|
||||
enable = 1
|
||||
}
|
||||
r := c.call("sqlite3_busy_handler_go", uint64(c.handle), enable)
|
||||
if err := c.error(r); err != nil {
|
||||
rc := res_t(c.call("sqlite3_busy_handler_go", stk_t(c.handle), stk_t(enable)))
|
||||
if err := c.error(rc); err != nil {
|
||||
return err
|
||||
}
|
||||
c.busy = cb
|
||||
return nil
|
||||
}
|
||||
|
||||
func busyCallback(ctx context.Context, mod api.Module, pDB uint32, count int32) (retry uint32) {
|
||||
func busyCallback(ctx context.Context, mod api.Module, pDB ptr_t, count int32) (retry int32) {
|
||||
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil {
|
||||
interrupt := c.interrupt
|
||||
if interrupt == nil {
|
||||
interrupt = context.Background()
|
||||
}
|
||||
if interrupt.Err() == nil && c.busy(interrupt, int(count)) {
|
||||
if interrupt := c.interrupt; interrupt.Err() == nil &&
|
||||
c.busy(interrupt, int(count)) {
|
||||
retry = 1
|
||||
}
|
||||
}
|
||||
@@ -441,21 +420,21 @@ func busyCallback(ctx context.Context, mod api.Module, pDB uint32, count int32)
|
||||
// Status retrieves runtime status information about a database connection.
|
||||
//
|
||||
// https://sqlite.org/c3ref/db_status.html
|
||||
func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int, err error) {
|
||||
func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int64, err error) {
|
||||
defer c.arena.mark()()
|
||||
hiPtr := c.arena.new(intlen)
|
||||
curPtr := c.arena.new(intlen)
|
||||
hiPtr := c.arena.new(8)
|
||||
curPtr := c.arena.new(8)
|
||||
|
||||
var i uint64
|
||||
var i int32
|
||||
if reset {
|
||||
i = 1
|
||||
}
|
||||
|
||||
r := c.call("sqlite3_db_status", uint64(c.handle),
|
||||
uint64(op), uint64(curPtr), uint64(hiPtr), i)
|
||||
if err = c.error(r); err == nil {
|
||||
current = int(util.ReadUint32(c.mod, curPtr))
|
||||
highwater = int(util.ReadUint32(c.mod, hiPtr))
|
||||
rc := res_t(c.call("sqlite3_db_status64", stk_t(c.handle),
|
||||
stk_t(op), stk_t(curPtr), stk_t(hiPtr), stk_t(i)))
|
||||
if err = c.error(rc); err == nil {
|
||||
current = util.Read64[int64](c.mod, curPtr)
|
||||
highwater = util.Read64[int64](c.mod, hiPtr)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -465,47 +444,60 @@ func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int, err erro
|
||||
// https://sqlite.org/c3ref/table_column_metadata.html
|
||||
func (c *Conn) TableColumnMetadata(schema, table, column string) (declType, collSeq string, notNull, primaryKey, autoInc bool, err error) {
|
||||
defer c.arena.mark()()
|
||||
|
||||
var schemaPtr, columnPtr uint32
|
||||
declTypePtr := c.arena.new(ptrlen)
|
||||
collSeqPtr := c.arena.new(ptrlen)
|
||||
notNullPtr := c.arena.new(ptrlen)
|
||||
autoIncPtr := c.arena.new(ptrlen)
|
||||
primaryKeyPtr := c.arena.new(ptrlen)
|
||||
var (
|
||||
declTypePtr ptr_t
|
||||
collSeqPtr ptr_t
|
||||
notNullPtr ptr_t
|
||||
primaryKeyPtr ptr_t
|
||||
autoIncPtr ptr_t
|
||||
columnPtr ptr_t
|
||||
schemaPtr ptr_t
|
||||
)
|
||||
if column != "" {
|
||||
declTypePtr = c.arena.new(ptrlen)
|
||||
collSeqPtr = c.arena.new(ptrlen)
|
||||
notNullPtr = c.arena.new(ptrlen)
|
||||
primaryKeyPtr = c.arena.new(ptrlen)
|
||||
autoIncPtr = c.arena.new(ptrlen)
|
||||
columnPtr = c.arena.string(column)
|
||||
}
|
||||
if schema != "" {
|
||||
schemaPtr = c.arena.string(schema)
|
||||
}
|
||||
tablePtr := c.arena.string(table)
|
||||
if column != "" {
|
||||
columnPtr = c.arena.string(column)
|
||||
}
|
||||
|
||||
r := c.call("sqlite3_table_column_metadata", uint64(c.handle),
|
||||
uint64(schemaPtr), uint64(tablePtr), uint64(columnPtr),
|
||||
uint64(declTypePtr), uint64(collSeqPtr),
|
||||
uint64(notNullPtr), uint64(primaryKeyPtr), uint64(autoIncPtr))
|
||||
if err = c.error(r); err == nil && column != "" {
|
||||
if ptr := util.ReadUint32(c.mod, declTypePtr); ptr != 0 {
|
||||
rc := res_t(c.call("sqlite3_table_column_metadata", stk_t(c.handle),
|
||||
stk_t(schemaPtr), stk_t(tablePtr), stk_t(columnPtr),
|
||||
stk_t(declTypePtr), stk_t(collSeqPtr),
|
||||
stk_t(notNullPtr), stk_t(primaryKeyPtr), stk_t(autoIncPtr)))
|
||||
if err = c.error(rc); err == nil && column != "" {
|
||||
if ptr := util.Read32[ptr_t](c.mod, declTypePtr); ptr != 0 {
|
||||
declType = util.ReadString(c.mod, ptr, _MAX_NAME)
|
||||
}
|
||||
if ptr := util.ReadUint32(c.mod, collSeqPtr); ptr != 0 {
|
||||
if ptr := util.Read32[ptr_t](c.mod, collSeqPtr); ptr != 0 {
|
||||
collSeq = util.ReadString(c.mod, ptr, _MAX_NAME)
|
||||
}
|
||||
notNull = util.ReadUint32(c.mod, notNullPtr) != 0
|
||||
autoInc = util.ReadUint32(c.mod, autoIncPtr) != 0
|
||||
primaryKey = util.ReadUint32(c.mod, primaryKeyPtr) != 0
|
||||
notNull = util.ReadBool(c.mod, notNullPtr)
|
||||
autoInc = util.ReadBool(c.mod, autoIncPtr)
|
||||
primaryKey = util.ReadBool(c.mod, primaryKeyPtr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) error(rc uint64, sql ...string) error {
|
||||
func (c *Conn) error(rc res_t, sql ...string) error {
|
||||
return c.sqlite.error(rc, c.handle, sql...)
|
||||
}
|
||||
|
||||
func (c *Conn) stmtsIter(yield func(*Stmt) bool) {
|
||||
for _, s := range c.stmts {
|
||||
if !yield(s) {
|
||||
break
|
||||
// Stmts returns an iterator for the prepared statements
|
||||
// associated with the database connection.
|
||||
//
|
||||
// https://sqlite.org/c3ref/next_stmt.html
|
||||
func (c *Conn) Stmts() iter.Seq[*Stmt] {
|
||||
return func(yield func(*Stmt) bool) {
|
||||
for _, s := range c.stmts {
|
||||
if !yield(s) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
11
conn_iter.go
11
conn_iter.go
@@ -1,11 +0,0 @@
|
||||
//go:build go1.23
|
||||
|
||||
package sqlite3
|
||||
|
||||
import "iter"
|
||||
|
||||
// Stmts returns an iterator for the prepared statements
|
||||
// associated with the database connection.
|
||||
//
|
||||
// https://sqlite.org/c3ref/next_stmt.html
|
||||
func (c *Conn) Stmts() iter.Seq[*Stmt] { return c.stmtsIter }
|
||||
@@ -1,9 +0,0 @@
|
||||
//go:build !go1.23
|
||||
|
||||
package sqlite3
|
||||
|
||||
// Stmts returns an iterator for the prepared statements
|
||||
// associated with the database connection.
|
||||
//
|
||||
// https://sqlite.org/c3ref/next_stmt.html
|
||||
func (c *Conn) Stmts() func(func(*Stmt) bool) { return c.stmtsIter }
|
||||
60
const.go
60
const.go
@@ -1,19 +1,28 @@
|
||||
package sqlite3
|
||||
|
||||
import "strconv"
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
_OK = 0 /* Successful result */
|
||||
_ROW = 100 /* sqlite3_step() has another row ready */
|
||||
_DONE = 101 /* sqlite3_step() has finished executing */
|
||||
|
||||
_MAX_NAME = 1e6 // Self-imposed limit for most NUL terminated strings.
|
||||
_MAX_LENGTH = 1e9
|
||||
_MAX_SQL_LENGTH = 1e9
|
||||
_MAX_FUNCTION_ARG = 100
|
||||
_MAX_NAME = 1e6 // Self-imposed limit for most NUL terminated strings.
|
||||
_MAX_LENGTH = 1e9
|
||||
_MAX_SQL_LENGTH = 1e9
|
||||
|
||||
ptrlen = 4
|
||||
intlen = 4
|
||||
ptrlen = util.PtrLen
|
||||
intlen = util.IntLen
|
||||
)
|
||||
|
||||
type (
|
||||
stk_t = util.Stk_t
|
||||
ptr_t = util.Ptr_t
|
||||
res_t = util.Res_t
|
||||
)
|
||||
|
||||
// ErrorCode is a result code that [Error.Code] might return.
|
||||
@@ -64,6 +73,9 @@ const (
|
||||
ERROR_MISSING_COLLSEQ ExtendedErrorCode = xErrorCode(ERROR) | (1 << 8)
|
||||
ERROR_RETRY ExtendedErrorCode = xErrorCode(ERROR) | (2 << 8)
|
||||
ERROR_SNAPSHOT ExtendedErrorCode = xErrorCode(ERROR) | (3 << 8)
|
||||
ERROR_RESERVESIZE ExtendedErrorCode = xErrorCode(ERROR) | (4 << 8)
|
||||
ERROR_KEY ExtendedErrorCode = xErrorCode(ERROR) | (5 << 8)
|
||||
ERROR_UNABLE ExtendedErrorCode = xErrorCode(ERROR) | (6 << 8)
|
||||
IOERR_READ ExtendedErrorCode = xErrorCode(IOERR) | (1 << 8)
|
||||
IOERR_SHORT_READ ExtendedErrorCode = xErrorCode(IOERR) | (2 << 8)
|
||||
IOERR_WRITE ExtendedErrorCode = xErrorCode(IOERR) | (3 << 8)
|
||||
@@ -98,6 +110,8 @@ const (
|
||||
IOERR_DATA ExtendedErrorCode = xErrorCode(IOERR) | (32 << 8)
|
||||
IOERR_CORRUPTFS ExtendedErrorCode = xErrorCode(IOERR) | (33 << 8)
|
||||
IOERR_IN_PAGE ExtendedErrorCode = xErrorCode(IOERR) | (34 << 8)
|
||||
IOERR_BADKEY ExtendedErrorCode = xErrorCode(IOERR) | (35 << 8)
|
||||
IOERR_CODEC ExtendedErrorCode = xErrorCode(IOERR) | (36 << 8)
|
||||
LOCKED_SHAREDCACHE ExtendedErrorCode = xErrorCode(LOCKED) | (1 << 8)
|
||||
LOCKED_VTAB ExtendedErrorCode = xErrorCode(LOCKED) | (2 << 8)
|
||||
BUSY_RECOVERY ExtendedErrorCode = xErrorCode(BUSY) | (1 << 8)
|
||||
@@ -166,6 +180,7 @@ const (
|
||||
PREPARE_PERSISTENT PrepareFlag = 0x01
|
||||
PREPARE_NORMALIZE PrepareFlag = 0x02
|
||||
PREPARE_NO_VTAB PrepareFlag = 0x04
|
||||
PREPARE_DONT_LOG PrepareFlag = 0x10
|
||||
)
|
||||
|
||||
// FunctionFlag is a flag that can be passed to
|
||||
@@ -175,12 +190,12 @@ const (
|
||||
type FunctionFlag uint32
|
||||
|
||||
const (
|
||||
DETERMINISTIC FunctionFlag = 0x000000800
|
||||
DIRECTONLY FunctionFlag = 0x000080000
|
||||
INNOCUOUS FunctionFlag = 0x000200000
|
||||
SELFORDER1 FunctionFlag = 0x002000000
|
||||
// SUBTYPE FunctionFlag = 0x000100000
|
||||
// RESULT_SUBTYPE FunctionFlag = 0x001000000
|
||||
DETERMINISTIC FunctionFlag = 0x000000800
|
||||
DIRECTONLY FunctionFlag = 0x000080000
|
||||
SUBTYPE FunctionFlag = 0x000100000
|
||||
INNOCUOUS FunctionFlag = 0x000200000
|
||||
RESULT_SUBTYPE FunctionFlag = 0x001000000
|
||||
SELFORDER1 FunctionFlag = 0x002000000
|
||||
)
|
||||
|
||||
// StmtStatus name counter values associated with the [Stmt.Status] method.
|
||||
@@ -219,6 +234,8 @@ const (
|
||||
DBSTATUS_DEFERRED_FKS DBStatus = 10
|
||||
DBSTATUS_CACHE_USED_SHARED DBStatus = 11
|
||||
DBSTATUS_CACHE_SPILL DBStatus = 12
|
||||
DBSTATUS_TEMPBUF_SPILL DBStatus = 13
|
||||
// DBSTATUS_MAX DBStatus = 13
|
||||
)
|
||||
|
||||
// DBConfig are the available database connection configuration options.
|
||||
@@ -247,7 +264,10 @@ const (
|
||||
DBCONFIG_TRUSTED_SCHEMA DBConfig = 1017
|
||||
DBCONFIG_STMT_SCANSTATUS DBConfig = 1018
|
||||
DBCONFIG_REVERSE_SCANORDER DBConfig = 1019
|
||||
// DBCONFIG_MAX DBConfig = 1019
|
||||
DBCONFIG_ENABLE_ATTACH_CREATE DBConfig = 1020
|
||||
DBCONFIG_ENABLE_ATTACH_WRITE DBConfig = 1021
|
||||
DBCONFIG_ENABLE_COMMENTS DBConfig = 1022
|
||||
// DBCONFIG_MAX DBConfig = 1022
|
||||
)
|
||||
|
||||
// FcntlOpcode are the available opcodes for [Conn.FileControl].
|
||||
@@ -266,6 +286,7 @@ const (
|
||||
FCNTL_DATA_VERSION FcntlOpcode = 35
|
||||
FCNTL_RESERVE_BYTES FcntlOpcode = 38
|
||||
FCNTL_RESET_CACHE FcntlOpcode = 42
|
||||
FCNTL_NULL_IO FcntlOpcode = 43
|
||||
)
|
||||
|
||||
// LimitCategory are the available run-time limit categories.
|
||||
@@ -347,13 +368,14 @@ const (
|
||||
// CheckpointMode are all the checkpoint mode values.
|
||||
//
|
||||
// https://sqlite.org/c3ref/c_checkpoint_full.html
|
||||
type CheckpointMode uint32
|
||||
type CheckpointMode int32
|
||||
|
||||
const (
|
||||
CHECKPOINT_PASSIVE CheckpointMode = 0 /* Do as much as possible w/o blocking */
|
||||
CHECKPOINT_FULL CheckpointMode = 1 /* Wait for writers, then checkpoint */
|
||||
CHECKPOINT_RESTART CheckpointMode = 2 /* Like FULL but wait for readers */
|
||||
CHECKPOINT_TRUNCATE CheckpointMode = 3 /* Like RESTART but also truncate WAL */
|
||||
CHECKPOINT_NOOP CheckpointMode = -1 /* Do no work at all */
|
||||
CHECKPOINT_PASSIVE CheckpointMode = 0 /* Do as much as possible w/o blocking */
|
||||
CHECKPOINT_FULL CheckpointMode = 1 /* Wait for writers, then checkpoint */
|
||||
CHECKPOINT_RESTART CheckpointMode = 2 /* Like FULL but wait for readers */
|
||||
CHECKPOINT_TRUNCATE CheckpointMode = 3 /* Like RESTART but also truncate WAL */
|
||||
)
|
||||
|
||||
// TxnState are the allowed return values from [Conn.TxnState].
|
||||
|
||||
75
context.go
75
context.go
@@ -1,7 +1,6 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
"time"
|
||||
@@ -15,7 +14,7 @@ import (
|
||||
// https://sqlite.org/c3ref/context.html
|
||||
type Context struct {
|
||||
c *Conn
|
||||
handle uint32
|
||||
handle ptr_t
|
||||
}
|
||||
|
||||
// Conn returns the database connection of the
|
||||
@@ -32,14 +31,14 @@ func (ctx Context) Conn() *Conn {
|
||||
// https://sqlite.org/c3ref/get_auxdata.html
|
||||
func (ctx Context) SetAuxData(n int, data any) {
|
||||
ptr := util.AddHandle(ctx.c.ctx, data)
|
||||
ctx.c.call("sqlite3_set_auxdata_go", uint64(ctx.handle), uint64(n), uint64(ptr))
|
||||
ctx.c.call("sqlite3_set_auxdata_go", stk_t(ctx.handle), stk_t(n), stk_t(ptr))
|
||||
}
|
||||
|
||||
// GetAuxData returns metadata for argument n of the function.
|
||||
//
|
||||
// https://sqlite.org/c3ref/get_auxdata.html
|
||||
func (ctx Context) GetAuxData(n int) any {
|
||||
ptr := uint32(ctx.c.call("sqlite3_get_auxdata", uint64(ctx.handle), uint64(n)))
|
||||
ptr := ptr_t(ctx.c.call("sqlite3_get_auxdata", stk_t(ctx.handle), stk_t(n)))
|
||||
return util.GetHandle(ctx.c.ctx, ptr)
|
||||
}
|
||||
|
||||
@@ -68,7 +67,7 @@ func (ctx Context) ResultInt(value int) {
|
||||
// https://sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultInt64(value int64) {
|
||||
ctx.c.call("sqlite3_result_int64",
|
||||
uint64(ctx.handle), uint64(value))
|
||||
stk_t(ctx.handle), stk_t(value))
|
||||
}
|
||||
|
||||
// ResultFloat sets the result of the function to a float64.
|
||||
@@ -76,7 +75,7 @@ func (ctx Context) ResultInt64(value int64) {
|
||||
// https://sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultFloat(value float64) {
|
||||
ctx.c.call("sqlite3_result_double",
|
||||
uint64(ctx.handle), math.Float64bits(value))
|
||||
stk_t(ctx.handle), stk_t(math.Float64bits(value)))
|
||||
}
|
||||
|
||||
// ResultText sets the result of the function to a string.
|
||||
@@ -85,27 +84,33 @@ func (ctx Context) ResultFloat(value float64) {
|
||||
func (ctx Context) ResultText(value string) {
|
||||
ptr := ctx.c.newString(value)
|
||||
ctx.c.call("sqlite3_result_text_go",
|
||||
uint64(ctx.handle), uint64(ptr), uint64(len(value)))
|
||||
stk_t(ctx.handle), stk_t(ptr), stk_t(len(value)))
|
||||
}
|
||||
|
||||
// ResultRawText sets the text result of the function to a []byte.
|
||||
// Returning a nil slice is the same as calling [Context.ResultNull].
|
||||
//
|
||||
// https://sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultRawText(value []byte) {
|
||||
if len(value) == 0 {
|
||||
ctx.ResultText("")
|
||||
return
|
||||
}
|
||||
ptr := ctx.c.newBytes(value)
|
||||
ctx.c.call("sqlite3_result_text_go",
|
||||
uint64(ctx.handle), uint64(ptr), uint64(len(value)))
|
||||
stk_t(ctx.handle), stk_t(ptr), stk_t(len(value)))
|
||||
}
|
||||
|
||||
// ResultBlob sets the result of the function to a []byte.
|
||||
// Returning a nil slice is the same as calling [Context.ResultNull].
|
||||
//
|
||||
// https://sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultBlob(value []byte) {
|
||||
if len(value) == 0 {
|
||||
ctx.ResultZeroBlob(0)
|
||||
return
|
||||
}
|
||||
ptr := ctx.c.newBytes(value)
|
||||
ctx.c.call("sqlite3_result_blob_go",
|
||||
uint64(ctx.handle), uint64(ptr), uint64(len(value)))
|
||||
stk_t(ctx.handle), stk_t(ptr), stk_t(len(value)))
|
||||
}
|
||||
|
||||
// ResultZeroBlob sets the result of the function to a zero-filled, length n BLOB.
|
||||
@@ -113,7 +118,7 @@ func (ctx Context) ResultBlob(value []byte) {
|
||||
// https://sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultZeroBlob(n int64) {
|
||||
ctx.c.call("sqlite3_result_zeroblob64",
|
||||
uint64(ctx.handle), uint64(n))
|
||||
stk_t(ctx.handle), stk_t(n))
|
||||
}
|
||||
|
||||
// ResultNull sets the result of the function to NULL.
|
||||
@@ -121,7 +126,7 @@ func (ctx Context) ResultZeroBlob(n int64) {
|
||||
// https://sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultNull() {
|
||||
ctx.c.call("sqlite3_result_null",
|
||||
uint64(ctx.handle))
|
||||
stk_t(ctx.handle))
|
||||
}
|
||||
|
||||
// ResultTime sets the result of the function to a [time.Time].
|
||||
@@ -146,14 +151,14 @@ func (ctx Context) ResultTime(value time.Time, format TimeFormat) {
|
||||
}
|
||||
|
||||
func (ctx Context) resultRFC3339Nano(value time.Time) {
|
||||
const maxlen = uint64(len(time.RFC3339Nano)) + 5
|
||||
const maxlen = int64(len(time.RFC3339Nano)) + 5
|
||||
|
||||
ptr := ctx.c.new(maxlen)
|
||||
buf := util.View(ctx.c.mod, ptr, maxlen)
|
||||
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
|
||||
|
||||
ctx.c.call("sqlite3_result_text_go",
|
||||
uint64(ctx.handle), uint64(ptr), uint64(len(buf)))
|
||||
stk_t(ctx.handle), stk_t(ptr), stk_t(len(buf)))
|
||||
}
|
||||
|
||||
// ResultPointer sets the result of the function to NULL, just like [Context.ResultNull],
|
||||
@@ -164,19 +169,7 @@ func (ctx Context) resultRFC3339Nano(value time.Time) {
|
||||
func (ctx Context) ResultPointer(ptr any) {
|
||||
valPtr := util.AddHandle(ctx.c.ctx, ptr)
|
||||
ctx.c.call("sqlite3_result_pointer_go",
|
||||
uint64(ctx.handle), uint64(valPtr))
|
||||
}
|
||||
|
||||
// ResultJSON sets the result of the function to the JSON encoding of value.
|
||||
//
|
||||
// https://sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultJSON(value any) {
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
ctx.ResultRawText(data)
|
||||
stk_t(ctx.handle), stk_t(valPtr))
|
||||
}
|
||||
|
||||
// ResultValue sets the result of the function to a copy of [Value].
|
||||
@@ -188,7 +181,7 @@ func (ctx Context) ResultValue(value Value) {
|
||||
return
|
||||
}
|
||||
ctx.c.call("sqlite3_result_value",
|
||||
uint64(ctx.handle), uint64(value.handle))
|
||||
stk_t(ctx.handle), stk_t(value.handle))
|
||||
}
|
||||
|
||||
// ResultError sets the result of the function an error.
|
||||
@@ -196,33 +189,41 @@ func (ctx Context) ResultValue(value Value) {
|
||||
// https://sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultError(err error) {
|
||||
if errors.Is(err, NOMEM) {
|
||||
ctx.c.call("sqlite3_result_error_nomem", uint64(ctx.handle))
|
||||
ctx.c.call("sqlite3_result_error_nomem", stk_t(ctx.handle))
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, TOOBIG) {
|
||||
ctx.c.call("sqlite3_result_error_toobig", uint64(ctx.handle))
|
||||
ctx.c.call("sqlite3_result_error_toobig", stk_t(ctx.handle))
|
||||
return
|
||||
}
|
||||
|
||||
msg, code := errorCode(err, _OK)
|
||||
msg, code := errorCode(err, ERROR)
|
||||
if msg != "" {
|
||||
defer ctx.c.arena.mark()()
|
||||
ptr := ctx.c.arena.string(msg)
|
||||
ctx.c.call("sqlite3_result_error",
|
||||
uint64(ctx.handle), uint64(ptr), uint64(len(msg)))
|
||||
stk_t(ctx.handle), stk_t(ptr), stk_t(len(msg)))
|
||||
}
|
||||
if code != _OK {
|
||||
if code != res_t(ERROR) {
|
||||
ctx.c.call("sqlite3_result_error_code",
|
||||
uint64(ctx.handle), uint64(code))
|
||||
stk_t(ctx.handle), stk_t(code))
|
||||
}
|
||||
}
|
||||
|
||||
// ResultSubtype sets the subtype of the result of the function.
|
||||
//
|
||||
// https://sqlite.org/c3ref/result_subtype.html
|
||||
func (ctx Context) ResultSubtype(t uint) {
|
||||
ctx.c.call("sqlite3_result_subtype",
|
||||
stk_t(ctx.handle), stk_t(uint32(t)))
|
||||
}
|
||||
|
||||
// VTabNoChange may return true if a column is being fetched as part
|
||||
// of an update during which the column value will not change.
|
||||
//
|
||||
// https://sqlite.org/c3ref/vtab_nochange.html
|
||||
func (ctx Context) VTabNoChange() bool {
|
||||
r := ctx.c.call("sqlite3_vtab_nochange", uint64(ctx.handle))
|
||||
return r != 0
|
||||
b := int32(ctx.c.call("sqlite3_vtab_nochange", stk_t(ctx.handle)))
|
||||
return b != 0
|
||||
}
|
||||
|
||||
337
driver/driver.go
337
driver/driver.go
@@ -20,22 +20,45 @@
|
||||
// - a [serializable] transaction is always "immediate";
|
||||
// - a [read-only] transaction is always "deferred".
|
||||
//
|
||||
// # Datatypes In SQLite
|
||||
//
|
||||
// SQLite is dynamically typed.
|
||||
// Columns can mostly hold any value regardless of their declared type.
|
||||
// SQLite supports most [driver.Value] types out of the box,
|
||||
// but bool and [time.Time] require special care.
|
||||
//
|
||||
// Booleans can be stored on any column type and scanned back to a *bool.
|
||||
// However, if scanned to a *any, booleans may either become an
|
||||
// int64, string or bool, depending on the declared type of the column.
|
||||
// If you use BOOLEAN for your column type,
|
||||
// 1 and 0 will always scan as true and false.
|
||||
//
|
||||
// # Working with time
|
||||
//
|
||||
// Time values can similarly be stored on any column type.
|
||||
// The time encoding/decoding format can be specified using "_timefmt":
|
||||
//
|
||||
// sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite")
|
||||
//
|
||||
// Possible values are: "auto" (the default), "sqlite", "rfc3339";
|
||||
// Special values are: "auto" (the default), "sqlite", "rfc3339";
|
||||
// - "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite;
|
||||
// - "sqlite" encodes as SQLite and decodes any [format] supported by SQLite;
|
||||
// - "rfc3339" encodes and decodes RFC 3339 only.
|
||||
//
|
||||
// If you encode as RFC 3339 (the default),
|
||||
// consider using the TIME [collating sequence] to produce a time-ordered sequence.
|
||||
// You can also set "_timefmt" to an arbitrary [sqlite3.TimeFormat] or [time.Layout].
|
||||
//
|
||||
// To scan values in other formats, [sqlite3.TimeFormat.Scanner] may be helpful.
|
||||
// To bind values in other formats, [sqlite3.TimeFormat.Encode] them before binding.
|
||||
// If you encode as RFC 3339 (the default),
|
||||
// consider using the TIME [collating sequence] to produce time-ordered sequences.
|
||||
//
|
||||
// If you encode as RFC 3339 (the default),
|
||||
// time values will scan back to a *time.Time unless your column type is TEXT.
|
||||
// Otherwise, if scanned to a *any, time values may either become an
|
||||
// int64, float64 or string, depending on the time format and declared type of the column.
|
||||
// If you use DATE, TIME, DATETIME, or TIMESTAMP for your column type,
|
||||
// "_timefmt" will be used to decode values.
|
||||
//
|
||||
// To scan values in custom formats, [sqlite3.TimeFormat.Scanner] may be helpful.
|
||||
// To bind values in custom formats, [sqlite3.TimeFormat.Encode] them before binding.
|
||||
//
|
||||
// When using a custom time struct, you'll have to implement
|
||||
// [database/sql/driver.Valuer] and [database/sql.Scanner].
|
||||
@@ -48,7 +71,7 @@
|
||||
// The Scan method needs to take into account that the value it receives can be of differing types.
|
||||
// It can already be a [time.Time], if the driver decoded the value according to "_timefmt" rules.
|
||||
// Or it can be a: string, int64, float64, []byte, or nil,
|
||||
// depending on the column type and what whoever wrote the value.
|
||||
// depending on the column type and whoever wrote the value.
|
||||
// [sqlite3.TimeFormat.Decode] may help.
|
||||
//
|
||||
// # Setting PRAGMAs
|
||||
@@ -201,7 +224,7 @@ func (n *connector) Driver() driver.Driver {
|
||||
return &SQLite{}
|
||||
}
|
||||
|
||||
func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) {
|
||||
func (n *connector) Connect(ctx context.Context) (ret driver.Conn, err error) {
|
||||
c := &conn{
|
||||
txLock: n.txLock,
|
||||
tmRead: n.tmRead,
|
||||
@@ -213,13 +236,14 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if res == nil {
|
||||
if ret == nil {
|
||||
c.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
if old := c.Conn.SetInterrupt(ctx); old != ctx {
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
}
|
||||
|
||||
if !n.pragmas {
|
||||
err = c.Conn.BusyTimeout(time.Minute)
|
||||
@@ -239,10 +263,8 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) {
|
||||
return nil, err
|
||||
}
|
||||
defer s.Close()
|
||||
if s.Step() && s.ColumnBool(0) {
|
||||
c.readOnly = '1'
|
||||
} else {
|
||||
c.readOnly = '0'
|
||||
if s.Step() {
|
||||
c.readOnly = s.ColumnBool(0)
|
||||
}
|
||||
err = s.Close()
|
||||
if err != nil {
|
||||
@@ -274,6 +296,7 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) {
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// defer conn.Close()
|
||||
//
|
||||
// err = conn.Raw(func(driverConn any) error {
|
||||
// conn := driverConn.(driver.Conn)
|
||||
@@ -297,7 +320,7 @@ type conn struct {
|
||||
txReset string
|
||||
tmRead sqlite3.TimeFormat
|
||||
tmWrite sqlite3.TimeFormat
|
||||
readOnly byte
|
||||
readOnly bool
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -333,13 +356,14 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
|
||||
|
||||
c.txReset = ``
|
||||
txBegin := `BEGIN ` + txLock
|
||||
if opts.ReadOnly {
|
||||
if opts.ReadOnly && !c.readOnly {
|
||||
txBegin += ` ; PRAGMA query_only=on`
|
||||
c.txReset = `; PRAGMA query_only=` + string(c.readOnly)
|
||||
c.txReset = `; PRAGMA query_only=off`
|
||||
}
|
||||
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
if old := c.Conn.SetInterrupt(ctx); old != ctx {
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
}
|
||||
|
||||
err := c.Conn.Exec(txBegin)
|
||||
if err != nil {
|
||||
@@ -357,13 +381,12 @@ func (c *conn) Commit() error {
|
||||
}
|
||||
|
||||
func (c *conn) Rollback() error {
|
||||
err := c.Conn.Exec(`ROLLBACK` + c.txReset)
|
||||
if errors.Is(err, sqlite3.INTERRUPT) {
|
||||
old := c.Conn.SetInterrupt(context.Background())
|
||||
// ROLLBACK even if interrupted.
|
||||
ctx := context.Background()
|
||||
if old := c.Conn.SetInterrupt(ctx); old != ctx {
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
err = c.Conn.Exec(`ROLLBACK` + c.txReset)
|
||||
}
|
||||
return err
|
||||
return c.Conn.Exec(`ROLLBACK` + c.txReset)
|
||||
}
|
||||
|
||||
func (c *conn) Prepare(query string) (driver.Stmt, error) {
|
||||
@@ -372,8 +395,9 @@ func (c *conn) Prepare(query string) (driver.Stmt, error) {
|
||||
}
|
||||
|
||||
func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
if old := c.Conn.SetInterrupt(ctx); old != ctx {
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
}
|
||||
|
||||
s, tail, err := c.Conn.Prepare(query)
|
||||
if err != nil {
|
||||
@@ -398,8 +422,9 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
|
||||
return resultRowsAffected(0), nil
|
||||
}
|
||||
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
if old := c.Conn.SetInterrupt(ctx); old != ctx {
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
}
|
||||
|
||||
err := c.Conn.Exec(query)
|
||||
if err != nil {
|
||||
@@ -462,16 +487,19 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
return nil, err
|
||||
}
|
||||
|
||||
old := s.Stmt.Conn().SetInterrupt(ctx)
|
||||
defer s.Stmt.Conn().SetInterrupt(old)
|
||||
c := s.Stmt.Conn()
|
||||
if old := c.SetInterrupt(ctx); old != ctx {
|
||||
defer c.SetInterrupt(old)
|
||||
}
|
||||
|
||||
err = s.Stmt.Exec()
|
||||
s.Stmt.ClearBindings()
|
||||
err = errors.Join(
|
||||
s.Stmt.Exec(),
|
||||
s.Stmt.ClearBindings())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newResult(s.Stmt.Conn()), nil
|
||||
return newResult(c), nil
|
||||
}
|
||||
|
||||
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
@@ -516,8 +544,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) (err error) {
|
||||
err = s.Stmt.BindTime(id, a, s.tmWrite)
|
||||
case util.JSON:
|
||||
err = s.Stmt.BindJSON(id, a.Value)
|
||||
case util.PointerUnwrap:
|
||||
err = s.Stmt.BindPointer(id, util.UnwrapPointer(a))
|
||||
case util.Pointer:
|
||||
err = s.Stmt.BindPointer(id, a.Value)
|
||||
case nil:
|
||||
err = s.Stmt.BindNull(id)
|
||||
default:
|
||||
@@ -535,7 +563,7 @@ func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
|
||||
switch arg.Value.(type) {
|
||||
case bool, int, int64, float64, string, []byte,
|
||||
time.Time, sqlite3.ZeroBlob,
|
||||
util.JSON, util.PointerUnwrap,
|
||||
util.JSON, util.Pointer,
|
||||
nil:
|
||||
return nil
|
||||
default:
|
||||
@@ -574,28 +602,59 @@ func (r resultRowsAffected) RowsAffected() (int64, error) {
|
||||
return int64(r), nil
|
||||
}
|
||||
|
||||
type scantype byte
|
||||
|
||||
const (
|
||||
_ANY scantype = iota
|
||||
_INT
|
||||
_REAL
|
||||
_TEXT
|
||||
_BLOB
|
||||
_NULL
|
||||
_BOOL
|
||||
_TIME
|
||||
_NOT_NULL
|
||||
)
|
||||
|
||||
var (
|
||||
_ [0]struct{} = [scantype(sqlite3.INTEGER) - _INT]struct{}{}
|
||||
_ [0]struct{} = [scantype(sqlite3.FLOAT) - _REAL]struct{}{}
|
||||
_ [0]struct{} = [scantype(sqlite3.TEXT) - _TEXT]struct{}{}
|
||||
_ [0]struct{} = [scantype(sqlite3.BLOB) - _BLOB]struct{}{}
|
||||
_ [0]struct{} = [scantype(sqlite3.NULL) - _NULL]struct{}{}
|
||||
_ [0]struct{} = [_NOT_NULL & (_NOT_NULL - 1)]struct{}{}
|
||||
)
|
||||
|
||||
func scanFromDecl(decl string) scantype {
|
||||
// These types are only used before we have rows,
|
||||
// and otherwise as type hints.
|
||||
// The first few ensure STRICT tables are strictly typed.
|
||||
// The other two are type hints for booleans and time.
|
||||
switch decl {
|
||||
case "INT", "INTEGER":
|
||||
return _INT
|
||||
case "REAL":
|
||||
return _REAL
|
||||
case "TEXT":
|
||||
return _TEXT
|
||||
case "BLOB":
|
||||
return _BLOB
|
||||
case "BOOLEAN":
|
||||
return _BOOL
|
||||
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
|
||||
return _TIME
|
||||
}
|
||||
return _ANY
|
||||
}
|
||||
|
||||
type rows struct {
|
||||
ctx context.Context
|
||||
*stmt
|
||||
names []string
|
||||
types []string
|
||||
nulls []bool
|
||||
scans []scantype
|
||||
}
|
||||
|
||||
type scantype byte
|
||||
|
||||
const (
|
||||
_ANY scantype = iota
|
||||
_INT scantype = scantype(sqlite3.INTEGER)
|
||||
_REAL scantype = scantype(sqlite3.FLOAT)
|
||||
_TEXT scantype = scantype(sqlite3.TEXT)
|
||||
_BLOB scantype = scantype(sqlite3.BLOB)
|
||||
_NULL scantype = scantype(sqlite3.NULL)
|
||||
_BOOL scantype = iota
|
||||
_TIME
|
||||
)
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ driver.RowsColumnTypeDatabaseTypeName = &rows{}
|
||||
@@ -603,8 +662,9 @@ var (
|
||||
)
|
||||
|
||||
func (r *rows) Close() error {
|
||||
r.Stmt.ClearBindings()
|
||||
return r.Stmt.Reset()
|
||||
return errors.Join(
|
||||
r.Stmt.Reset(),
|
||||
r.Stmt.ClearBindings())
|
||||
}
|
||||
|
||||
func (r *rows) Columns() []string {
|
||||
@@ -619,79 +679,69 @@ func (r *rows) Columns() []string {
|
||||
return r.names
|
||||
}
|
||||
|
||||
func (r *rows) scanType(index int) scantype {
|
||||
if r.scans == nil {
|
||||
count := len(r.names)
|
||||
scans := make([]scantype, count)
|
||||
for i := range scans {
|
||||
scans[i] = scanFromDecl(strings.ToUpper(r.Stmt.ColumnDeclType(i)))
|
||||
}
|
||||
r.scans = scans
|
||||
}
|
||||
return r.scans[index] &^ _NOT_NULL
|
||||
}
|
||||
|
||||
func (r *rows) loadColumnMetadata() {
|
||||
if r.nulls == nil {
|
||||
count := r.Stmt.ColumnCount()
|
||||
nulls := make([]bool, count)
|
||||
if r.types == nil {
|
||||
c := r.Stmt.Conn()
|
||||
count := len(r.names)
|
||||
types := make([]string, count)
|
||||
scans := make([]scantype, count)
|
||||
for i := range nulls {
|
||||
if col := r.Stmt.ColumnOriginName(i); col != "" {
|
||||
types[i], _, nulls[i], _, _, _ = r.Stmt.Conn().TableColumnMetadata(
|
||||
for i := range types {
|
||||
var declType string
|
||||
var notNull, autoInc bool
|
||||
if column := r.Stmt.ColumnOriginName(i); column != "" {
|
||||
declType, _, notNull, _, autoInc, _ = c.TableColumnMetadata(
|
||||
r.Stmt.ColumnDatabaseName(i),
|
||||
r.Stmt.ColumnTableName(i),
|
||||
col)
|
||||
types[i] = strings.ToUpper(types[i])
|
||||
// These types are only used before we have rows,
|
||||
// and otherwise as type hints.
|
||||
// The first few ensure STRICT tables are strictly typed.
|
||||
// The other two are type hints for booleans and time.
|
||||
switch types[i] {
|
||||
case "INT", "INTEGER":
|
||||
scans[i] = _INT
|
||||
case "REAL":
|
||||
scans[i] = _REAL
|
||||
case "TEXT":
|
||||
scans[i] = _TEXT
|
||||
case "BLOB":
|
||||
scans[i] = _BLOB
|
||||
case "BOOLEAN":
|
||||
scans[i] = _BOOL
|
||||
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
|
||||
scans[i] = _TIME
|
||||
}
|
||||
column)
|
||||
} else {
|
||||
declType = r.Stmt.ColumnDeclType(i)
|
||||
}
|
||||
if declType != "" {
|
||||
declType = strings.ToUpper(declType)
|
||||
scans[i] = scanFromDecl(declType)
|
||||
types[i] = declType
|
||||
}
|
||||
if notNull || autoInc {
|
||||
scans[i] |= _NOT_NULL
|
||||
}
|
||||
}
|
||||
r.nulls = nulls
|
||||
r.types = types
|
||||
r.scans = scans
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rows) declType(index int) string {
|
||||
if r.types == nil {
|
||||
count := r.Stmt.ColumnCount()
|
||||
types := make([]string, count)
|
||||
for i := range types {
|
||||
types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i))
|
||||
}
|
||||
r.types = types
|
||||
}
|
||||
return r.types[index]
|
||||
}
|
||||
|
||||
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
|
||||
r.loadColumnMetadata()
|
||||
decltype := r.types[index]
|
||||
if len := len(decltype); len > 0 && decltype[len-1] == ')' {
|
||||
if i := strings.LastIndexByte(decltype, '('); i >= 0 {
|
||||
decltype = decltype[:i]
|
||||
decl := r.types[index]
|
||||
if len := len(decl); len > 0 && decl[len-1] == ')' {
|
||||
if i := strings.LastIndexByte(decl, '('); i >= 0 {
|
||||
decl = decl[:i]
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(decltype)
|
||||
return strings.TrimSpace(decl)
|
||||
}
|
||||
|
||||
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
|
||||
r.loadColumnMetadata()
|
||||
if r.nulls[index] {
|
||||
return false, true
|
||||
}
|
||||
return true, false
|
||||
nullable = r.scans[index]&^_NOT_NULL == 0
|
||||
return nullable, !nullable
|
||||
}
|
||||
|
||||
func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
|
||||
r.loadColumnMetadata()
|
||||
scan := r.scans[index]
|
||||
scan := r.scans[index] &^ _NOT_NULL
|
||||
|
||||
if r.Stmt.Busy() {
|
||||
// SQLite is dynamically typed and we now have a row.
|
||||
@@ -703,7 +753,7 @@ func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
|
||||
switch {
|
||||
case scan == _TIME && val != _BLOB && val != _NULL:
|
||||
t := r.Stmt.ColumnTime(index, r.tmRead)
|
||||
useValType = t == time.Time{}
|
||||
useValType = t.IsZero()
|
||||
case scan == _BOOL && val == _INT:
|
||||
i := r.Stmt.ColumnInt64(index)
|
||||
useValType = i != 0 && i != 1
|
||||
@@ -717,25 +767,27 @@ func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
|
||||
|
||||
switch scan {
|
||||
case _INT:
|
||||
return reflect.TypeOf(int64(0))
|
||||
return reflect.TypeFor[int64]()
|
||||
case _REAL:
|
||||
return reflect.TypeOf(float64(0))
|
||||
return reflect.TypeFor[float64]()
|
||||
case _TEXT:
|
||||
return reflect.TypeOf("")
|
||||
return reflect.TypeFor[string]()
|
||||
case _BLOB:
|
||||
return reflect.TypeOf([]byte{})
|
||||
return reflect.TypeFor[[]byte]()
|
||||
case _BOOL:
|
||||
return reflect.TypeOf(false)
|
||||
return reflect.TypeFor[bool]()
|
||||
case _TIME:
|
||||
return reflect.TypeOf(time.Time{})
|
||||
return reflect.TypeFor[time.Time]()
|
||||
default:
|
||||
return reflect.TypeOf((*any)(nil)).Elem()
|
||||
return reflect.TypeFor[any]()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rows) Next(dest []driver.Value) error {
|
||||
old := r.Stmt.Conn().SetInterrupt(r.ctx)
|
||||
defer r.Stmt.Conn().SetInterrupt(old)
|
||||
c := r.Stmt.Conn()
|
||||
if old := c.SetInterrupt(r.ctx); old != r.ctx {
|
||||
defer c.SetInterrupt(old)
|
||||
}
|
||||
|
||||
if !r.Stmt.Step() {
|
||||
if err := r.Stmt.Err(); err != nil {
|
||||
@@ -745,36 +797,41 @@ func (r *rows) Next(dest []driver.Value) error {
|
||||
}
|
||||
|
||||
data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest))
|
||||
err := r.Stmt.Columns(data...)
|
||||
if err := r.Stmt.ColumnsRaw(data...); err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range dest {
|
||||
if t, ok := r.decodeTime(i, dest[i]); ok {
|
||||
dest[i] = t
|
||||
scan := r.scanType(i)
|
||||
if v, ok := dest[i].([]byte); ok {
|
||||
if len(v) == cap(v) { // a BLOB
|
||||
continue
|
||||
}
|
||||
if scan != _TEXT {
|
||||
switch r.tmWrite {
|
||||
case "", time.RFC3339, time.RFC3339Nano:
|
||||
t, ok := maybeTime(v)
|
||||
if ok {
|
||||
dest[i] = t
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
dest[i] = string(v)
|
||||
}
|
||||
switch scan {
|
||||
case _TIME:
|
||||
t, err := r.tmRead.Decode(dest[i])
|
||||
if err == nil {
|
||||
dest[i] = t
|
||||
}
|
||||
case _BOOL:
|
||||
switch dest[i] {
|
||||
case int64(0):
|
||||
dest[i] = false
|
||||
case int64(1):
|
||||
dest[i] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *rows) decodeTime(i int, v any) (_ time.Time, ok bool) {
|
||||
switch v := v.(type) {
|
||||
case int64, float64:
|
||||
// could be a time value
|
||||
case string:
|
||||
if r.tmWrite != "" && r.tmWrite != time.RFC3339 && r.tmWrite != time.RFC3339Nano {
|
||||
break
|
||||
}
|
||||
t, ok := maybeTime(v)
|
||||
if ok {
|
||||
return t, true
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
switch r.declType(i) {
|
||||
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
|
||||
// could be a time value
|
||||
default:
|
||||
return
|
||||
}
|
||||
t, err := r.tmRead.Decode(v)
|
||||
return t, err == nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ func Test_Open_error(t *testing.T) {
|
||||
func Test_Open_dir(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sql.Open("sqlite3", ".")
|
||||
db, err := Open(".")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -43,18 +43,18 @@ func Test_Open_dir(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.CANTOPEN) {
|
||||
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
|
||||
if !errors.Is(err, sqlite3.CANTOPEN_ISDIR) {
|
||||
t.Errorf("got %v, want sqlite3.CANTOPEN_ISDIR", err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Open_pragma(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t, url.Values{
|
||||
dsn := memdb.TestDB(t, url.Values{
|
||||
"_pragma": {"busy_timeout(1000)"},
|
||||
})
|
||||
|
||||
db, err := sql.Open("sqlite3", tmp)
|
||||
db, err := Open(dsn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -72,11 +72,11 @@ func Test_Open_pragma(t *testing.T) {
|
||||
|
||||
func Test_Open_pragma_invalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t, url.Values{
|
||||
dsn := memdb.TestDB(t, url.Values{
|
||||
"_pragma": {"busy_timeout 1000"},
|
||||
})
|
||||
|
||||
db, err := sql.Open("sqlite3", tmp)
|
||||
db, err := Open(dsn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -100,12 +100,12 @@ func Test_Open_pragma_invalid(t *testing.T) {
|
||||
|
||||
func Test_Open_txLock(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t, url.Values{
|
||||
dsn := memdb.TestDB(t, url.Values{
|
||||
"_txlock": {"exclusive"},
|
||||
"_pragma": {"busy_timeout(1000)"},
|
||||
})
|
||||
|
||||
db, err := sql.Open("sqlite3", tmp)
|
||||
db, err := Open(dsn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -136,11 +136,11 @@ func Test_Open_txLock(t *testing.T) {
|
||||
|
||||
func Test_Open_txLock_invalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t, url.Values{
|
||||
dsn := memdb.TestDB(t, url.Values{
|
||||
"_txlock": {"xclusive"},
|
||||
})
|
||||
|
||||
_, err := sql.Open("sqlite3", tmp+"_txlock=xclusive")
|
||||
_, err := Open(dsn)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
@@ -151,31 +151,28 @@ func Test_Open_txLock_invalid(t *testing.T) {
|
||||
|
||||
func Test_BeginTx(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t, url.Values{
|
||||
dsn := memdb.TestDB(t, url.Values{
|
||||
"_txlock": {"exclusive"},
|
||||
"_pragma": {"busy_timeout(0)"},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := sql.Open("sqlite3", tmp)
|
||||
db, err := Open(dsn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
|
||||
_, err = db.BeginTx(t.Context(), &sql.TxOptions{Isolation: sql.LevelReadCommitted})
|
||||
if err.Error() != string(util.IsolationErr) {
|
||||
t.Error("want isolationErr")
|
||||
}
|
||||
|
||||
tx1, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
|
||||
tx1, err := db.BeginTx(t.Context(), &sql.TxOptions{ReadOnly: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tx2, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
|
||||
tx2, err := db.BeginTx(t.Context(), &sql.TxOptions{ReadOnly: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -199,11 +196,69 @@ func Test_BeginTx(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_nested_context(t *testing.T) {
|
||||
t.Parallel()
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := Open(dsn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
outer, err := tx.Query(`SELECT value FROM generate_series(0)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer outer.Close()
|
||||
|
||||
want := func(rows *sql.Rows, want int) {
|
||||
t.Helper()
|
||||
|
||||
var got int
|
||||
rows.Next()
|
||||
if err := rows.Scan(&got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("got %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
want(outer, 0)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
inner, err := tx.QueryContext(ctx, `SELECT value FROM generate_series(0)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer inner.Close()
|
||||
|
||||
want(inner, 0)
|
||||
cancel()
|
||||
|
||||
var terr interface{ Temporary() bool }
|
||||
if inner.Next() || !errors.Is(inner.Err(), context.Canceled) &&
|
||||
(!errors.As(inner.Err(), &terr) || !terr.Temporary()) {
|
||||
t.Fatalf("got %v, want cancellation", inner.Err())
|
||||
}
|
||||
|
||||
want(outer, 1)
|
||||
}
|
||||
|
||||
func Test_Prepare(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := sql.Open("sqlite3", tmp)
|
||||
db, err := Open(dsn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -242,24 +297,21 @@ func Test_Prepare(t *testing.T) {
|
||||
|
||||
func Test_QueryRow_named(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := sql.Open("sqlite3", tmp)
|
||||
db, err := Open(dsn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn, err := db.Conn(ctx)
|
||||
conn, err := db.Conn(t.Context())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
stmt, err := conn.PrepareContext(ctx, `SELECT ?, ?5, :AAA, @AAA, $AAA`)
|
||||
stmt, err := conn.PrepareContext(t.Context(), `SELECT ?, ?5, :AAA, @AAA, $AAA`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -295,9 +347,9 @@ func Test_QueryRow_named(t *testing.T) {
|
||||
|
||||
func Test_QueryRow_blob_null(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := sql.Open("sqlite3", tmp)
|
||||
db, err := Open(dsn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -332,11 +384,11 @@ func Test_time(t *testing.T) {
|
||||
|
||||
for _, fmt := range []string{"auto", "sqlite", "rfc3339", time.ANSIC} {
|
||||
t.Run(fmt, func(t *testing.T) {
|
||||
tmp := memdb.TestDB(t, url.Values{
|
||||
dsn := memdb.TestDB(t, url.Values{
|
||||
"_timefmt": {fmt},
|
||||
})
|
||||
|
||||
db, err := sql.Open("sqlite3", tmp)
|
||||
db, err := Open(dsn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -369,19 +421,19 @@ func Test_time(t *testing.T) {
|
||||
|
||||
func Test_ColumnType_ScanType(t *testing.T) {
|
||||
var (
|
||||
INT = reflect.TypeOf(int64(0))
|
||||
REAL = reflect.TypeOf(float64(0))
|
||||
TEXT = reflect.TypeOf("")
|
||||
BLOB = reflect.TypeOf([]byte{})
|
||||
BOOL = reflect.TypeOf(false)
|
||||
TIME = reflect.TypeOf(time.Time{})
|
||||
ANY = reflect.TypeOf((*any)(nil)).Elem()
|
||||
INT = reflect.TypeFor[int64]()
|
||||
REAL = reflect.TypeFor[float64]()
|
||||
TEXT = reflect.TypeFor[string]()
|
||||
BLOB = reflect.TypeFor[[]byte]()
|
||||
BOOL = reflect.TypeFor[bool]()
|
||||
TIME = reflect.TypeFor[time.Time]()
|
||||
ANY = reflect.TypeFor[any]()
|
||||
)
|
||||
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := sql.Open("sqlite3", tmp)
|
||||
db, err := Open(dsn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -467,3 +519,25 @@ func Test_ColumnType_ScanType(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_loop(b *testing.B) {
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
var version string
|
||||
err = db.QueryRow(`SELECT sqlite_version();`).Scan(&version)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for b.Loop() {
|
||||
_, err := db.ExecContext(b.Context(),
|
||||
`WITH RECURSIVE c(x) AS (VALUES(1) UNION ALL SELECT x+1 FROM c WHERE x < 1000000) SELECT x FROM c;`)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
//go:build linux || darwin || windows || freebsd || openbsd || netbsd || dragonfly || illumos || sqlite3_flock || sqlite3_dotlk
|
||||
|
||||
package driver_test
|
||||
|
||||
// Adapted from: https://go.dev/doc/tutorial/database-access
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
@@ -27,7 +23,7 @@ func Example_customTime() {
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE data (
|
||||
id INTEGER PRIMARY KEY,
|
||||
date_time TEXT
|
||||
date_time ANY
|
||||
) STRICT;
|
||||
`)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package driver
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"bytes"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
|
||||
// if it roundtrips back to the same string.
|
||||
// This way times can be persisted to, and recovered from, the database,
|
||||
// but if a string is needed, [database/sql] will recover the same string.
|
||||
func maybeTime(text string) (_ time.Time, _ bool) {
|
||||
func maybeTime(text []byte) (_ time.Time, _ bool) {
|
||||
// Weed out (some) values that can't possibly be
|
||||
// [time.RFC3339Nano] timestamps.
|
||||
if len(text) < len("2006-01-02T15:04:05Z") {
|
||||
@@ -21,8 +24,8 @@ func maybeTime(text string) (_ time.Time, _ bool) {
|
||||
|
||||
// Slow path.
|
||||
var buf [len(time.RFC3339Nano)]byte
|
||||
date, err := time.Parse(time.RFC3339Nano, text)
|
||||
if err == nil && text == string(date.AppendFormat(buf[:0], time.RFC3339Nano)) {
|
||||
date, err := time.Parse(time.RFC3339Nano, string(text))
|
||||
if err == nil && bytes.Equal(text, date.AppendFormat(buf[:0], time.RFC3339Nano)) {
|
||||
return date, true
|
||||
}
|
||||
return
|
||||
|
||||
@@ -22,7 +22,7 @@ func Fuzz_stringOrTime_1(f *testing.F) {
|
||||
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
|
||||
|
||||
f.Fuzz(func(t *testing.T, str string) {
|
||||
v, ok := maybeTime(str)
|
||||
v, ok := maybeTime([]byte(str))
|
||||
if ok {
|
||||
// Make sure times round-trip to the same string:
|
||||
// https://pkg.go.dev/database/sql#Rows.Scan
|
||||
@@ -51,7 +51,7 @@ func Fuzz_stringOrTime_2(f *testing.F) {
|
||||
f.Add(int64(-763421161058), int64(222_222_222)) // twosday, year 22222BC
|
||||
|
||||
checkTime := func(t testing.TB, date time.Time) {
|
||||
v, ok := maybeTime(date.Format(time.RFC3339Nano))
|
||||
v, ok := maybeTime(date.AppendFormat(nil, time.RFC3339Nano))
|
||||
if ok {
|
||||
// Make sure times round-trip to the same time:
|
||||
if !v.Equal(date) {
|
||||
|
||||
@@ -12,3 +12,63 @@ func namedValues(args []driver.Value) []driver.NamedValue {
|
||||
}
|
||||
return named
|
||||
}
|
||||
|
||||
func notWhitespace(sql string) bool {
|
||||
const (
|
||||
code = iota
|
||||
slash
|
||||
minus
|
||||
ccomment
|
||||
sqlcomment
|
||||
endcomment
|
||||
)
|
||||
|
||||
state := code
|
||||
for _, b := range ([]byte)(sql) {
|
||||
if b == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
switch state {
|
||||
case code:
|
||||
switch b {
|
||||
case '/':
|
||||
state = slash
|
||||
case '-':
|
||||
state = minus
|
||||
case ' ', ';', '\t', '\n', '\v', '\f', '\r':
|
||||
continue
|
||||
default:
|
||||
return true
|
||||
}
|
||||
case slash:
|
||||
if b != '*' {
|
||||
return true
|
||||
}
|
||||
state = ccomment
|
||||
case minus:
|
||||
if b != '-' {
|
||||
return true
|
||||
}
|
||||
state = sqlcomment
|
||||
case ccomment:
|
||||
if b == '*' {
|
||||
state = endcomment
|
||||
}
|
||||
case sqlcomment:
|
||||
if b == '\n' {
|
||||
state = code
|
||||
}
|
||||
case endcomment:
|
||||
switch b {
|
||||
case '/':
|
||||
state = code
|
||||
case '*':
|
||||
state = endcomment
|
||||
default:
|
||||
state = ccomment
|
||||
}
|
||||
}
|
||||
}
|
||||
return state == slash || state == minus
|
||||
}
|
||||
|
||||
@@ -2,8 +2,11 @@ package driver
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
)
|
||||
|
||||
func Test_namedValues(t *testing.T) {
|
||||
@@ -12,7 +15,71 @@ func Test_namedValues(t *testing.T) {
|
||||
{Ordinal: 2, Value: false},
|
||||
}
|
||||
got := namedValues([]driver.Value{true, false})
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
if !slices.Equal(got, want) {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Fuzz_notWhitespace(f *testing.F) {
|
||||
f.Add("")
|
||||
f.Add(" ")
|
||||
f.Add(";")
|
||||
f.Add("0")
|
||||
f.Add("-")
|
||||
f.Add("-0")
|
||||
f.Add("--")
|
||||
f.Add("--0")
|
||||
f.Add("--\n")
|
||||
f.Add("--0\n")
|
||||
f.Add("/0")
|
||||
f.Add("/*")
|
||||
f.Add("/*/")
|
||||
f.Add("/**")
|
||||
f.Add("/*0")
|
||||
f.Add("/**/")
|
||||
f.Add("/***/")
|
||||
f.Add("/**0/")
|
||||
f.Add("\v")
|
||||
f.Add(" \v")
|
||||
f.Add("\xf0")
|
||||
f.Add("\000")
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
f.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
f.Fuzz(func(t *testing.T, str string) {
|
||||
if len(str) > 128 {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
c, err := db.Conn(t.Context())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
c.Raw(func(driverConn any) error {
|
||||
conn := driverConn.(*conn).Conn
|
||||
stmt, tail, err := conn.Prepare(str)
|
||||
stmt.Close()
|
||||
|
||||
// It's hard to be bug for bug compatible with SQLite.
|
||||
// We settle for somewhat less:
|
||||
// - if SQLite reports whitespace, we must too
|
||||
// - if we report whitespace, SQLite must not parse a statement
|
||||
if notWhitespace(str) {
|
||||
if stmt == nil && tail == "" && err == nil {
|
||||
t.Errorf("was whitespace: %q", str)
|
||||
}
|
||||
} else {
|
||||
if stmt != nil {
|
||||
t.Errorf("was not whitespace: %q (%v)", str, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
package driver
|
||||
|
||||
func notWhitespace(sql string) bool {
|
||||
const (
|
||||
code = iota
|
||||
slash
|
||||
minus
|
||||
ccomment
|
||||
sqlcomment
|
||||
endcomment
|
||||
)
|
||||
|
||||
state := code
|
||||
for _, b := range ([]byte)(sql) {
|
||||
if b == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
switch state {
|
||||
case code:
|
||||
switch b {
|
||||
case '/':
|
||||
state = slash
|
||||
case '-':
|
||||
state = minus
|
||||
case ' ', ';', '\t', '\n', '\v', '\f', '\r':
|
||||
continue
|
||||
default:
|
||||
return true
|
||||
}
|
||||
case slash:
|
||||
if b != '*' {
|
||||
return true
|
||||
}
|
||||
state = ccomment
|
||||
case minus:
|
||||
if b != '-' {
|
||||
return true
|
||||
}
|
||||
state = sqlcomment
|
||||
case ccomment:
|
||||
if b == '*' {
|
||||
state = endcomment
|
||||
}
|
||||
case sqlcomment:
|
||||
if b == '\n' {
|
||||
state = code
|
||||
}
|
||||
case endcomment:
|
||||
switch b {
|
||||
case '/':
|
||||
state = code
|
||||
case '*':
|
||||
state = endcomment
|
||||
default:
|
||||
state = ccomment
|
||||
}
|
||||
}
|
||||
}
|
||||
return state == slash || state == minus
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
package driver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
)
|
||||
|
||||
func Fuzz_notWhitespace(f *testing.F) {
|
||||
f.Add("")
|
||||
f.Add(" ")
|
||||
f.Add(";")
|
||||
f.Add("0")
|
||||
f.Add("-")
|
||||
f.Add("-0")
|
||||
f.Add("--")
|
||||
f.Add("--0")
|
||||
f.Add("--\n")
|
||||
f.Add("--0\n")
|
||||
f.Add("/0")
|
||||
f.Add("/*")
|
||||
f.Add("/*/")
|
||||
f.Add("/**")
|
||||
f.Add("/*0")
|
||||
f.Add("/**/")
|
||||
f.Add("/***/")
|
||||
f.Add("/**0/")
|
||||
f.Add("\v")
|
||||
f.Add(" \v")
|
||||
f.Add("\xf0")
|
||||
f.Add("\000")
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
f.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
f.Fuzz(func(t *testing.T, str string) {
|
||||
if len(str) > 128 {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
c, err := db.Conn(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
c.Raw(func(driverConn any) error {
|
||||
conn := driverConn.(*conn).Conn
|
||||
stmt, tail, err := conn.Prepare(str)
|
||||
stmt.Close()
|
||||
|
||||
// It's hard to be bug for bug compatible with SQLite.
|
||||
// We settle for somewhat less:
|
||||
// - if SQLite reports whitespace, we must too
|
||||
// - if we report whitespace, SQLite must not parse a statement
|
||||
if notWhitespace(str) {
|
||||
if stmt == nil && tail == "" && err == nil {
|
||||
t.Errorf("was whitespace: %q", str)
|
||||
}
|
||||
} else {
|
||||
if stmt != nil {
|
||||
t.Errorf("was not whitespace: %q (%v)", str, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
# Embeddable Wasm build of SQLite
|
||||
|
||||
This folder includes an embeddable Wasm build of SQLite 3.47.2 for use with
|
||||
This folder includes an embeddable Wasm build of SQLite 3.51.1 for use with
|
||||
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
|
||||
|
||||
The following optional features are compiled in:
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
# Embeddable Wasm build of SQLite
|
||||
|
||||
This folder includes an embeddable Wasm build of SQLite, including the experimental
|
||||
This folder includes an alternative embeddable Wasm build of SQLite,
|
||||
which includes the experimental
|
||||
[`BEGIN CONCURRENT`](https://sqlite.org/src/doc/begin-concurrent/doc/begin_concurrent.md) and
|
||||
[Wal2](https://sqlite.org/cgi/src/doc/wal2/doc/wal2.md) patches.
|
||||
|
||||
It also enables the optional
|
||||
[`UPDATE … ORDER BY … LIMIT`](https://sqlite.org/lang_update.html#optional_limit_and_order_by_clauses) and
|
||||
[`DELETE … ORDER BY … LIMIT`](https://sqlite.org/lang_delete.html#optional_limit_and_order_by_clauses) clauses,
|
||||
and the [`WITHIN GROUP ORDER BY`](https://sqlite.org/compile.html#enable_ordered_set_aggregates) aggregate syntax.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> This package is experimental.
|
||||
> It is built from the `bedrock` branch of SQLite,
|
||||
> since that is _currently_ the most stable, maintained branch to include both features.
|
||||
> since that is _currently_ the most stable, maintained branch to include these features.
|
||||
|
||||
> [!CAUTION]
|
||||
> The Wal2 journaling mode creates databases that other versions of SQLite cannot access.
|
||||
|
||||
Binary file not shown.
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
"github.com/ncruces/go-sqlite3/ext/stats"
|
||||
"github.com/ncruces/go-sqlite3/vfs"
|
||||
)
|
||||
|
||||
@@ -15,7 +16,7 @@ func Test_bcw2(t *testing.T) {
|
||||
|
||||
tmp := filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))
|
||||
|
||||
db, err := driver.Open("file:" + tmp + "?_pragma=journal_mode(wal2)&_txlock=concurrent")
|
||||
db, err := driver.Open("file:"+tmp+"?_pragma=journal_mode(wal2)&_txlock=concurrent", stats.Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -32,6 +33,16 @@ func Test_bcw2(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`DELETE FROM test LIMIT 1`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`SELECT median() WITHIN GROUP (ORDER BY col) FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -42,7 +53,7 @@ func Test_bcw2(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != "3.48.0" {
|
||||
if version != "3.52.0" {
|
||||
t.Error(version)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,21 +7,24 @@ ROOT=../../
|
||||
BINARYEN="$ROOT/tools/binaryen/bin"
|
||||
WASI_SDK="$ROOT/tools/wasi-sdk/bin"
|
||||
|
||||
trap 'rm -rf build/ sqlite/ bcw2.tmp' EXIT
|
||||
trap 'rm -rf sqlite/ build/ bcw2.tmp' EXIT
|
||||
|
||||
mkdir -p sqlite/
|
||||
mkdir -p build/ext/
|
||||
cp "$ROOT"/sqlite3/*.[ch] build/
|
||||
cp "$ROOT"/sqlite3/*.patch build/
|
||||
cd sqlite/
|
||||
|
||||
# https://sqlite.org/src/info/08cfa7e8b3090151
|
||||
curl -# https://sqlite.org/src/tarball/sqlite.tar.gz?r=08cfa7e8 | tar xz
|
||||
# https://sqlite.org/src/info/f273f6b8245c5dca
|
||||
curl -#L https://github.com/sqlite/sqlite/archive/7c126d7.tar.gz | tar xz --strip-components=1
|
||||
# curl -#L https://sqlite.org/src/tarball/sqlite.tar.gz?r=f273f6b824 | tar xz --strip-components=1
|
||||
|
||||
cd sqlite
|
||||
if [[ "$OSTYPE" == "msys" || "$OSTYPE" == "cygwin" ]]; then
|
||||
MSYS_NO_PATHCONV=1 nmake /f makefile.msc sqlite3.c
|
||||
MSYS_NO_PATHCONV=1 nmake /f makefile.msc sqlite3.c "OPTS=-DSQLITE_ENABLE_UPDATE_DELETE_LIMIT -DSQLITE_ENABLE_ORDERED_SET_AGGREGATES"
|
||||
else
|
||||
sh configure
|
||||
make sqlite3.c
|
||||
sh configure --enable-update-limit
|
||||
make verify-source
|
||||
OPTS=-DSQLITE_ENABLE_ORDERED_SET_AGGREGATES make sqlite3.c
|
||||
fi
|
||||
cd ~-
|
||||
|
||||
@@ -38,28 +41,33 @@ mv sqlite/ext/misc/spellfix.c build/ext/
|
||||
mv sqlite/ext/misc/uint.c build/ext/
|
||||
|
||||
cd build
|
||||
cat *.patch | patch --no-backup-if-mismatch
|
||||
cat *.patch | patch -p0 --no-backup-if-mismatch
|
||||
cd ~-
|
||||
|
||||
"$WASI_SDK/clang" --target=wasm32-wasi -std=c23 -g0 -O2 \
|
||||
-Wall -Wextra -Wno-unused-parameter -Wno-unused-function \
|
||||
-o bcw2.wasm "build/main.c" \
|
||||
-I"build" \
|
||||
-o bcw2.wasm build/main.c \
|
||||
-I"$ROOT/sqlite3/libc" -I"build" \
|
||||
-mexec-model=reactor \
|
||||
-matomics -msimd128 -mmutable-globals -mmultivalue \
|
||||
-mbulk-memory -mreference-types \
|
||||
-mnontrapping-fptoint -msign-ext \
|
||||
-fno-stack-protector -fno-stack-clash-protection \
|
||||
-mmutable-globals -mnontrapping-fptoint \
|
||||
-msimd128 -mbulk-memory -msign-ext \
|
||||
-mreference-types -mmultivalue \
|
||||
-mno-extended-const \
|
||||
-fno-stack-protector \
|
||||
-Wl,--stack-first \
|
||||
-Wl,--import-undefined \
|
||||
-Wl,--initial-memory=327680 \
|
||||
-D_HAVE_SQLITE_CONFIG_H \
|
||||
-DSQLITE_ENABLE_UPDATE_DELETE_LIMIT \
|
||||
-DSQLITE_ENABLE_ORDERED_SET_AGGREGATES \
|
||||
-DSQLITE_EXPERIMENTAL_PRAGMA_20251114 \
|
||||
-DSQLITE_CUSTOM_INCLUDE=sqlite_opt.h \
|
||||
$(awk '{print "-Wl,--export="$0}' ../exports.txt)
|
||||
|
||||
"$BINARYEN/wasm-ctor-eval" -g -c _initialize bcw2.wasm -o bcw2.tmp
|
||||
"$BINARYEN/wasm-opt" -g --strip --strip-producers -c -O3 \
|
||||
bcw2.tmp -o bcw2.wasm \
|
||||
--enable-simd --enable-mutable-globals --enable-multivalue \
|
||||
--enable-bulk-memory --enable-reference-types \
|
||||
--enable-nontrapping-float-to-int --enable-sign-ext
|
||||
"$BINARYEN/wasm-opt" -g bcw2.tmp -o bcw2.wasm \
|
||||
--low-memory-unused --gufa --generate-global-effects --converge -O3 \
|
||||
--enable-mutable-globals --enable-nontrapping-float-to-int \
|
||||
--enable-simd --enable-bulk-memory --enable-sign-ext \
|
||||
--enable-reference-types --enable-multivalue \
|
||||
--strip --strip-producers
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
module github.com/ncruces/go-sqlite3/embed/bcw2
|
||||
|
||||
go 1.21
|
||||
go 1.24.0
|
||||
|
||||
toolchain go1.23.0
|
||||
|
||||
require github.com/ncruces/go-sqlite3 v0.21.0
|
||||
require github.com/ncruces/go-sqlite3 v0.30.3
|
||||
|
||||
require (
|
||||
github.com/ncruces/julianday v1.0.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.8.2 // indirect
|
||||
golang.org/x/sys v0.28.0 // indirect
|
||||
github.com/ncruces/sort v0.1.6 // indirect
|
||||
github.com/tetratelabs/wazero v1.11.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
github.com/ncruces/go-sqlite3 v0.21.0 h1:EwKFoy1hHEopN4sFZarmi+McXdbCcbTuLixhEayXVbQ=
|
||||
github.com/ncruces/go-sqlite3 v0.21.0/go.mod h1:zxMOaSG5kFYVFK4xQa0pdwIszqxqJ0W0BxBgwdrNjuA=
|
||||
github.com/ncruces/go-sqlite3 v0.30.3 h1:X/CgWW9GzmIAkEPrifhKqf0cC15DuOVxAJaHFTTAURQ=
|
||||
github.com/ncruces/go-sqlite3 v0.30.3/go.mod h1:AxKu9sRxkludimFocbktlY6LiYSkxiI5gTA8r+os/Nw=
|
||||
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
|
||||
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
|
||||
github.com/tetratelabs/wazero v1.8.2 h1:yIgLR/b2bN31bjxwXHD8a3d+BogigR952csSDdLYEv4=
|
||||
github.com/tetratelabs/wazero v1.8.2/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
github.com/ncruces/sort v0.1.6 h1:TrsJfGRH1AoWoaeB4/+gCohot9+cA6u/INaH5agIhNk=
|
||||
github.com/ncruces/sort v0.1.6/go.mod h1:obJToO4rYr6VWP0Uw5FYymgYGt3Br4RXcs/JdKaXAPk=
|
||||
github.com/tetratelabs/wazero v1.11.0 h1:+gKemEuKCTevU4d7ZTzlsvgd1uaToIDtlQlmNbwqYhA=
|
||||
github.com/tetratelabs/wazero v1.11.0/go.mod h1:eV28rsN8Q+xwjogd7f4/Pp4xFxO7uOGbLcD/LzB1wiU=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
|
||||
@@ -11,13 +11,14 @@ package bcw2
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
//go:embed bcw2.wasm
|
||||
var binary []byte
|
||||
var binary string
|
||||
|
||||
func init() {
|
||||
sqlite3.Binary = binary
|
||||
sqlite3.Binary = unsafe.Slice(unsafe.StringData(binary), len(binary))
|
||||
}
|
||||
|
||||
@@ -12,22 +12,25 @@ trap 'rm -f sqlite3.tmp' EXIT
|
||||
"$WASI_SDK/clang" --target=wasm32-wasi -std=c23 -g0 -O2 \
|
||||
-Wall -Wextra -Wno-unused-parameter -Wno-unused-function \
|
||||
-o sqlite3.wasm "$ROOT/sqlite3/main.c" \
|
||||
-I"$ROOT/sqlite3" \
|
||||
-I"$ROOT/sqlite3/libc" -I"$ROOT/sqlite3" \
|
||||
-mexec-model=reactor \
|
||||
-matomics -msimd128 -mmutable-globals -mmultivalue \
|
||||
-mbulk-memory -mreference-types \
|
||||
-mnontrapping-fptoint -msign-ext \
|
||||
-fno-stack-protector -fno-stack-clash-protection \
|
||||
-mmutable-globals -mnontrapping-fptoint \
|
||||
-msimd128 -mbulk-memory -msign-ext \
|
||||
-mreference-types -mmultivalue \
|
||||
-mno-extended-const \
|
||||
-fno-stack-protector \
|
||||
-Wl,--stack-first \
|
||||
-Wl,--import-undefined \
|
||||
-Wl,--initial-memory=327680 \
|
||||
-D_HAVE_SQLITE_CONFIG_H \
|
||||
-DSQLITE_EXPERIMENTAL_PRAGMA_20251114 \
|
||||
-DSQLITE_CUSTOM_INCLUDE=sqlite_opt.h \
|
||||
$(awk '{print "-Wl,--export="$0}' exports.txt)
|
||||
|
||||
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
|
||||
"$BINARYEN/wasm-opt" -g --strip --strip-producers -c -O3 \
|
||||
sqlite3.tmp -o sqlite3.wasm \
|
||||
--enable-simd --enable-mutable-globals --enable-multivalue \
|
||||
--enable-bulk-memory --enable-reference-types \
|
||||
--enable-nontrapping-float-to-int --enable-sign-ext
|
||||
"$BINARYEN/wasm-opt" -g sqlite3.tmp -o sqlite3.wasm \
|
||||
--low-memory-unused --gufa --generate-global-effects --converge -O3 \
|
||||
--enable-mutable-globals --enable-nontrapping-float-to-int \
|
||||
--enable-simd --enable-bulk-memory --enable-sign-ext \
|
||||
--enable-reference-types --enable-multivalue \
|
||||
--strip --strip-producers
|
||||
@@ -59,13 +59,14 @@ sqlite3_db_filename
|
||||
sqlite3_db_name
|
||||
sqlite3_db_readonly
|
||||
sqlite3_db_release_memory
|
||||
sqlite3_db_status
|
||||
sqlite3_db_status64
|
||||
sqlite3_declare_vtab
|
||||
sqlite3_errcode
|
||||
sqlite3_errmsg
|
||||
sqlite3_error_offset
|
||||
sqlite3_errstr
|
||||
sqlite3_exec
|
||||
sqlite3_exec_go
|
||||
sqlite3_expanded_sql
|
||||
sqlite3_file_control
|
||||
sqlite3_filename_database
|
||||
@@ -77,8 +78,10 @@ sqlite3_get_autocommit
|
||||
sqlite3_get_auxdata
|
||||
sqlite3_hard_heap_limit64
|
||||
sqlite3_interrupt
|
||||
sqlite3_invoke_busy_handler_go
|
||||
sqlite3_last_insert_rowid
|
||||
sqlite3_limit
|
||||
sqlite3_log_go
|
||||
sqlite3_malloc64
|
||||
sqlite3_open_v2
|
||||
sqlite3_overload_function
|
||||
@@ -95,6 +98,7 @@ sqlite3_result_error_toobig
|
||||
sqlite3_result_int64
|
||||
sqlite3_result_null
|
||||
sqlite3_result_pointer_go
|
||||
sqlite3_result_subtype
|
||||
sqlite3_result_text_go
|
||||
sqlite3_result_value
|
||||
sqlite3_result_zeroblob64
|
||||
@@ -123,6 +127,7 @@ sqlite3_value_int64
|
||||
sqlite3_value_nochange
|
||||
sqlite3_value_numeric_type
|
||||
sqlite3_value_pointer_go
|
||||
sqlite3_value_subtype
|
||||
sqlite3_value_text
|
||||
sqlite3_value_type
|
||||
sqlite3_vtab_collation
|
||||
|
||||
@@ -8,13 +8,16 @@ package embed
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
//go:embed sqlite3.wasm
|
||||
var binary []byte
|
||||
var binary string
|
||||
|
||||
func init() {
|
||||
sqlite3.Binary = binary
|
||||
if sqlite3.Binary == nil {
|
||||
sqlite3.Binary = unsafe.Slice(unsafe.StringData(binary), len(binary))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ func Test_init(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != "3.47.2" {
|
||||
if version != "3.51.1" {
|
||||
t.Error(version)
|
||||
}
|
||||
}
|
||||
|
||||
Binary file not shown.
68
error.go
68
error.go
@@ -2,7 +2,6 @@ package sqlite3
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
@@ -12,10 +11,10 @@ import (
|
||||
//
|
||||
// https://sqlite.org/c3ref/errcode.html
|
||||
type Error struct {
|
||||
str string
|
||||
sys error
|
||||
msg string
|
||||
sql string
|
||||
code uint64
|
||||
code res_t
|
||||
}
|
||||
|
||||
// Code returns the primary error code for this error.
|
||||
@@ -29,28 +28,34 @@ func (e *Error) Code() ErrorCode {
|
||||
//
|
||||
// https://sqlite.org/rescode.html
|
||||
func (e *Error) ExtendedCode() ExtendedErrorCode {
|
||||
return ExtendedErrorCode(e.code)
|
||||
return xErrorCode(e.code)
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e *Error) Error() string {
|
||||
var b strings.Builder
|
||||
b.WriteString("sqlite3: ")
|
||||
|
||||
if e.str != "" {
|
||||
b.WriteString(e.str)
|
||||
} else {
|
||||
b.WriteString(strconv.Itoa(int(e.code)))
|
||||
}
|
||||
b.WriteString(util.ErrorCodeString(e.code))
|
||||
|
||||
if e.msg != "" {
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.msg)
|
||||
}
|
||||
if e.sys != nil {
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.sys.Error())
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying operating system error
|
||||
// that caused the I/O error or failure to open a file.
|
||||
//
|
||||
// https://sqlite.org/c3ref/system_errno.html
|
||||
func (e *Error) Unwrap() error {
|
||||
return e.sys
|
||||
}
|
||||
|
||||
// Is tests whether this error matches a given [ErrorCode] or [ExtendedErrorCode].
|
||||
//
|
||||
// It makes it possible to do:
|
||||
@@ -83,7 +88,7 @@ func (e *Error) As(err any) bool {
|
||||
|
||||
// Temporary returns true for [BUSY] errors.
|
||||
func (e *Error) Temporary() bool {
|
||||
return e.Code() == BUSY
|
||||
return e.Code() == BUSY || e.Code() == INTERRUPT
|
||||
}
|
||||
|
||||
// Timeout returns true for [BUSY_TIMEOUT] errors.
|
||||
@@ -98,22 +103,31 @@ func (e *Error) SQL() string {
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e ErrorCode) Error() string {
|
||||
return util.ErrorCodeString(uint32(e))
|
||||
return util.ErrorCodeString(e)
|
||||
}
|
||||
|
||||
// As converts this error to an [ExtendedErrorCode].
|
||||
func (e ErrorCode) As(err any) bool {
|
||||
c, ok := err.(*xErrorCode)
|
||||
if ok {
|
||||
*c = xErrorCode(e)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// Temporary returns true for [BUSY] errors.
|
||||
func (e ErrorCode) Temporary() bool {
|
||||
return e == BUSY
|
||||
return e == BUSY || e == INTERRUPT
|
||||
}
|
||||
|
||||
// ExtendedCode returns the extended error code for this error.
|
||||
func (e ErrorCode) ExtendedCode() ExtendedErrorCode {
|
||||
return ExtendedErrorCode(e)
|
||||
return xErrorCode(e)
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e ExtendedErrorCode) Error() string {
|
||||
return util.ErrorCodeString(uint32(e))
|
||||
return util.ErrorCodeString(e)
|
||||
}
|
||||
|
||||
// Is tests whether this error matches a given [ErrorCode].
|
||||
@@ -133,7 +147,7 @@ func (e ExtendedErrorCode) As(err any) bool {
|
||||
|
||||
// Temporary returns true for [BUSY] errors.
|
||||
func (e ExtendedErrorCode) Temporary() bool {
|
||||
return ErrorCode(e) == BUSY
|
||||
return ErrorCode(e) == BUSY || ErrorCode(e) == INTERRUPT
|
||||
}
|
||||
|
||||
// Timeout returns true for [BUSY_TIMEOUT] errors.
|
||||
@@ -146,27 +160,23 @@ func (e ExtendedErrorCode) Code() ErrorCode {
|
||||
return ErrorCode(e)
|
||||
}
|
||||
|
||||
func errorCode(err error, def ErrorCode) (msg string, code uint32) {
|
||||
func errorCode(err error, def ErrorCode) (msg string, code res_t) {
|
||||
switch code := err.(type) {
|
||||
case nil:
|
||||
return "", _OK
|
||||
case ErrorCode:
|
||||
return "", uint32(code)
|
||||
return "", res_t(code)
|
||||
case xErrorCode:
|
||||
return "", uint32(code)
|
||||
return "", res_t(code)
|
||||
case *Error:
|
||||
return code.msg, uint32(code.code)
|
||||
return code.msg, res_t(code.code)
|
||||
}
|
||||
|
||||
var ecode ErrorCode
|
||||
var xcode xErrorCode
|
||||
switch {
|
||||
case errors.As(err, &xcode):
|
||||
code = uint32(xcode)
|
||||
case errors.As(err, &ecode):
|
||||
code = uint32(ecode)
|
||||
default:
|
||||
code = uint32(def)
|
||||
if errors.As(err, &xcode) {
|
||||
code = res_t(xcode)
|
||||
} else {
|
||||
code = res_t(def)
|
||||
}
|
||||
return err.Error(), code
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestError(t *testing.T) {
|
||||
if !errors.Is(err, xErrorCode(0x8080)) {
|
||||
t.Errorf("want true")
|
||||
}
|
||||
if s := err.Error(); s != "sqlite3: 32896" {
|
||||
if s := err.Error(); s != "sqlite3: unknown error" {
|
||||
t.Errorf("got %q", s)
|
||||
}
|
||||
if ok := errors.As(err.ExtendedCode(), &ecode); !ok || ecode != ErrorCode(0x80) {
|
||||
@@ -59,14 +59,14 @@ func TestError_Temporary(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code uint64
|
||||
code res_t
|
||||
want bool
|
||||
}{
|
||||
{"ERROR", uint64(ERROR), false},
|
||||
{"BUSY", uint64(BUSY), true},
|
||||
{"BUSY_RECOVERY", uint64(BUSY_RECOVERY), true},
|
||||
{"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), true},
|
||||
{"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true},
|
||||
{"ERROR", res_t(ERROR), false},
|
||||
{"BUSY", res_t(BUSY), true},
|
||||
{"BUSY_RECOVERY", res_t(BUSY_RECOVERY), true},
|
||||
{"BUSY_SNAPSHOT", res_t(BUSY_SNAPSHOT), true},
|
||||
{"BUSY_TIMEOUT", res_t(BUSY_TIMEOUT), true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@@ -83,7 +83,7 @@ func TestError_Temporary(t *testing.T) {
|
||||
}
|
||||
}
|
||||
{
|
||||
err := ExtendedErrorCode(tt.code)
|
||||
err := xErrorCode(tt.code)
|
||||
if got := err.Temporary(); got != tt.want {
|
||||
t.Errorf("ExtendedErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
|
||||
}
|
||||
@@ -97,14 +97,14 @@ func TestError_Timeout(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code uint64
|
||||
code res_t
|
||||
want bool
|
||||
}{
|
||||
{"ERROR", uint64(ERROR), false},
|
||||
{"BUSY", uint64(BUSY), false},
|
||||
{"BUSY_RECOVERY", uint64(BUSY_RECOVERY), false},
|
||||
{"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), false},
|
||||
{"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true},
|
||||
{"ERROR", res_t(ERROR), false},
|
||||
{"BUSY", res_t(BUSY), false},
|
||||
{"BUSY_RECOVERY", res_t(BUSY_RECOVERY), false},
|
||||
{"BUSY_SNAPSHOT", res_t(BUSY_SNAPSHOT), false},
|
||||
{"BUSY_TIMEOUT", res_t(BUSY_TIMEOUT), true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@@ -115,7 +115,7 @@ func TestError_Timeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
{
|
||||
err := ExtendedErrorCode(tt.code)
|
||||
err := xErrorCode(tt.code)
|
||||
if got := err.Timeout(); got != tt.want {
|
||||
t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want)
|
||||
}
|
||||
@@ -136,8 +136,8 @@ func Test_ErrorCode_Error(t *testing.T) {
|
||||
// Test all error codes.
|
||||
for i := 0; i == int(ErrorCode(i)); i++ {
|
||||
want := "sqlite3: "
|
||||
r := db.call("sqlite3_errstr", uint64(i))
|
||||
want += util.ReadString(db.mod, uint32(r), _MAX_NAME)
|
||||
ptr := ptr_t(db.call("sqlite3_errstr", stk_t(i)))
|
||||
want += util.ReadString(db.mod, ptr, _MAX_NAME)
|
||||
|
||||
got := ErrorCode(i).Error()
|
||||
if got != want {
|
||||
@@ -156,12 +156,12 @@ func Test_ExtendedErrorCode_Error(t *testing.T) {
|
||||
defer db.Close()
|
||||
|
||||
// Test all extended error codes.
|
||||
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
|
||||
for i := 0; i == int(xErrorCode(i)); i++ {
|
||||
want := "sqlite3: "
|
||||
r := db.call("sqlite3_errstr", uint64(i))
|
||||
want += util.ReadString(db.mod, uint32(r), _MAX_NAME)
|
||||
ptr := ptr_t(db.call("sqlite3_errstr", stk_t(i)))
|
||||
want += util.ReadString(db.mod, ptr, _MAX_NAME)
|
||||
|
||||
got := ExtendedErrorCode(i).Error()
|
||||
got := xErrorCode(i).Error()
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q, with %d", got, want, i)
|
||||
}
|
||||
@@ -172,7 +172,7 @@ func Test_errorCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
arg error
|
||||
wantMsg string
|
||||
wantCode uint32
|
||||
wantCode res_t
|
||||
}{
|
||||
{nil, "", _OK},
|
||||
{ERROR, "", util.ERROR},
|
||||
@@ -190,7 +190,7 @@ func Test_errorCode(t *testing.T) {
|
||||
if gotMsg != tt.wantMsg {
|
||||
t.Errorf("errorCode() gotMsg = %q, want %q", gotMsg, tt.wantMsg)
|
||||
}
|
||||
if gotCode != uint32(tt.wantCode) {
|
||||
if gotCode != tt.wantCode {
|
||||
t.Errorf("errorCode() gotCode = %d, want %d", gotCode, tt.wantCode)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -25,13 +25,24 @@ you can load into your database connections.
|
||||
creates [pivot tables](https://github.com/jakethaw/pivot_vtab).
|
||||
- [`github.com/ncruces/go-sqlite3/ext/regexp`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/regexp)
|
||||
provides regular expression functions.
|
||||
- [`github.com/ncruces/go-sqlite3/ext/serdes`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/serdes)
|
||||
(de)serializes databases.
|
||||
- [`github.com/ncruces/go-sqlite3/ext/statement`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/statement)
|
||||
creates [parameterized views](https://github.com/0x09/sqlite-statement-vtab).
|
||||
- [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
|
||||
provides [statistics](https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html) functions.
|
||||
provides [statistics](https://oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html) functions.
|
||||
- [`github.com/ncruces/go-sqlite3/ext/unicode`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/unicode)
|
||||
provides [Unicode aware](https://sqlite.org/src/dir/ext/icu) functions.
|
||||
- [`github.com/ncruces/go-sqlite3/ext/uuid`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/uuid)
|
||||
generates [UUIDs](https://en.wikipedia.org/wiki/Universally_unique_identifier).
|
||||
- [`github.com/ncruces/go-sqlite3/ext/zorder`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/zorder)
|
||||
maps multidimensional data to one dimension.
|
||||
maps multidimensional data to one dimension.
|
||||
|
||||
### Packages
|
||||
|
||||
These packages may also be useful to work with SQLite:
|
||||
|
||||
- [`github.com/ncruces/decimal`](https://pkg.go.dev/github.com/ncruces/decimal)
|
||||
decimal arithmetic.
|
||||
- [`github.com/ncruces/julianday`](https://pkg.go.dev/github.com/ncruces/julianday)
|
||||
Julian day math.
|
||||
|
||||
@@ -59,7 +59,8 @@ func (c *cursor) Next() error {
|
||||
}
|
||||
|
||||
func (c *cursor) RowID() (int64, error) {
|
||||
return int64(c.rowID), nil
|
||||
// One-based RowID for consistency with carray and other tables.
|
||||
return int64(c.rowID) + 1, nil
|
||||
}
|
||||
|
||||
func (c *cursor) Column(ctx sqlite3.Context, n int) error {
|
||||
|
||||
@@ -88,9 +88,9 @@ func Example() {
|
||||
|
||||
func Test_cursor_Column(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, array.Register)
|
||||
db, err := driver.Open(dsn, array.Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -44,6 +44,8 @@ func Register(db *sqlite3.Conn) error {
|
||||
type OpenCallback func(*sqlite3.Blob, ...sqlite3.Value) error
|
||||
|
||||
func readblob(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
_ = arg[5] // bounds check
|
||||
|
||||
blob, err := getAuxBlob(ctx, arg, false)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
@@ -78,6 +80,8 @@ func readblob(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
}
|
||||
|
||||
func writeblob(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
_ = arg[5] // bounds check
|
||||
|
||||
blob, err := getAuxBlob(ctx, arg, true)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package blobio_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -34,7 +35,8 @@ func Example() {
|
||||
const message = "Hello BLOB!"
|
||||
|
||||
// Create the BLOB.
|
||||
r, err := db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(len(message)))
|
||||
r, err := db.Exec(`INSERT INTO test VALUES (:data)`,
|
||||
sql.Named("data", sqlite3.ZeroBlob(len(message))))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -45,15 +47,19 @@ func Example() {
|
||||
}
|
||||
|
||||
// Write the BLOB.
|
||||
_, err = db.Exec(`SELECT writeblob('main', 'test', 'col', ?, 0, ?)`,
|
||||
id, message)
|
||||
_, err = db.Exec(`SELECT writeblob('main', 'test', 'col', :rowid, :offset, :message)`,
|
||||
sql.Named("rowid", id),
|
||||
sql.Named("offset", 0),
|
||||
sql.Named("message", message))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the BLOB.
|
||||
_, err = db.Exec(`SELECT readblob('main', 'test', 'col', ?, 0, ?)`,
|
||||
id, sqlite3.Pointer(os.Stdout))
|
||||
_, err = db.Exec(`SELECT readblob('main', 'test', 'col', :rowid, :offset, :writer)`,
|
||||
sql.Named("rowid", id),
|
||||
sql.Named("offset", 0),
|
||||
sql.Named("writer", sqlite3.Pointer(os.Stdout)))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -64,7 +70,7 @@ func Example() {
|
||||
func TestMain(m *testing.M) {
|
||||
sqlite3.AutoExtension(blobio.Register)
|
||||
sqlite3.AutoExtension(array.Register)
|
||||
m.Run()
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func Test_readblob(t *testing.T) {
|
||||
@@ -138,18 +144,16 @@ func Test_readblob(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
got := stmt.ColumnText(0)
|
||||
if got != tt.want1 {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnText(0); got != tt.want1 {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
got := stmt.ColumnText(0)
|
||||
if got != tt.want2 {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnText(0); got != tt.want2 {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
|
||||
err = stmt.Err()
|
||||
@@ -278,7 +282,7 @@ func Test_openblob(t *testing.T) {
|
||||
}
|
||||
|
||||
want := []string{"\xca\xfe", "\xba\xbe"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
if !slices.Equal(got, want) {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/ncruces/go-sqlite3/util/sql3util"
|
||||
)
|
||||
|
||||
// Register registers the bloom_filter virtual table:
|
||||
@@ -34,6 +35,8 @@ type bloom struct {
|
||||
hashes int
|
||||
}
|
||||
|
||||
const vtab = `CREATE TABLE x(present, word TEXT HIDDEN NOT NULL PRIMARY KEY) WITHOUT ROWID`
|
||||
|
||||
func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom, err error) {
|
||||
b := bloom{
|
||||
db: db,
|
||||
@@ -55,11 +58,9 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom,
|
||||
}
|
||||
|
||||
if len(arg) > 1 {
|
||||
b.prob, err = strconv.ParseFloat(arg[1], 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if b.prob <= 0 || b.prob >= 1 {
|
||||
var ok bool
|
||||
b.prob, ok = sql3util.ParseFloat(arg[1])
|
||||
if !ok || b.prob <= 0 || b.prob >= 1 {
|
||||
return nil, util.ErrorString("bloom: probability must be in the range (0,1)")
|
||||
}
|
||||
} else {
|
||||
@@ -80,8 +81,7 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom,
|
||||
|
||||
b.bytes = numBytes(nelem, b.prob)
|
||||
|
||||
err = db.DeclareVTab(
|
||||
`CREATE TABLE x(present, word HIDDEN NOT NULL PRIMARY KEY) WITHOUT ROWID`)
|
||||
err = db.DeclareVTab(vtab)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -115,8 +115,7 @@ func connect(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom
|
||||
storage: table + "_storage",
|
||||
}
|
||||
|
||||
err = db.DeclareVTab(
|
||||
`CREATE TABLE x(present, word HIDDEN NOT NULL PRIMARY KEY) WITHOUT ROWID`)
|
||||
err = db.DeclareVTab(vtab)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -232,7 +231,7 @@ func (b *bloom) Update(arg ...sqlite3.Value) (rowid int64, err error) {
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
for n := 0; n < b.hashes; n++ {
|
||||
for n := range b.hashes {
|
||||
hash := calcHash(n, blob)
|
||||
hash %= uint64(b.bytes * 8)
|
||||
bitpos := byte(hash % 8)
|
||||
@@ -268,13 +267,13 @@ func (b *bloom) Open() (sqlite3.VTabCursor, error) {
|
||||
|
||||
type cursor struct {
|
||||
*bloom
|
||||
arg *sqlite3.Value
|
||||
arg sqlite3.Value
|
||||
eof bool
|
||||
}
|
||||
|
||||
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
|
||||
c.eof = false
|
||||
c.arg = &arg[0]
|
||||
c.arg = arg[0]
|
||||
blob := arg[0].RawBlob()
|
||||
|
||||
f, err := c.db.OpenBlob(c.schema, c.storage, "data", 1, false)
|
||||
@@ -312,7 +311,7 @@ func (c *cursor) Column(ctx sqlite3.Context, n int) error {
|
||||
case 0:
|
||||
ctx.ResultBool(true)
|
||||
case 1:
|
||||
ctx.ResultValue(*c.arg)
|
||||
ctx.ResultValue(c.arg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
sqlite3.AutoExtension(bloom.Register)
|
||||
m.Run()
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
|
||||
@@ -56,7 +56,7 @@ func Register(db *sqlite3.Conn) error {
|
||||
done.Add(key)
|
||||
}
|
||||
|
||||
err := db.DeclareVTab(`CREATE TABLE x(id,depth,root HIDDEN,tablename HIDDEN,idcolumn HIDDEN,parentcolumn HIDDEN)`)
|
||||
err := db.DeclareVTab(`CREATE TABLE x(id INT,depth INT,root HIDDEN,tablename TEXT HIDDEN,idcolumn TEXT HIDDEN,parentcolumn TEXT HIDDEN)`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -84,10 +84,11 @@ func (c *closure) BestIndex(idx *sqlite3.IndexInfo) error {
|
||||
cost := 1e7
|
||||
|
||||
for i, cst := range idx.Constraint {
|
||||
if !cst.Usable {
|
||||
switch {
|
||||
case !cst.Usable:
|
||||
continue
|
||||
}
|
||||
if plan&1 == 0 && cst.Column == _COL_ROOT {
|
||||
|
||||
case plan&1 == 0 && cst.Column == _COL_ROOT:
|
||||
switch cst.Op {
|
||||
case sqlite3.INDEX_CONSTRAINT_EQ:
|
||||
plan |= 1
|
||||
@@ -97,9 +98,8 @@ func (c *closure) BestIndex(idx *sqlite3.IndexInfo) error {
|
||||
Omit: true,
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if plan&0xf0 == 0 && cst.Column == _COL_DEPTH {
|
||||
|
||||
case plan&0xf0 == 0 && cst.Column == _COL_DEPTH:
|
||||
switch cst.Op {
|
||||
case sqlite3.INDEX_CONSTRAINT_LT, sqlite3.INDEX_CONSTRAINT_LE, sqlite3.INDEX_CONSTRAINT_EQ:
|
||||
plan |= posi << 4
|
||||
@@ -110,9 +110,8 @@ func (c *closure) BestIndex(idx *sqlite3.IndexInfo) error {
|
||||
plan |= 2
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if plan&0xf00 == 0 && cst.Column == _COL_TABLENAME {
|
||||
|
||||
case plan&0xf00 == 0 && cst.Column == _COL_TABLENAME:
|
||||
switch cst.Op {
|
||||
case sqlite3.INDEX_CONSTRAINT_EQ:
|
||||
plan |= posi << 8
|
||||
@@ -123,9 +122,8 @@ func (c *closure) BestIndex(idx *sqlite3.IndexInfo) error {
|
||||
Omit: true,
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if plan&0xf000 == 0 && cst.Column == _COL_IDCOLUMN {
|
||||
|
||||
case plan&0xf000 == 0 && cst.Column == _COL_IDCOLUMN:
|
||||
switch cst.Op {
|
||||
case sqlite3.INDEX_CONSTRAINT_EQ:
|
||||
plan |= posi << 12
|
||||
@@ -135,9 +133,8 @@ func (c *closure) BestIndex(idx *sqlite3.IndexInfo) error {
|
||||
Omit: true,
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if plan&0xf0000 == 0 && cst.Column == _COL_PARENTCOLUMN {
|
||||
|
||||
case plan&0xf0000 == 0 && cst.Column == _COL_PARENTCOLUMN:
|
||||
switch cst.Op {
|
||||
case sqlite3.INDEX_CONSTRAINT_EQ:
|
||||
plan |= posi << 16
|
||||
@@ -147,7 +144,6 @@ func (c *closure) BestIndex(idx *sqlite3.IndexInfo) error {
|
||||
Omit: true,
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,6 +154,7 @@ func (c *closure) BestIndex(idx *sqlite3.IndexInfo) error {
|
||||
return sqlite3.CONSTRAINT
|
||||
}
|
||||
|
||||
idx.IdxFlags = sqlite3.INDEX_SCAN_HEX
|
||||
idx.EstimatedCost = cost
|
||||
idx.IdxNum = plan
|
||||
return nil
|
||||
@@ -214,12 +211,14 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
|
||||
c.nodes = []node{{root, 0}}
|
||||
set := util.Set[int64]{}
|
||||
set.Add(root)
|
||||
for i := 0; i < len(c.nodes); i++ {
|
||||
for i := range c.nodes {
|
||||
curr := c.nodes[i]
|
||||
if curr.depth >= maxDepth {
|
||||
continue
|
||||
}
|
||||
stmt.BindInt64(1, curr.id)
|
||||
if err := stmt.BindInt64(1, curr.id); err != nil {
|
||||
return err
|
||||
}
|
||||
for stmt.Step() {
|
||||
if stmt.ColumnType(0) == sqlite3.INTEGER {
|
||||
next := stmt.ColumnInt64(0)
|
||||
@@ -229,7 +228,9 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
stmt.Reset()
|
||||
if err := stmt.Reset(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
sqlite3.AutoExtension(closure.Register)
|
||||
m.Run()
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func Example() {
|
||||
|
||||
@@ -30,7 +30,7 @@ func Register(db *sqlite3.Conn) error {
|
||||
// RegisterFS registers the CSV virtual table.
|
||||
// If a filename is specified, fsys is used to open the file.
|
||||
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
|
||||
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (res *table, err error) {
|
||||
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
|
||||
var (
|
||||
filename string
|
||||
data string
|
||||
@@ -214,7 +214,10 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
|
||||
return err
|
||||
}
|
||||
if c.table.header {
|
||||
c.Next() // skip header
|
||||
err = c.Next() // skip header
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.rowID = 0
|
||||
return c.Next()
|
||||
@@ -251,19 +254,15 @@ func (c *cursor) Column(ctx sqlite3.Context, col int) error {
|
||||
|
||||
switch typ {
|
||||
case numeric, integer:
|
||||
if strings.TrimLeft(txt, "+-0123456789") == "" {
|
||||
if i, err := strconv.ParseInt(txt, 10, 64); err == nil {
|
||||
ctx.ResultInt64(i)
|
||||
return nil
|
||||
}
|
||||
if i, err := strconv.ParseInt(txt, 10, 64); err == nil {
|
||||
ctx.ResultInt64(i)
|
||||
return nil
|
||||
}
|
||||
fallthrough
|
||||
case real:
|
||||
if strings.TrimLeft(txt, "+-.0123456789Ee") == "" {
|
||||
if f, err := strconv.ParseFloat(txt, 64); err == nil {
|
||||
ctx.ResultFloat(f)
|
||||
return nil
|
||||
}
|
||||
if f, ok := sql3util.ParseFloat(txt); ok {
|
||||
ctx.ResultFloat(f)
|
||||
return nil
|
||||
}
|
||||
fallthrough
|
||||
default:
|
||||
|
||||
@@ -3,6 +3,7 @@ package csv_test
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
@@ -56,7 +57,7 @@ func Example() {
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
sqlite3.AutoExtension(csv.Register)
|
||||
m.Run()
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
@@ -146,20 +147,21 @@ func TestAffinity(t *testing.T) {
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnText(0); got != "1" {
|
||||
t.Errorf("got %q want 1", got)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnText(0); got != "1" {
|
||||
t.Errorf("got %q want 1", got)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnText(0); got != "0.1" {
|
||||
t.Errorf("got %q want 0.1", got)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnText(0); got != "0.1" {
|
||||
t.Errorf("got %q want 0.1", got)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnText(0); got != "e" {
|
||||
t.Errorf("got %q want e", got)
|
||||
}
|
||||
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnText(0); got != "e" {
|
||||
t.Errorf("got %q want e", got)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
//go:build !go1.23
|
||||
|
||||
package fileio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Adapted from: https://research.swtch.com/coro
|
||||
|
||||
const errCoroCanceled = util.ErrorString("coroutine canceled")
|
||||
|
||||
func coroNew[In, Out any](f func(In, func(Out) In) Out) (resume func(In) (Out, bool), cancel func()) {
|
||||
type msg[T any] struct {
|
||||
panic any
|
||||
val T
|
||||
}
|
||||
|
||||
cin := make(chan msg[In])
|
||||
cout := make(chan msg[Out])
|
||||
running := true
|
||||
resume = func(in In) (out Out, ok bool) {
|
||||
if !running {
|
||||
return
|
||||
}
|
||||
cin <- msg[In]{val: in}
|
||||
m := <-cout
|
||||
if m.panic != nil {
|
||||
panic(m.panic)
|
||||
}
|
||||
return m.val, running
|
||||
}
|
||||
cancel = func() {
|
||||
if !running {
|
||||
return
|
||||
}
|
||||
e := fmt.Errorf("%w", errCoroCanceled)
|
||||
cin <- msg[In]{panic: e}
|
||||
m := <-cout
|
||||
if m.panic != nil && m.panic != e {
|
||||
panic(m.panic)
|
||||
}
|
||||
}
|
||||
yield := func(out Out) In {
|
||||
cout <- msg[Out]{val: out}
|
||||
m := <-cin
|
||||
if m.panic != nil {
|
||||
panic(m.panic)
|
||||
}
|
||||
return m.val
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
if running {
|
||||
running = false
|
||||
cout <- msg[Out]{panic: recover()}
|
||||
}
|
||||
}()
|
||||
var out Out
|
||||
m := <-cin
|
||||
if m.panic == nil {
|
||||
out = f(m.val, yield)
|
||||
}
|
||||
running = false
|
||||
cout <- msg[Out]{val: out}
|
||||
}()
|
||||
return resume, cancel
|
||||
}
|
||||
@@ -18,7 +18,7 @@ func Register(db *sqlite3.Conn) error {
|
||||
return RegisterFS(db, nil)
|
||||
}
|
||||
|
||||
// Register registers SQL functions readfile, lsmode,
|
||||
// RegisterFS registers SQL functions readfile, lsmode,
|
||||
// and the table-valued function fsdir;
|
||||
// fsys will be used to read files and list directories.
|
||||
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
|
||||
@@ -30,7 +30,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
|
||||
db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys)),
|
||||
db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode),
|
||||
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
|
||||
err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
|
||||
err := db.DeclareVTab(`CREATE TABLE x(name TEXT,mode INT,mtime TIMESTAMP,data BLOB,path HIDDEN,dir HIDDEN)`)
|
||||
if err == nil {
|
||||
err = db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
|
||||
}
|
||||
@@ -42,7 +42,7 @@ func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
ctx.ResultText(fs.FileMode(arg[0].Int()).String())
|
||||
}
|
||||
|
||||
func readfile(fsys fs.FS) func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
func readfile(fsys fs.FS) sqlite3.ScalarFunction {
|
||||
return func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
var err error
|
||||
var data []byte
|
||||
|
||||
@@ -17,9 +17,9 @@ import (
|
||||
|
||||
func Test_lsmode(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, fileio.Register)
|
||||
db, err := driver.Open(dsn, fileio.Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -53,9 +53,9 @@ func Test_readfile(t *testing.T) {
|
||||
|
||||
for _, fsys := range []fs.FS{nil, os.DirFS(".")} {
|
||||
t.Run("", func(t *testing.T) {
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, func(c *sqlite3.Conn) error {
|
||||
db, err := driver.Open(dsn, func(c *sqlite3.Conn) error {
|
||||
fileio.RegisterFS(c, fsys)
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -2,6 +2,7 @@ package fileio
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"iter"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
@@ -62,12 +63,12 @@ func (d fsdir) Open() (sqlite3.VTabCursor, error) {
|
||||
|
||||
type cursor struct {
|
||||
fsdir
|
||||
base string
|
||||
resume resume
|
||||
cancel func()
|
||||
curr entry
|
||||
eof bool
|
||||
rowID int64
|
||||
base string
|
||||
next func() (entry, bool)
|
||||
stop func()
|
||||
curr entry
|
||||
eof bool
|
||||
rowID int64
|
||||
}
|
||||
|
||||
type entry struct {
|
||||
@@ -77,8 +78,8 @@ type entry struct {
|
||||
}
|
||||
|
||||
func (c *cursor) Close() error {
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
if c.stop != nil {
|
||||
c.stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -101,14 +102,26 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
|
||||
c.base = base
|
||||
}
|
||||
|
||||
c.resume, c.cancel = pull(c, root)
|
||||
c.next, c.stop = iter.Pull(func(yield func(entry) bool) {
|
||||
walkDir := func(path string, d fs.DirEntry, err error) error {
|
||||
if yield(entry{d, err, path}) {
|
||||
return nil
|
||||
}
|
||||
return fs.SkipAll
|
||||
}
|
||||
if c.fsys != nil {
|
||||
fs.WalkDir(c.fsys, root, walkDir)
|
||||
} else {
|
||||
filepath.WalkDir(root, walkDir)
|
||||
}
|
||||
})
|
||||
c.eof = false
|
||||
c.rowID = 0
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
func (c *cursor) Next() error {
|
||||
curr, ok := next(c)
|
||||
curr, ok := c.next()
|
||||
c.curr = curr
|
||||
c.eof = !ok
|
||||
c.rowID++
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
//go:build !go1.23
|
||||
|
||||
package fileio
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type resume = func(struct{}) (entry, bool)
|
||||
|
||||
func next(c *cursor) (entry, bool) {
|
||||
return c.resume(struct{}{})
|
||||
}
|
||||
|
||||
func pull(c *cursor, root string) (resume, func()) {
|
||||
return coroNew(func(_ struct{}, yield func(entry) struct{}) entry {
|
||||
walkDir := func(path string, d fs.DirEntry, err error) error {
|
||||
yield(entry{d, err, path})
|
||||
return nil
|
||||
}
|
||||
if c.fsys != nil {
|
||||
fs.WalkDir(c.fsys, root, walkDir)
|
||||
} else {
|
||||
filepath.WalkDir(root, walkDir)
|
||||
}
|
||||
return entry{}
|
||||
})
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
//go:build go1.23
|
||||
|
||||
package fileio
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"iter"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type resume = func() (entry, bool)
|
||||
|
||||
func next(c *cursor) (entry, bool) {
|
||||
return c.resume()
|
||||
}
|
||||
|
||||
func pull(c *cursor, root string) (resume, func()) {
|
||||
return iter.Pull(func(yield func(entry) bool) {
|
||||
walkDir := func(path string, d fs.DirEntry, err error) error {
|
||||
if yield(entry{d, err, path}) {
|
||||
return nil
|
||||
}
|
||||
return fs.SkipAll
|
||||
}
|
||||
if c.fsys != nil {
|
||||
fs.WalkDir(c.fsys, root, walkDir)
|
||||
} else {
|
||||
filepath.WalkDir(root, walkDir)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -21,9 +21,9 @@ func Test_fsdir(t *testing.T) {
|
||||
|
||||
for _, fsys := range []fs.FS{nil, os.DirFS(".")} {
|
||||
t.Run("", func(t *testing.T) {
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, func(c *sqlite3.Conn) error {
|
||||
db, err := driver.Open(dsn, func(c *sqlite3.Conn) error {
|
||||
fileio.RegisterFS(c, fsys)
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -15,9 +15,9 @@ import (
|
||||
|
||||
func Test_writefile(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, Register)
|
||||
db, err := driver.Open(dsn, Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
_ "crypto/md5"
|
||||
_ "crypto/sha1"
|
||||
_ "crypto/sha256"
|
||||
_ "crypto/sha3"
|
||||
_ "crypto/sha512"
|
||||
"testing"
|
||||
|
||||
@@ -11,7 +12,6 @@ import (
|
||||
_ "golang.org/x/crypto/blake2s"
|
||||
_ "golang.org/x/crypto/md4"
|
||||
_ "golang.org/x/crypto/ripemd160"
|
||||
_ "golang.org/x/crypto/sha3"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -55,7 +55,7 @@ func TestRegister(t *testing.T) {
|
||||
{"blake2b('', 256)", "0E5751C026E543B2E8AB2EB06099DAA1D1E5DF47778F7787FAAB45CDF12FE3A8"},
|
||||
}
|
||||
|
||||
db, err := driver.Open(tmp, Register)
|
||||
db, err := driver.Open(dsn, Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
113
ext/ipaddr/ipaddr.go
Normal file
113
ext/ipaddr/ipaddr.go
Normal file
@@ -0,0 +1,113 @@
|
||||
// Package ipaddr provides functions to manipulate IPs and CIDRs.
|
||||
//
|
||||
// It provides the following functions:
|
||||
// - ipcontains(prefix, ip)
|
||||
// - ipoverlaps(prefix1, prefix2)
|
||||
// - ipfamily(ip/prefix)
|
||||
// - iphost(ip/prefix)
|
||||
// - ipmasklen(prefix)
|
||||
// - ipnetwork(prefix)
|
||||
package ipaddr
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
// Register IP/CIDR functions for a database connection.
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
return errors.Join(
|
||||
db.CreateFunction("ipcontains", 2, flags, contains),
|
||||
db.CreateFunction("ipoverlaps", 2, flags, overlaps),
|
||||
db.CreateFunction("ipfamily", 1, flags, family),
|
||||
db.CreateFunction("iphost", 1, flags, host),
|
||||
db.CreateFunction("ipmasklen", 1, flags, masklen),
|
||||
db.CreateFunction("ipnetwork", 1, flags, network))
|
||||
}
|
||||
|
||||
func contains(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
prefix, err := netip.ParsePrefix(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
addr, err := netip.ParseAddr(arg[1].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
ctx.ResultBool(prefix.Contains(addr))
|
||||
}
|
||||
|
||||
func overlaps(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
prefix1, err := netip.ParsePrefix(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
prefix2, err := netip.ParsePrefix(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
ctx.ResultBool(prefix1.Overlaps(prefix2))
|
||||
}
|
||||
|
||||
func family(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
addr, err := addr(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
switch {
|
||||
case addr.Is4():
|
||||
ctx.ResultInt(4)
|
||||
case addr.Is6():
|
||||
ctx.ResultInt(6)
|
||||
}
|
||||
}
|
||||
|
||||
func host(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
addr, err := addr(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
buf, _ := addr.MarshalText()
|
||||
ctx.ResultRawText(buf)
|
||||
}
|
||||
|
||||
func masklen(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
prefix, err := netip.ParsePrefix(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
ctx.ResultInt(prefix.Bits())
|
||||
}
|
||||
|
||||
func network(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
prefix, err := netip.ParsePrefix(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
buf, _ := prefix.Masked().MarshalText()
|
||||
ctx.ResultRawText(buf)
|
||||
}
|
||||
|
||||
func addr(text string) (netip.Addr, error) {
|
||||
addr, err := netip.ParseAddr(text)
|
||||
if err != nil {
|
||||
if prefix, err := netip.ParsePrefix(text); err == nil {
|
||||
return prefix.Addr(), nil
|
||||
}
|
||||
if addrpt, err := netip.ParseAddrPort(text); err == nil {
|
||||
return addrpt.Addr(), nil
|
||||
}
|
||||
}
|
||||
return addr, err
|
||||
}
|
||||
88
ext/ipaddr/ipaddr_test.go
Normal file
88
ext/ipaddr/ipaddr_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package ipaddr_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/ext/ipaddr"
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
"github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(dsn, ipaddr.Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
var got string
|
||||
|
||||
err = db.QueryRow(`SELECT ipfamily('::1')`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "6" {
|
||||
t.Fatalf("got %s", got)
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT ipfamily('[::1]:80')`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "6" {
|
||||
t.Fatalf("got %s", got)
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT ipfamily('192.168.1.5/24')`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "4" {
|
||||
t.Fatalf("got %s", got)
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT iphost('192.168.1.5/24')`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "192.168.1.5" {
|
||||
t.Fatalf("got %s", got)
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT ipmasklen('192.168.1.5/24')`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "24" {
|
||||
t.Fatalf("got %s", got)
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT ipnetwork('192.168.1.5/24')`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "192.168.1.0/24" {
|
||||
t.Fatalf("got %s", got)
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT ipcontains('192.168.1.0/24', '192.168.1.5')`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "1" {
|
||||
t.Fatalf("got %s", got)
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT ipoverlaps('192.168.1.0/24', '192.168.1.5/32')`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "1" {
|
||||
t.Fatalf("got %s", got)
|
||||
}
|
||||
}
|
||||
@@ -38,7 +38,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
|
||||
return errors.Join(
|
||||
sqlite3.CreateModule(db, "lines", nil,
|
||||
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
|
||||
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
|
||||
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN, delim HIDDEN)`)
|
||||
if err == nil {
|
||||
err = db.VTabConfig(sqlite3.VTAB_INNOCUOUS)
|
||||
}
|
||||
@@ -46,7 +46,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
|
||||
}),
|
||||
sqlite3.CreateModule(db, "lines_read", nil,
|
||||
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
|
||||
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
|
||||
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN, delim HIDDEN)`)
|
||||
if err == nil {
|
||||
err = db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
|
||||
}
|
||||
@@ -58,19 +58,29 @@ type lines struct {
|
||||
fsys fs.FS
|
||||
}
|
||||
|
||||
func (l lines) BestIndex(idx *sqlite3.IndexInfo) error {
|
||||
func (l lines) BestIndex(idx *sqlite3.IndexInfo) (err error) {
|
||||
err = sqlite3.CONSTRAINT
|
||||
for i, cst := range idx.Constraint {
|
||||
if cst.Column == 1 && cst.Op == sqlite3.INDEX_CONSTRAINT_EQ && cst.Usable {
|
||||
if !cst.Usable || cst.Op != sqlite3.INDEX_CONSTRAINT_EQ {
|
||||
continue
|
||||
}
|
||||
switch cst.Column {
|
||||
case 1:
|
||||
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
|
||||
Omit: true,
|
||||
ArgvIndex: 1,
|
||||
}
|
||||
idx.EstimatedCost = 1e6
|
||||
idx.EstimatedRows = 100
|
||||
return nil
|
||||
err = nil
|
||||
case 2:
|
||||
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
|
||||
Omit: true,
|
||||
ArgvIndex: 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
return sqlite3.CONSTRAINT
|
||||
return err
|
||||
}
|
||||
|
||||
func (l lines) Open() (sqlite3.VTabCursor, error) {
|
||||
@@ -85,6 +95,7 @@ type cursor struct {
|
||||
line []byte
|
||||
rowID int64
|
||||
eof bool
|
||||
delim byte
|
||||
}
|
||||
|
||||
func (c *cursor) EOF() bool {
|
||||
@@ -140,6 +151,15 @@ func (c *reader) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
|
||||
return fmt.Errorf("lines: unsupported argument:%.0w %v", sqlite3.MISMATCH, typ)
|
||||
}
|
||||
|
||||
c.delim = '\n'
|
||||
if len(arg) > 1 {
|
||||
b := arg[1].RawText()
|
||||
if len(b) != 1 {
|
||||
return fmt.Errorf("lines: delimiter must be a single byte%.0w", sqlite3.MISMATCH)
|
||||
}
|
||||
c.delim = b[0]
|
||||
}
|
||||
|
||||
c.reader = bufio.NewReader(r)
|
||||
c.closer, _ = r.(io.Closer)
|
||||
c.rowID = 0
|
||||
@@ -150,7 +170,12 @@ func (c *reader) Next() (err error) {
|
||||
c.line = c.line[:0]
|
||||
for more := true; more; {
|
||||
var line []byte
|
||||
line, more, err = c.reader.ReadLine()
|
||||
if c.delim == '\n' {
|
||||
line, more, err = c.reader.ReadLine()
|
||||
} else {
|
||||
line, err = c.reader.ReadSlice(c.delim)
|
||||
more = err == bufio.ErrBufferFull
|
||||
}
|
||||
c.line = append(c.line, line...)
|
||||
}
|
||||
if err == io.EOF {
|
||||
@@ -177,18 +202,27 @@ func (c *buffer) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
|
||||
return fmt.Errorf("lines: unsupported argument:%.0w %v", sqlite3.MISMATCH, typ)
|
||||
}
|
||||
|
||||
c.delim = '\n'
|
||||
if len(arg) > 1 {
|
||||
b := arg[1].RawText()
|
||||
if len(b) != 1 {
|
||||
return fmt.Errorf("lines: delimiter must be a single byte%.0w", sqlite3.MISMATCH)
|
||||
}
|
||||
c.delim = b[0]
|
||||
}
|
||||
|
||||
c.rowID = 0
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
func (c *buffer) Next() error {
|
||||
i := bytes.IndexByte(c.data, '\n')
|
||||
i := bytes.IndexByte(c.data, c.delim)
|
||||
j := i + 1
|
||||
switch {
|
||||
case i < 0:
|
||||
i = len(c.data)
|
||||
j = i
|
||||
case i > 0 && c.data[i-1] == '\r':
|
||||
case i > 0 && c.delim == '\n' && c.data[i-1] == '\r':
|
||||
i--
|
||||
}
|
||||
c.eof = len(c.data) == 0
|
||||
|
||||
@@ -67,9 +67,9 @@ func Example() {
|
||||
|
||||
func Test_lines(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, lines.Register)
|
||||
db, err := driver.Open(dsn, lines.Register)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -98,9 +98,9 @@ func Test_lines(t *testing.T) {
|
||||
|
||||
func Test_lines_error(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, lines.Register)
|
||||
db, err := driver.Open(dsn, lines.Register)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -123,9 +123,9 @@ func Test_lines_error(t *testing.T) {
|
||||
|
||||
func Test_lines_read(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, lines.Register)
|
||||
db, err := driver.Open(dsn, lines.Register)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -155,15 +155,15 @@ func Test_lines_read(t *testing.T) {
|
||||
|
||||
func Test_lines_test(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, lines.Register)
|
||||
db, err := driver.Open(dsn, lines.Register)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows, err := db.Query(`SELECT rowid, line FROM lines_read(?)`, "lines_test.go")
|
||||
rows, err := db.Query(`SELECT rowid, line FROM lines_read(?, '}')`, "lines_test.go")
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
t.Skip(err)
|
||||
}
|
||||
|
||||
@@ -25,14 +25,14 @@ type table struct {
|
||||
cols []*sqlite3.Value
|
||||
}
|
||||
|
||||
func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (res *table, err error) {
|
||||
func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (ret *table, err error) {
|
||||
if len(arg) != 3 {
|
||||
return nil, fmt.Errorf("pivot: wrong number of arguments")
|
||||
}
|
||||
|
||||
t := &table{db: db}
|
||||
defer func() {
|
||||
if res == nil {
|
||||
if ret == nil {
|
||||
t.Close()
|
||||
}
|
||||
}()
|
||||
@@ -55,6 +55,8 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (res *table, err e
|
||||
t.keys[i] = name
|
||||
create.WriteString(sep)
|
||||
create.WriteString(name)
|
||||
create.WriteString(" ")
|
||||
create.WriteString(stmt.ColumnDeclType(i))
|
||||
sep = ","
|
||||
}
|
||||
stmt.Close()
|
||||
@@ -71,8 +73,11 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (res *table, err e
|
||||
for stmt.Step() {
|
||||
name := sqlite3.QuoteIdentifier(stmt.ColumnText(1))
|
||||
t.cols = append(t.cols, stmt.ColumnValue(0).Dup())
|
||||
create.WriteString(",")
|
||||
create.WriteString(sep)
|
||||
create.WriteString(name)
|
||||
create.WriteString(" ")
|
||||
create.WriteString(stmt.ColumnDeclType(1))
|
||||
sep = ","
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
@@ -99,10 +104,11 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (res *table, err e
|
||||
}
|
||||
|
||||
func (t *table) Close() error {
|
||||
var errs []error
|
||||
for _, c := range t.cols {
|
||||
c.Close()
|
||||
errs = append(errs, c.Close())
|
||||
}
|
||||
return nil
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
|
||||
@@ -206,7 +212,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
|
||||
func (c *cursor) Next() error {
|
||||
if c.scan.Step() {
|
||||
count := c.scan.ColumnCount()
|
||||
for i := 0; i < count; i++ {
|
||||
for i := range count {
|
||||
err := c.cell.BindValue(i+1, c.scan.ColumnValue(i))
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -3,6 +3,7 @@ package pivot_test
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -85,7 +86,7 @@ func Example() {
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
sqlite3.AutoExtension(pivot.Register)
|
||||
m.Run()
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
@@ -140,10 +141,10 @@ func TestRegister(t *testing.T) {
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnInt(0); got != 3 {
|
||||
t.Errorf("got %d, want 3", got)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnInt(0); got != 3 {
|
||||
t.Errorf("got %d, want 3", got)
|
||||
}
|
||||
|
||||
err = db.Exec(`ALTER TABLE v_x RENAME TO v_y`)
|
||||
|
||||
@@ -16,7 +16,9 @@ package regexp
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
"regexp/syntax"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
@@ -50,33 +52,83 @@ func Register(db *sqlite3.Conn) error {
|
||||
// SELECT column WHERE column GLOB :glob_prefix AND column REGEXP :regexp
|
||||
//
|
||||
// [LIKE optimization]: https://sqlite.org/optoverview.html#the_like_optimization
|
||||
func GlobPrefix(re *regexp.Regexp) string {
|
||||
prefix, complete := re.LiteralPrefix()
|
||||
i := strings.IndexAny(prefix, "*?[")
|
||||
if i < 0 {
|
||||
if complete {
|
||||
return prefix
|
||||
}
|
||||
i = len(prefix)
|
||||
func GlobPrefix(expr string) string {
|
||||
re, err := syntax.Parse(expr, syntax.Perl)
|
||||
if err != nil {
|
||||
return "" // no match possible
|
||||
}
|
||||
return prefix[:i] + "*"
|
||||
prog, err := syntax.Compile(re.Simplify())
|
||||
if err != nil {
|
||||
return "" // notest
|
||||
}
|
||||
|
||||
i := &prog.Inst[prog.Start]
|
||||
|
||||
var empty syntax.EmptyOp
|
||||
loop1:
|
||||
for {
|
||||
switch i.Op {
|
||||
case syntax.InstFail:
|
||||
return "" // notest
|
||||
case syntax.InstCapture, syntax.InstNop:
|
||||
// skip
|
||||
case syntax.InstEmptyWidth:
|
||||
empty |= syntax.EmptyOp(i.Arg)
|
||||
default:
|
||||
break loop1
|
||||
}
|
||||
i = &prog.Inst[i.Out]
|
||||
}
|
||||
if empty&syntax.EmptyBeginText == 0 {
|
||||
return "*" // not anchored
|
||||
}
|
||||
|
||||
var glob strings.Builder
|
||||
loop2:
|
||||
for {
|
||||
switch i.Op {
|
||||
case syntax.InstFail:
|
||||
return "" // notest
|
||||
case syntax.InstCapture, syntax.InstEmptyWidth, syntax.InstNop:
|
||||
// skip
|
||||
case syntax.InstRune, syntax.InstRune1:
|
||||
if len(i.Rune) != 1 || syntax.Flags(i.Arg)&syntax.FoldCase != 0 {
|
||||
break loop2
|
||||
}
|
||||
switch r := i.Rune[0]; r {
|
||||
case '*', '?', '[', utf8.RuneError:
|
||||
break loop2
|
||||
default:
|
||||
glob.WriteRune(r)
|
||||
}
|
||||
default:
|
||||
break loop2
|
||||
}
|
||||
i = &prog.Inst[i.Out]
|
||||
}
|
||||
|
||||
glob.WriteByte('*')
|
||||
return glob.String()
|
||||
}
|
||||
|
||||
func load(ctx sqlite3.Context, i int, expr string) (*regexp.Regexp, error) {
|
||||
func load(ctx sqlite3.Context, arg []sqlite3.Value, i int) (*regexp.Regexp, error) {
|
||||
re, ok := ctx.GetAuxData(i).(*regexp.Regexp)
|
||||
if !ok {
|
||||
r, err := regexp.Compile(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
re, ok = arg[i].Pointer().(*regexp.Regexp)
|
||||
if !ok {
|
||||
r, err := regexp.Compile(arg[i].Text())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
re = r
|
||||
}
|
||||
re = r
|
||||
ctx.SetAuxData(0, r)
|
||||
ctx.SetAuxData(i, re)
|
||||
}
|
||||
return re, nil
|
||||
}
|
||||
|
||||
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, err := load(ctx, 0, arg[0].Text())
|
||||
re, err := load(ctx, arg, 0)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
@@ -86,18 +138,17 @@ func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
}
|
||||
|
||||
func regexLike(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, err := load(ctx, 1, arg[1].Text())
|
||||
re, err := load(ctx, arg, 1)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
|
||||
text := arg[0].RawText()
|
||||
ctx.ResultBool(re.Match(text))
|
||||
}
|
||||
|
||||
func regexCount(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, err := load(ctx, 1, arg[1].Text())
|
||||
re, err := load(ctx, arg, 1)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
@@ -112,7 +163,7 @@ func regexCount(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
}
|
||||
|
||||
func regexSubstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, err := load(ctx, 1, arg[1].Text())
|
||||
re, err := load(ctx, arg, 1)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
@@ -137,7 +188,7 @@ func regexSubstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
}
|
||||
|
||||
func regexInstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, err := load(ctx, 1, arg[1].Text())
|
||||
re, err := load(ctx, arg, 1)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
@@ -165,14 +216,14 @@ func regexInstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
}
|
||||
|
||||
func regexReplace(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, err := load(ctx, 1, arg[1].Text())
|
||||
re, err := load(ctx, arg, 1)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
|
||||
text := arg[0].RawText()
|
||||
repl := arg[2].RawText()
|
||||
text := arg[0].RawText()
|
||||
var pos, n int
|
||||
if len(arg) > 3 {
|
||||
pos = arg[3].Int()
|
||||
|
||||
@@ -3,8 +3,10 @@ package regexp
|
||||
import (
|
||||
"database/sql"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
@@ -13,9 +15,9 @@ import (
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, Register)
|
||||
db, err := driver.Open(dsn, Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -36,7 +38,7 @@ func TestRegister(t *testing.T) {
|
||||
{`regexp_instr('Hello', '.', 6)`, ""},
|
||||
{`regexp_substr('Hello', 'el.')`, "ell"},
|
||||
{`regexp_replace('Hello', 'llo', 'll')`, "Hell"},
|
||||
// https://www.postgresql.org/docs/current/functions-matching.html
|
||||
// https://postgresql.org/docs/current/functions-matching.html
|
||||
{`regexp_count('ABCABCAXYaxy', 'A.')`, "3"},
|
||||
{`regexp_count('ABCABCAXYaxy', '(?i)A.', 1)`, "4"},
|
||||
{`regexp_instr('number of your street, town zip, FR', '[^,]+', 1, 2)`, "23"},
|
||||
@@ -78,9 +80,9 @@ func TestRegister(t *testing.T) {
|
||||
|
||||
func TestRegister_errors(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, Register)
|
||||
db, err := driver.Open(dsn, Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -103,24 +105,81 @@ func TestRegister_errors(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_pointer(t *testing.T) {
|
||||
t.Parallel()
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(dsn, Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
var got int
|
||||
err = db.QueryRow(`SELECT regexp_count('ABCABCAXYaxy', ?, 1)`,
|
||||
sqlite3.Pointer(regexp.MustCompile(`(?i)A.`))).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != 4 {
|
||||
t.Errorf("got %d, want %d", got, 4)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
re string
|
||||
want string
|
||||
}{
|
||||
{``, ""},
|
||||
{`a`, "a"},
|
||||
{`a*`, "*"},
|
||||
{`a+`, "a*"},
|
||||
{`ab*`, "a*"},
|
||||
{`ab+`, "ab*"},
|
||||
{`a\?b`, "a*"},
|
||||
{`[`, ""},
|
||||
{``, "*"},
|
||||
{`^`, "*"},
|
||||
{`a`, "*"},
|
||||
{`ab`, "*"},
|
||||
{`^a`, "a*"},
|
||||
{`^a*`, "*"},
|
||||
{`^a+`, "a*"},
|
||||
{`^ab*`, "a*"},
|
||||
{`^ab+`, "ab*"},
|
||||
{`^a\?b`, "a*"},
|
||||
{`^[a-z]`, "*"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.re, func(t *testing.T) {
|
||||
if got := GlobPrefix(regexp.MustCompile(tt.re)); got != tt.want {
|
||||
t.Errorf("GlobPrefix() = %v, want %v", got, tt.want)
|
||||
if got := GlobPrefix(tt.re); got != tt.want {
|
||||
t.Errorf("GlobPrefix(%v) = %v, want %v", tt.re, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzGlobPrefix(f *testing.F) {
|
||||
f.Add(``, ``)
|
||||
f.Add(`[`, ``)
|
||||
f.Add(`^`, ``)
|
||||
f.Add(`a`, `a`)
|
||||
f.Add(`ab`, `b`)
|
||||
f.Add(`^a`, `a`)
|
||||
f.Add(`^a*`, `ab`)
|
||||
f.Add(`^a+`, `ab`)
|
||||
f.Add(`^ab*`, `ab`)
|
||||
f.Add(`^ab+`, `ab`)
|
||||
f.Add(`^a\?b`, `ab`)
|
||||
f.Add(`^[a-z]`, `ab`)
|
||||
|
||||
f.Fuzz(func(t *testing.T, lit, str string) {
|
||||
re, err := regexp.Compile(lit)
|
||||
if err != nil {
|
||||
t.SkipNow()
|
||||
}
|
||||
if re.MatchString(str) {
|
||||
prefix, ok := strings.CutSuffix(GlobPrefix(lit), "*")
|
||||
if !ok {
|
||||
t.Fatalf("missing * after %q for %q with %q", prefix, lit, str)
|
||||
}
|
||||
if !strings.HasPrefix(str, prefix) {
|
||||
t.Fatalf("missing prefix %q for %q with %q", prefix, lit, str)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
72
ext/serdes/serdes.go
Normal file
72
ext/serdes/serdes.go
Normal file
@@ -0,0 +1,72 @@
|
||||
// Package serdes provides functions to (de)serialize databases.
|
||||
package serdes
|
||||
|
||||
import (
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/util/vfsutil"
|
||||
"github.com/ncruces/go-sqlite3/vfs"
|
||||
)
|
||||
|
||||
const vfsName = "github.com/ncruces/go-sqlite3/ext/serdes.sliceVFS"
|
||||
|
||||
func init() {
|
||||
vfs.Register(vfsName, sliceVFS{})
|
||||
}
|
||||
|
||||
var fileToOpen = make(chan *[]byte, 1)
|
||||
|
||||
// Serialize backs up a database into a byte slice.
|
||||
//
|
||||
// https://sqlite.org/c3ref/serialize.html
|
||||
func Serialize(db *sqlite3.Conn, schema string) ([]byte, error) {
|
||||
var file []byte
|
||||
fileToOpen <- &file
|
||||
err := db.Backup(schema, "file:serdes.db?nolock=1&vfs="+vfsName)
|
||||
return file, err
|
||||
}
|
||||
|
||||
// Deserialize restores a database from a byte slice,
|
||||
// DESTROYING any contents previously stored in schema.
|
||||
//
|
||||
// To non-destructively open a database from a byte slice,
|
||||
// consider alternatives like the ["reader"] or ["memdb"] VFSes.
|
||||
//
|
||||
// This differs from the similarly named SQLite API
|
||||
// in that it DOES NOT disconnect from schema
|
||||
// to reopen as an in-memory database.
|
||||
//
|
||||
// https://sqlite.org/c3ref/deserialize.html
|
||||
//
|
||||
// ["memdb"]: https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/memdb
|
||||
// ["reader"]: https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/readervfs
|
||||
func Deserialize(db *sqlite3.Conn, schema string, data []byte) error {
|
||||
fileToOpen <- &data
|
||||
return db.Restore(schema, "file:serdes.db?immutable=1&vfs="+vfsName)
|
||||
}
|
||||
|
||||
type sliceVFS struct{}
|
||||
|
||||
func (sliceVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
|
||||
if flags&vfs.OPEN_MAIN_DB == 0 || name != "serdes.db" {
|
||||
return nil, flags, sqlite3.CANTOPEN
|
||||
}
|
||||
select {
|
||||
case file := <-fileToOpen:
|
||||
return (*vfsutil.SliceFile)(file), flags | vfs.OPEN_MEMORY, nil
|
||||
default:
|
||||
return nil, flags, sqlite3.MISUSE
|
||||
}
|
||||
}
|
||||
|
||||
func (sliceVFS) Delete(name string, dirSync bool) error {
|
||||
// notest // no journals to delete
|
||||
return sqlite3.IOERR_DELETE
|
||||
}
|
||||
|
||||
func (sliceVFS) Access(name string, flag vfs.AccessFlag) (bool, error) {
|
||||
return name == "serdes.db", nil
|
||||
}
|
||||
|
||||
func (sliceVFS) FullPathname(name string) (string, error) {
|
||||
return name, nil
|
||||
}
|
||||
115
ext/serdes/serdes_test.go
Normal file
115
ext/serdes/serdes_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package serdes_test
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/ext/serdes"
|
||||
)
|
||||
|
||||
//go:embed testdata/wal.db
|
||||
var walDB []byte
|
||||
|
||||
func Test_wal(t *testing.T) {
|
||||
db, err := sqlite3.Open("testdata/wal.db")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
data, err := serdes.Serialize(db, "main")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
compareDBs(t, data, walDB)
|
||||
|
||||
err = serdes.Deserialize(db, "temp", walDB)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_northwind(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
input, err := httpGet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = serdes.Deserialize(db, "temp", input)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
output, err := serdes.Serialize(db, "temp")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
compareDBs(t, input, output)
|
||||
}
|
||||
|
||||
func compareDBs(t *testing.T, a, b []byte) {
|
||||
if len(a) != len(b) {
|
||||
t.Fatal("lengths are different")
|
||||
}
|
||||
for i := range a {
|
||||
// These may be different.
|
||||
switch {
|
||||
case 24 <= i && i < 28:
|
||||
// File change counter.
|
||||
continue
|
||||
case 40 <= i && i < 44:
|
||||
// Schema cookie.
|
||||
continue
|
||||
case 92 <= i && i < 100:
|
||||
// SQLite version that wrote the file.
|
||||
continue
|
||||
}
|
||||
if a[i] != b[i] {
|
||||
t.Errorf("difference at %d: %d %d", i, a[i], b[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func httpGet() ([]byte, error) {
|
||||
res, err := http.Get("https://github.com/jpwhite3/northwind-SQLite3/raw/refs/heads/main/dist/northwind.db")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
return io.ReadAll(res.Body)
|
||||
}
|
||||
|
||||
func TestOpen_errors(t *testing.T) {
|
||||
_, err := sqlite3.Open("file:test.db?vfs=github.com/ncruces/go-sqlite3/ext/serdes.sliceVFS")
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.CANTOPEN) {
|
||||
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
|
||||
}
|
||||
|
||||
_, err = sqlite3.Open("file:serdes.db?vfs=github.com/ncruces/go-sqlite3/ext/serdes.sliceVFS")
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.MISUSE) {
|
||||
t.Errorf("got %v, want sqlite3.MISUSE", err)
|
||||
}
|
||||
}
|
||||
BIN
ext/serdes/testdata/wal.db
vendored
Normal file
BIN
ext/serdes/testdata/wal.db
vendored
Normal file
Binary file not shown.
@@ -8,6 +8,7 @@ package statement
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
@@ -43,7 +44,7 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (*table, error) {
|
||||
var str strings.Builder
|
||||
str.WriteString("CREATE TABLE x(")
|
||||
outputs := stmt.ColumnCount()
|
||||
for i := 0; i < outputs; i++ {
|
||||
for i := range outputs {
|
||||
name := sqlite3.QuoteIdentifier(stmt.ColumnName(i))
|
||||
str.WriteString(sep)
|
||||
str.WriteString(name)
|
||||
@@ -150,17 +151,18 @@ type cursor struct {
|
||||
func (c *cursor) Close() error {
|
||||
if c.stmt == c.table.stmt {
|
||||
c.table.inuse = false
|
||||
c.stmt.ClearBindings()
|
||||
return c.stmt.Reset()
|
||||
return errors.Join(
|
||||
c.stmt.Reset(),
|
||||
c.stmt.ClearBindings())
|
||||
}
|
||||
return c.stmt.Close()
|
||||
}
|
||||
|
||||
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
|
||||
c.arg = arg
|
||||
c.rowID = 0
|
||||
c.stmt.ClearBindings()
|
||||
if err := c.stmt.Reset(); err != nil {
|
||||
err := errors.Join(
|
||||
c.stmt.Reset(),
|
||||
c.stmt.ClearBindings())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -183,6 +185,8 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.arg = append(c.arg[:0], arg...)
|
||||
c.rowID = 0
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package statement_test
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
@@ -50,7 +51,7 @@ func Example() {
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
sqlite3.AutoExtension(statement.Register)
|
||||
m.Run()
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
@@ -91,7 +92,9 @@ func TestRegister(t *testing.T) {
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else {
|
||||
x := stmt.ColumnInt(0)
|
||||
y := stmt.ColumnInt(1)
|
||||
hypot := stmt.ColumnInt(2)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# ANSI SQL Aggregate Functions
|
||||
|
||||
https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
|
||||
https://oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
|
||||
|
||||
## Built in aggregates
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ const (
|
||||
some
|
||||
)
|
||||
|
||||
func newBoolean(kind int) func() sqlite3.AggregateFunction {
|
||||
func newBoolean(kind int) sqlite3.AggregateConstructor {
|
||||
return func() sqlite3.AggregateFunction { return &boolean{kind: kind} }
|
||||
}
|
||||
|
||||
|
||||
@@ -37,7 +37,9 @@ func TestRegister_boolean(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else {
|
||||
if got := stmt.ColumnBool(0); got != true {
|
||||
t.Errorf("got %v, want true", got)
|
||||
}
|
||||
|
||||
19
ext/stats/kahan.go
Normal file
19
ext/stats/kahan.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package stats
|
||||
|
||||
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
|
||||
|
||||
type kahan struct{ hi, lo float64 }
|
||||
|
||||
func (k *kahan) add(x float64) {
|
||||
y := k.lo + x
|
||||
t := k.hi + y
|
||||
k.lo = y - (t - k.hi)
|
||||
k.hi = t
|
||||
}
|
||||
|
||||
func (k *kahan) sub(x float64) {
|
||||
y := k.lo - x
|
||||
t := k.hi + y
|
||||
k.lo = y - (t - k.hi)
|
||||
k.hi = t
|
||||
}
|
||||
121
ext/stats/mode.go
Normal file
121
ext/stats/mode.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func newMode() sqlite3.AggregateFunction {
|
||||
return &mode{}
|
||||
}
|
||||
|
||||
type mode struct {
|
||||
ints counter[int64]
|
||||
reals counter[float64]
|
||||
texts counter[string]
|
||||
blobs counter[string]
|
||||
}
|
||||
|
||||
func (m mode) Value(ctx sqlite3.Context) {
|
||||
var (
|
||||
typ = sqlite3.NULL
|
||||
max uint
|
||||
i64 int64
|
||||
f64 float64
|
||||
str string
|
||||
)
|
||||
for k, v := range m.ints {
|
||||
if v > max || v == max && k < i64 {
|
||||
typ = sqlite3.INTEGER
|
||||
max = v
|
||||
i64 = k
|
||||
}
|
||||
}
|
||||
for k, v := range m.reals {
|
||||
if v > max || v == max && k < f64 {
|
||||
typ = sqlite3.FLOAT
|
||||
max = v
|
||||
f64 = k
|
||||
}
|
||||
}
|
||||
for k, v := range m.texts {
|
||||
if v > max || v == max && typ == sqlite3.TEXT && k < str {
|
||||
typ = sqlite3.TEXT
|
||||
max = v
|
||||
str = k
|
||||
}
|
||||
}
|
||||
for k, v := range m.blobs {
|
||||
if v > max || v == max && typ == sqlite3.BLOB && k < str {
|
||||
typ = sqlite3.BLOB
|
||||
max = v
|
||||
str = k
|
||||
}
|
||||
}
|
||||
switch typ {
|
||||
case sqlite3.INTEGER:
|
||||
ctx.ResultInt64(i64)
|
||||
case sqlite3.FLOAT:
|
||||
ctx.ResultFloat(f64)
|
||||
case sqlite3.TEXT:
|
||||
ctx.ResultText(str)
|
||||
case sqlite3.BLOB:
|
||||
ctx.ResultBlob(unsafe.Slice(unsafe.StringData(str), len(str)))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mode) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
switch arg[0].Type() {
|
||||
case sqlite3.INTEGER:
|
||||
if m.reals == nil {
|
||||
m.ints.add(arg[0].Int64())
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
case sqlite3.FLOAT:
|
||||
m.reals.add(arg[0].Float())
|
||||
for k, v := range m.ints {
|
||||
m.reals[float64(k)] += v
|
||||
}
|
||||
m.ints = nil
|
||||
case sqlite3.TEXT:
|
||||
m.texts.add(arg[0].Text())
|
||||
case sqlite3.BLOB:
|
||||
m.blobs.add(string(arg[0].RawBlob()))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mode) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
switch arg[0].Type() {
|
||||
case sqlite3.INTEGER:
|
||||
if m.reals == nil {
|
||||
m.ints.del(arg[0].Int64())
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
case sqlite3.FLOAT:
|
||||
m.reals.del(arg[0].Float())
|
||||
case sqlite3.TEXT:
|
||||
m.texts.del(arg[0].Text())
|
||||
case sqlite3.BLOB:
|
||||
m.blobs.del(string(arg[0].RawBlob()))
|
||||
}
|
||||
}
|
||||
|
||||
type counter[T comparable] map[T]uint
|
||||
|
||||
func (c *counter[T]) add(k T) {
|
||||
if (*c) == nil {
|
||||
(*c) = make(counter[T])
|
||||
}
|
||||
(*c)[k]++
|
||||
}
|
||||
|
||||
func (c counter[T]) del(k T) {
|
||||
if n := c[k]; n == 1 {
|
||||
delete(c, k)
|
||||
} else {
|
||||
c[k] = n - 1
|
||||
}
|
||||
}
|
||||
102
ext/stats/mode_test.go
Normal file
102
ext/stats/mode_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package stats_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
)
|
||||
|
||||
func TestRegister_mode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT mode(column1) FROM (VALUES (NULL), (1), (NULL), (2), (NULL), (3), (3))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnInt(0); got != 3 {
|
||||
t.Errorf("got %v, want 3", got)
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (1), (1), (2), (2), (3))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnInt(0); got != 1 {
|
||||
t.Errorf("got %v, want 1", got)
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (0.5), (1), (2.5), (2), (2.5))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnFloat(0); got != 2.5 {
|
||||
t.Errorf("got %v, want 2.5", got)
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES ('red'), ('green'), ('blue'), ('red'))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnText(0); got != "red" {
|
||||
t.Errorf("got %q, want red", got)
|
||||
}
|
||||
|
||||
stmt.Close()
|
||||
|
||||
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (X'cafebabe'), ('green'), ('blue'), (X'cafebabe'))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnText(0); got != "\xca\xfe\xba\xbe" {
|
||||
t.Errorf("got %q, want cafebabe", got)
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
stmt, _, err = db.Prepare(`
|
||||
SELECT mode(column1) OVER (ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
|
||||
FROM (VALUES (1), (1), (2.5), ('blue'), (X'cafebabe'), (1), (1))
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for stmt.Step() {
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (?), (?), (?), (?), (?))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
stmt.BindInt(1, 1)
|
||||
stmt.BindInt(2, 1)
|
||||
stmt.BindInt(3, 2)
|
||||
stmt.BindFloat(4, 2)
|
||||
stmt.BindFloat(5, 2)
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnInt(0); got != 2 {
|
||||
t.Errorf("got %v, want 2", got)
|
||||
}
|
||||
stmt.Close()
|
||||
}
|
||||
101
ext/stats/moments.go
Normal file
101
ext/stats/moments.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package stats
|
||||
|
||||
import "math"
|
||||
|
||||
// Fisher–Pearson skewness and kurtosis using
|
||||
// Terriberry's algorithm with Kahan summation:
|
||||
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
|
||||
|
||||
type moments struct {
|
||||
m1, m2, m3, m4 kahan
|
||||
n int64
|
||||
}
|
||||
|
||||
func (m moments) mean() float64 {
|
||||
return m.m1.hi
|
||||
}
|
||||
|
||||
func (m moments) var_pop() float64 {
|
||||
return m.m2.hi / float64(m.n)
|
||||
}
|
||||
|
||||
func (m moments) var_samp() float64 {
|
||||
return m.m2.hi / float64(m.n-1) // Bessel's correction
|
||||
}
|
||||
|
||||
func (m moments) stddev_pop() float64 {
|
||||
return math.Sqrt(m.var_pop())
|
||||
}
|
||||
|
||||
func (m moments) stddev_samp() float64 {
|
||||
return math.Sqrt(m.var_samp())
|
||||
}
|
||||
|
||||
func (m moments) skewness_pop() float64 {
|
||||
m2 := m.m2.hi
|
||||
if div := m2 * m2 * m2; div != 0 {
|
||||
return m.m3.hi * math.Sqrt(float64(m.n)/div)
|
||||
}
|
||||
return math.NaN()
|
||||
}
|
||||
|
||||
func (m moments) skewness_samp() float64 {
|
||||
n := m.n
|
||||
// https://mathworks.com/help/stats/skewness.html#f1132178
|
||||
return m.skewness_pop() * math.Sqrt(float64(n*(n-1))) / float64(n-2)
|
||||
}
|
||||
|
||||
func (m moments) kurtosis_pop() float64 {
|
||||
return m.raw_kurtosis_pop() - 3
|
||||
}
|
||||
|
||||
func (m moments) raw_kurtosis_pop() float64 {
|
||||
m2 := m.m2.hi
|
||||
if div := m2 * m2; div != 0 {
|
||||
return m.m4.hi * float64(m.n) / div
|
||||
}
|
||||
return math.NaN()
|
||||
}
|
||||
|
||||
func (m moments) kurtosis_samp() float64 {
|
||||
n := m.n
|
||||
k := math.FMA(m.raw_kurtosis_pop(), float64(n+1), float64(3-3*n))
|
||||
return k * float64(n-1) / float64((n-2)*(n-3))
|
||||
}
|
||||
|
||||
func (m moments) raw_kurtosis_samp() float64 {
|
||||
n := m.n
|
||||
// https://mathworks.com/help/stats/kurtosis.html#f4975293
|
||||
k := math.FMA(m.raw_kurtosis_pop(), float64(n+1), float64(3-3*n))
|
||||
return math.FMA(k, float64(n-1)/float64((n-2)*(n-3)), 3)
|
||||
}
|
||||
|
||||
func (m *moments) enqueue(x float64) {
|
||||
n := m.n + 1
|
||||
m.n = n
|
||||
d1 := x - m.m1.hi - m.m1.lo
|
||||
dn := d1 / float64(n)
|
||||
d2 := dn * dn
|
||||
t1 := d1 * dn * float64(n-1)
|
||||
m.m4.add(t1*d2*float64(n*n-3*n+3) + 6*d2*m.m2.hi - 4*dn*m.m3.hi)
|
||||
m.m3.add(t1*dn*float64(n-2) - 3*dn*m.m2.hi)
|
||||
m.m2.add(t1)
|
||||
m.m1.add(dn)
|
||||
}
|
||||
|
||||
func (m *moments) dequeue(x float64) {
|
||||
n := m.n - 1
|
||||
if n <= 0 {
|
||||
*m = moments{}
|
||||
return
|
||||
}
|
||||
m.n = n
|
||||
d1 := x - m.m1.hi - m.m1.lo
|
||||
dn := d1 / float64(n)
|
||||
d2 := dn * dn
|
||||
t1 := d1 * dn * float64(n+1)
|
||||
m.m4.sub(t1*d2*float64(n*n+3*n+3) - 6*d2*m.m2.hi - 4*dn*m.m3.hi)
|
||||
m.m3.sub(t1*dn*float64(n+2) - 3*dn*m.m2.hi)
|
||||
m.m2.sub(t1)
|
||||
m.m1.sub(dn)
|
||||
}
|
||||
87
ext/stats/moments_test.go
Normal file
87
ext/stats/moments_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_moments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var s1 moments
|
||||
s1.enqueue(1)
|
||||
s1.dequeue(1)
|
||||
if !math.IsNaN(s1.skewness_pop()) {
|
||||
t.Errorf("want NaN")
|
||||
}
|
||||
if !math.IsNaN(s1.raw_kurtosis_pop()) {
|
||||
t.Errorf("want NaN")
|
||||
}
|
||||
|
||||
s1.enqueue(+0.5377)
|
||||
s1.enqueue(+1.8339)
|
||||
s1.enqueue(-2.2588)
|
||||
s1.enqueue(+0.8622)
|
||||
s1.enqueue(+0.3188)
|
||||
s1.enqueue(-1.3077)
|
||||
s1.enqueue(-0.4336)
|
||||
s1.enqueue(+0.3426)
|
||||
s1.enqueue(+3.5784)
|
||||
s1.enqueue(+2.7694)
|
||||
|
||||
if got := s1.skewness_pop(); float32(got) != 0.106098293 {
|
||||
t.Errorf("got %v, want 0.1061", got)
|
||||
}
|
||||
if got := s1.skewness_samp(); float32(got) != 0.1258171 {
|
||||
t.Errorf("got %v, want 0.1258", got)
|
||||
}
|
||||
if got := s1.raw_kurtosis_pop(); float32(got) != 2.3121266 {
|
||||
t.Errorf("got %v, want 2.3121", got)
|
||||
}
|
||||
if got := s1.raw_kurtosis_samp(); float32(got) != 2.7482237 {
|
||||
t.Errorf("got %v, want 2.7483", got)
|
||||
}
|
||||
|
||||
var s2 welford
|
||||
|
||||
s2.enqueue(+0.5377)
|
||||
s2.enqueue(+1.8339)
|
||||
s2.enqueue(-2.2588)
|
||||
s2.enqueue(+0.8622)
|
||||
s2.enqueue(+0.3188)
|
||||
s2.enqueue(-1.3077)
|
||||
s2.enqueue(-0.4336)
|
||||
s2.enqueue(+0.3426)
|
||||
s2.enqueue(+3.5784)
|
||||
s2.enqueue(+2.7694)
|
||||
|
||||
if got, want := s1.mean(), s2.mean(); got != want {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
if got, want := s1.stddev_pop(), s2.stddev_pop(); got != want {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
if got, want := s1.stddev_samp(), s2.stddev_samp(); got != want {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
|
||||
s1.enqueue(math.Pi)
|
||||
s1.enqueue(math.Sqrt2)
|
||||
s1.enqueue(math.E)
|
||||
s1.dequeue(math.Pi)
|
||||
s1.dequeue(math.E)
|
||||
s1.dequeue(math.Sqrt2)
|
||||
|
||||
if got := s1.skewness_pop(); float32(got) != 0.106098293 {
|
||||
t.Errorf("got %v, want 0.1061", got)
|
||||
}
|
||||
if got := s1.skewness_samp(); float32(got) != 0.1258171 {
|
||||
t.Errorf("got %v, want 0.1258", got)
|
||||
}
|
||||
if got := s1.raw_kurtosis_pop(); float32(got) != 2.3121266 {
|
||||
t.Errorf("got %v, want 2.3121", got)
|
||||
}
|
||||
if got := s1.raw_kurtosis_samp(); float32(got) != 2.7482237 {
|
||||
t.Errorf("got %v, want 2.7483", got)
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,9 @@ import (
|
||||
"github.com/ncruces/sort/quick"
|
||||
)
|
||||
|
||||
// Compatible with:
|
||||
// https://sqlite.org/src/file/ext/misc/percentile.c
|
||||
|
||||
const (
|
||||
median = iota
|
||||
percentile_100
|
||||
@@ -18,7 +21,7 @@ const (
|
||||
percentile_disc
|
||||
)
|
||||
|
||||
func newPercentile(kind int) func() sqlite3.AggregateFunction {
|
||||
func newPercentile(kind int) sqlite3.AggregateConstructor {
|
||||
return func() sqlite3.AggregateFunction { return &percentile{kind: kind} }
|
||||
}
|
||||
|
||||
|
||||
@@ -38,7 +38,9 @@ func TestRegister_percentile(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else {
|
||||
if got := stmt.ColumnFloat(0); got != 10 {
|
||||
t.Errorf("got %v, want 10", got)
|
||||
}
|
||||
@@ -65,30 +67,30 @@ func TestRegister_percentile(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 5.5 {
|
||||
t.Errorf("got %v, want 5.5", got)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnFloat(0); got != 5.5 {
|
||||
t.Errorf("got %v, want 5.5", got)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 7 {
|
||||
t.Errorf("got %v, want 7", got)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnFloat(0); got != 7 {
|
||||
t.Errorf("got %v, want 7", got)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 10 {
|
||||
t.Errorf("got %v, want 10", got)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnFloat(0); got != 10 {
|
||||
t.Errorf("got %v, want 10", got)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 14.5 {
|
||||
t.Errorf("got %v, want 14.5", got)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnFloat(0); got != 14.5 {
|
||||
t.Errorf("got %v, want 14.5", got)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 16 {
|
||||
t.Errorf("got %v, want 16", got)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnFloat(0); got != 16 {
|
||||
t.Errorf("got %v, want 16", got)
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
@@ -103,7 +105,9 @@ func TestRegister_percentile(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else {
|
||||
if got := stmt.ColumnFloat(0); got != 4 {
|
||||
t.Errorf("got %v, want 4", got)
|
||||
}
|
||||
@@ -134,7 +138,9 @@ func TestRegister_percentile(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.NULL {
|
||||
t.Error("want NULL")
|
||||
}
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
// Package stats provides aggregate functions for statistics.
|
||||
//
|
||||
// Provided functions:
|
||||
// - stddev_pop: population standard deviation
|
||||
// - stddev_samp: sample standard deviation
|
||||
// - var_pop: population variance
|
||||
// - var_samp: sample variance
|
||||
// - stddev_pop: population standard deviation
|
||||
// - stddev_samp: sample standard deviation
|
||||
// - skewness_pop: Pearson population skewness
|
||||
// - skewness_samp: Pearson sample skewness
|
||||
// - kurtosis_pop: Fisher population excess kurtosis
|
||||
// - kurtosis_samp: Fisher sample excess kurtosis
|
||||
// - covar_pop: population covariance
|
||||
// - covar_samp: sample covariance
|
||||
// - corr: correlation coefficient
|
||||
// - corr: Pearson correlation coefficient
|
||||
// - regr_r2: correlation coefficient squared
|
||||
// - regr_avgx: average of the independent variable
|
||||
// - regr_avgy: average of the dependent variable
|
||||
@@ -17,10 +21,12 @@
|
||||
// - regr_count: count non-null pairs of variables
|
||||
// - regr_slope: slope of the least-squares-fit linear equation
|
||||
// - regr_intercept: y-intercept of the least-squares-fit linear equation
|
||||
// - regr_json: all regr stats in a JSON object
|
||||
// - percentile_disc: discrete percentile
|
||||
// - percentile_cont: continuous percentile
|
||||
// - median: median value
|
||||
// - regr_json: all regr stats as a JSON object
|
||||
// - percentile_disc: discrete quantile
|
||||
// - percentile_cont: continuous quantile
|
||||
// - percentile: continuous percentile
|
||||
// - median: middle value
|
||||
// - mode: most frequent value
|
||||
// - every: boolean and
|
||||
// - some: boolean or
|
||||
//
|
||||
@@ -41,7 +47,7 @@
|
||||
//
|
||||
// [Built-in Aggregate Functions]: https://sqlite.org/lang_aggfunc.html
|
||||
// [Built-in Window Functions]: https://sqlite.org/windowfunctions.html#builtins
|
||||
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
|
||||
// [ANSI SQL Aggregate Functions]: https://oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
|
||||
package stats
|
||||
|
||||
import (
|
||||
@@ -52,13 +58,20 @@ import (
|
||||
|
||||
// Register registers statistics functions.
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
const order = sqlite3.SELFORDER1 | flags
|
||||
const (
|
||||
flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
json = sqlite3.RESULT_SUBTYPE | flags
|
||||
order = sqlite3.SELFORDER1 | flags
|
||||
)
|
||||
return errors.Join(
|
||||
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop)),
|
||||
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp)),
|
||||
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop)),
|
||||
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp)),
|
||||
db.CreateWindowFunction("skewness_pop", 1, flags, newMoments(skewness_pop)),
|
||||
db.CreateWindowFunction("skewness_samp", 1, flags, newMoments(skewness_samp)),
|
||||
db.CreateWindowFunction("kurtosis_pop", 1, flags, newMoments(kurtosis_pop)),
|
||||
db.CreateWindowFunction("kurtosis_samp", 1, flags, newMoments(kurtosis_samp)),
|
||||
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop)),
|
||||
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp)),
|
||||
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr)),
|
||||
@@ -71,13 +84,14 @@ func Register(db *sqlite3.Conn) error {
|
||||
db.CreateWindowFunction("regr_slope", 2, flags, newCovariance(regr_slope)),
|
||||
db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept)),
|
||||
db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count)),
|
||||
db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json)),
|
||||
db.CreateWindowFunction("regr_json", 2, json, newCovariance(regr_json)),
|
||||
db.CreateWindowFunction("median", 1, order, newPercentile(median)),
|
||||
db.CreateWindowFunction("percentile", 2, order, newPercentile(percentile_100)),
|
||||
db.CreateWindowFunction("percentile_cont", 2, order, newPercentile(percentile_cont)),
|
||||
db.CreateWindowFunction("percentile_disc", 2, order, newPercentile(percentile_disc)),
|
||||
db.CreateWindowFunction("every", 1, flags, newBoolean(every)),
|
||||
db.CreateWindowFunction("some", 1, flags, newBoolean(some)))
|
||||
db.CreateWindowFunction("some", 1, flags, newBoolean(some)),
|
||||
db.CreateWindowFunction("mode", 1, order, newMode))
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -85,6 +99,10 @@ const (
|
||||
var_samp
|
||||
stddev_pop
|
||||
stddev_samp
|
||||
skewness_pop
|
||||
skewness_samp
|
||||
kurtosis_pop
|
||||
kurtosis_samp
|
||||
corr
|
||||
regr_r2
|
||||
regr_sxx
|
||||
@@ -98,7 +116,24 @@ const (
|
||||
regr_json
|
||||
)
|
||||
|
||||
func newVariance(kind int) func() sqlite3.AggregateFunction {
|
||||
func special(kind int, n int64) (null, zero bool) {
|
||||
switch kind {
|
||||
case var_pop, stddev_pop, regr_sxx, regr_syy, regr_sxy:
|
||||
return n <= 0, n == 1
|
||||
case regr_avgx, regr_avgy:
|
||||
return n <= 0, false
|
||||
case kurtosis_samp:
|
||||
return n <= 3, false
|
||||
case skewness_samp:
|
||||
return n <= 2, false
|
||||
case skewness_pop:
|
||||
return n <= 1, n == 2
|
||||
default:
|
||||
return n <= 1, false
|
||||
}
|
||||
}
|
||||
|
||||
func newVariance(kind int) sqlite3.AggregateConstructor {
|
||||
return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
|
||||
}
|
||||
|
||||
@@ -108,6 +143,14 @@ type variance struct {
|
||||
}
|
||||
|
||||
func (fn *variance) Value(ctx sqlite3.Context) {
|
||||
switch null, zero := special(fn.kind, fn.n); {
|
||||
case zero:
|
||||
ctx.ResultFloat(0)
|
||||
return
|
||||
case null:
|
||||
return
|
||||
}
|
||||
|
||||
var r float64
|
||||
switch fn.kind {
|
||||
case var_pop:
|
||||
@@ -138,7 +181,7 @@ func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
}
|
||||
}
|
||||
|
||||
func newCovariance(kind int) func() sqlite3.AggregateFunction {
|
||||
func newCovariance(kind int) sqlite3.AggregateConstructor {
|
||||
return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
|
||||
}
|
||||
|
||||
@@ -148,6 +191,18 @@ type covariance struct {
|
||||
}
|
||||
|
||||
func (fn *covariance) Value(ctx sqlite3.Context) {
|
||||
if fn.kind == regr_count {
|
||||
ctx.ResultInt64(fn.regr_count())
|
||||
return
|
||||
}
|
||||
switch null, zero := special(fn.kind, fn.n); {
|
||||
case zero:
|
||||
ctx.ResultFloat(0)
|
||||
return
|
||||
case null:
|
||||
return
|
||||
}
|
||||
|
||||
var r float64
|
||||
switch fn.kind {
|
||||
case var_pop:
|
||||
@@ -172,11 +227,10 @@ func (fn *covariance) Value(ctx sqlite3.Context) {
|
||||
r = fn.regr_slope()
|
||||
case regr_intercept:
|
||||
r = fn.regr_intercept()
|
||||
case regr_count:
|
||||
ctx.ResultInt64(fn.regr_count())
|
||||
return
|
||||
case regr_json:
|
||||
ctx.ResultText(fn.regr_json())
|
||||
var buf [128]byte
|
||||
ctx.ResultRawText(fn.regr_json(buf[:0]))
|
||||
ctx.ResultSubtype('J')
|
||||
return
|
||||
}
|
||||
ctx.ResultFloat(r)
|
||||
@@ -203,3 +257,51 @@ func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
fn.dequeue(fa, fb)
|
||||
}
|
||||
}
|
||||
|
||||
func newMoments(kind int) sqlite3.AggregateConstructor {
|
||||
return func() sqlite3.AggregateFunction { return &momentfn{kind: kind} }
|
||||
}
|
||||
|
||||
type momentfn struct {
|
||||
kind int
|
||||
moments
|
||||
}
|
||||
|
||||
func (fn *momentfn) Value(ctx sqlite3.Context) {
|
||||
switch null, zero := special(fn.kind, fn.n); {
|
||||
case zero:
|
||||
ctx.ResultFloat(0)
|
||||
return
|
||||
case null:
|
||||
return
|
||||
}
|
||||
|
||||
var r float64
|
||||
switch fn.kind {
|
||||
case skewness_pop:
|
||||
r = fn.skewness_pop()
|
||||
case skewness_samp:
|
||||
r = fn.skewness_samp()
|
||||
case kurtosis_pop:
|
||||
r = fn.kurtosis_pop()
|
||||
case kurtosis_samp:
|
||||
r = fn.kurtosis_samp()
|
||||
}
|
||||
ctx.ResultFloat(r)
|
||||
}
|
||||
|
||||
func (fn *momentfn) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
a := arg[0]
|
||||
f := a.Float()
|
||||
if f != 0.0 || a.NumericType() != sqlite3.NULL {
|
||||
fn.enqueue(f)
|
||||
}
|
||||
}
|
||||
|
||||
func (fn *momentfn) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
a := arg[0]
|
||||
f := a.Float()
|
||||
if f != 0.0 || a.NumericType() != sqlite3.NULL {
|
||||
fn.dequeue(f)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package stats_test
|
||||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
sqlite3.AutoExtension(stats.Register)
|
||||
m.Run()
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestRegister_variance(t *testing.T) {
|
||||
@@ -29,21 +30,36 @@ func TestRegister_variance(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT stddev_pop(x) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else if got := stmt.ColumnType(0); got != sqlite3.NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`
|
||||
stmt, _, err = db.Prepare(`
|
||||
SELECT
|
||||
sum(x), avg(x),
|
||||
var_samp(x), var_pop(x),
|
||||
stddev_samp(x), stddev_pop(x)
|
||||
stddev_samp(x), stddev_pop(x),
|
||||
skewness_samp(x), skewness_pop(x),
|
||||
kurtosis_samp(x), kurtosis_pop(x)
|
||||
FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else {
|
||||
if got := stmt.ColumnFloat(0); got != 40 {
|
||||
t.Errorf("got %v, want 40", got)
|
||||
}
|
||||
@@ -62,10 +78,27 @@ func TestRegister_variance(t *testing.T) {
|
||||
if got := stmt.ColumnFloat(5); got != math.Sqrt(22.5) {
|
||||
t.Errorf("got %v, want √22.5", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(6); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(7); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(8); float32(got) != -3.3 {
|
||||
t.Errorf("got %v, want -3.3", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(9); got != -1.64 {
|
||||
t.Errorf("got %v, want -1.64", got)
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
stmt, _, err = db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
stmt, _, err = db.Prepare(`
|
||||
SELECT
|
||||
var_samp(x) OVER (ROWS 1 PRECEDING),
|
||||
var_pop(x) OVER (ROWS 1 PRECEDING),
|
||||
skewness_pop(x) OVER (ROWS 1 PRECEDING)
|
||||
FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -96,12 +129,28 @@ func TestRegister_covariance(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT regr_count(y, x), regr_json(y, x) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
} else {
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want 0", got)
|
||||
}
|
||||
if got := stmt.ColumnType(1); got != sqlite3.NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
err = db.Exec(`INSERT INTO data (y, x) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT
|
||||
stmt, _, err = db.Prepare(`SELECT
|
||||
corr(y, x), covar_samp(y, x), covar_pop(y, x),
|
||||
regr_avgy(y, x), regr_avgx(y, x),
|
||||
regr_syy(y, x), regr_sxx(y, x), regr_sxy(y, x),
|
||||
@@ -111,53 +160,59 @@ func TestRegister_covariance(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
|
||||
t.Errorf("got %v, want 0.9881049293224639", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(1); got != 21.25 {
|
||||
t.Errorf("got %v, want 21.25", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(2); got != 17 {
|
||||
t.Errorf("got %v, want 17", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(3); got != 4.2 {
|
||||
t.Errorf("got %v, want 4.2", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(4); got != 75 {
|
||||
t.Errorf("got %v, want 75", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(5); got != 14.8 {
|
||||
t.Errorf("got %v, want 14.8", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(6); got != 500 {
|
||||
t.Errorf("got %v, want 500", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(7); got != 85 {
|
||||
t.Errorf("got %v, want 85", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(8); got != 0.17 {
|
||||
t.Errorf("got %v, want 0.17", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(9); got != -8.55 {
|
||||
t.Errorf("got %v, want -8.55", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(10); got != 0.9763513513513513 {
|
||||
t.Errorf("got %v, want 0.9763513513513513", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(11); got != 5 {
|
||||
t.Errorf("got %v, want 5", got)
|
||||
}
|
||||
var a map[string]float64
|
||||
if err := stmt.ColumnJSON(12, &a); err != nil {
|
||||
t.Error(err)
|
||||
} else if got := a["count"]; got != 5 {
|
||||
t.Errorf("got %v, want 5", got)
|
||||
}
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
|
||||
t.Errorf("got %v, want 0.9881049293224639", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(1); got != 21.25 {
|
||||
t.Errorf("got %v, want 21.25", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(2); got != 17 {
|
||||
t.Errorf("got %v, want 17", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(3); got != 4.2 {
|
||||
t.Errorf("got %v, want 4.2", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(4); got != 75 {
|
||||
t.Errorf("got %v, want 75", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(5); got != 14.8 {
|
||||
t.Errorf("got %v, want 14.8", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(6); got != 500 {
|
||||
t.Errorf("got %v, want 500", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(7); got != 85 {
|
||||
t.Errorf("got %v, want 85", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(8); got != 0.17 {
|
||||
t.Errorf("got %v, want 0.17", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(9); got != -8.55 {
|
||||
t.Errorf("got %v, want -8.55", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(10); got != 0.9763513513513513 {
|
||||
t.Errorf("got %v, want 0.9763513513513513", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(11); got != 5 {
|
||||
t.Errorf("got %v, want 5", got)
|
||||
}
|
||||
var a map[string]float64
|
||||
if err := stmt.ColumnJSON(12, &a); err != nil {
|
||||
t.Error(err)
|
||||
} else if got := a["count"]; got != 5 {
|
||||
t.Errorf("got %v, want 5", got)
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
stmt, _, err = db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
stmt, _, err = db.Prepare(`
|
||||
SELECT
|
||||
covar_samp(y, x) OVER (ROWS 1 PRECEDING),
|
||||
covar_pop(y, x) OVER (ROWS 1 PRECEDING),
|
||||
regr_avgx(y, x) OVER (ROWS 1 PRECEDING)
|
||||
FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -171,6 +226,9 @@ func TestRegister_covariance(t *testing.T) {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
}
|
||||
if stmt.Err() != nil {
|
||||
t.Fatal(stmt.Err())
|
||||
}
|
||||
stmt.Close()
|
||||
}
|
||||
|
||||
@@ -195,7 +253,9 @@ func Benchmark_average(b *testing.B) {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if !stmt.Step() {
|
||||
b.Fatal(stmt.Err())
|
||||
} else {
|
||||
want := float64(b.N) / 2
|
||||
if got := stmt.ColumnFloat(0); got != want {
|
||||
b.Errorf("got %v, want %v", got, want)
|
||||
@@ -229,7 +289,9 @@ func Benchmark_variance(b *testing.B) {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
if stmt.Step() && b.N > 100 {
|
||||
if !stmt.Step() {
|
||||
b.Fatal(stmt.Err())
|
||||
} else if b.N > 100 {
|
||||
want := float64(b.N*b.N) / 12
|
||||
if got := stmt.ColumnFloat(0); want > (got-want)*float64(b.N) {
|
||||
b.Errorf("got %v, want %v", got, want)
|
||||
|
||||
@@ -3,22 +3,20 @@ package stats
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Welford's algorithm with Kahan summation:
|
||||
// The effect of truncation in statistical computation [van Reeken, AJ 1970]
|
||||
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
|
||||
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
|
||||
|
||||
// See also:
|
||||
// https://duckdb.org/docs/sql/aggregates.html#statistical-aggregates
|
||||
|
||||
type welford struct {
|
||||
m1, m2 kahan
|
||||
n int64
|
||||
}
|
||||
|
||||
func (w welford) average() float64 {
|
||||
func (w welford) mean() float64 {
|
||||
return w.m1.hi
|
||||
}
|
||||
|
||||
@@ -39,17 +37,23 @@ func (w welford) stddev_samp() float64 {
|
||||
}
|
||||
|
||||
func (w *welford) enqueue(x float64) {
|
||||
w.n++
|
||||
n := w.n + 1
|
||||
w.n = n
|
||||
d1 := x - w.m1.hi - w.m1.lo
|
||||
w.m1.add(d1 / float64(w.n))
|
||||
w.m1.add(d1 / float64(n))
|
||||
d2 := x - w.m1.hi - w.m1.lo
|
||||
w.m2.add(d1 * d2)
|
||||
}
|
||||
|
||||
func (w *welford) dequeue(x float64) {
|
||||
w.n--
|
||||
n := w.n - 1
|
||||
if n <= 0 {
|
||||
*w = welford{}
|
||||
return
|
||||
}
|
||||
w.n = n
|
||||
d1 := x - w.m1.hi - w.m1.lo
|
||||
w.m1.sub(d1 / float64(w.n))
|
||||
w.m1.sub(d1 / float64(n))
|
||||
d2 := x - w.m1.hi - w.m1.lo
|
||||
w.m2.sub(d1 * d2)
|
||||
}
|
||||
@@ -112,38 +116,35 @@ func (w welford2) regr_r2() float64 {
|
||||
return w.cov.hi * w.cov.hi / (w.m2y.hi * w.m2x.hi)
|
||||
}
|
||||
|
||||
func (w welford2) regr_json() string {
|
||||
var json strings.Builder
|
||||
var num [32]byte
|
||||
json.Grow(128)
|
||||
json.WriteString(`{"count":`)
|
||||
json.Write(strconv.AppendInt(num[:0], w.regr_count(), 10))
|
||||
json.WriteString(`,"avgy":`)
|
||||
json.Write(strconv.AppendFloat(num[:0], w.regr_avgy(), 'g', -1, 64))
|
||||
json.WriteString(`,"avgx":`)
|
||||
json.Write(strconv.AppendFloat(num[:0], w.regr_avgx(), 'g', -1, 64))
|
||||
json.WriteString(`,"syy":`)
|
||||
json.Write(strconv.AppendFloat(num[:0], w.regr_syy(), 'g', -1, 64))
|
||||
json.WriteString(`,"sxx":`)
|
||||
json.Write(strconv.AppendFloat(num[:0], w.regr_sxx(), 'g', -1, 64))
|
||||
json.WriteString(`,"sxy":`)
|
||||
json.Write(strconv.AppendFloat(num[:0], w.regr_sxy(), 'g', -1, 64))
|
||||
json.WriteString(`,"slope":`)
|
||||
json.Write(strconv.AppendFloat(num[:0], w.regr_slope(), 'g', -1, 64))
|
||||
json.WriteString(`,"intercept":`)
|
||||
json.Write(strconv.AppendFloat(num[:0], w.regr_intercept(), 'g', -1, 64))
|
||||
json.WriteString(`,"r2":`)
|
||||
json.Write(strconv.AppendFloat(num[:0], w.regr_r2(), 'g', -1, 64))
|
||||
json.WriteByte('}')
|
||||
return json.String()
|
||||
func (w welford2) regr_json(dst []byte) []byte {
|
||||
dst = append(dst, `{"count":`...)
|
||||
dst = strconv.AppendInt(dst, w.regr_count(), 10)
|
||||
dst = append(dst, `,"avgy":`...)
|
||||
dst = util.AppendNumber(dst, w.regr_avgy())
|
||||
dst = append(dst, `,"avgx":`...)
|
||||
dst = util.AppendNumber(dst, w.regr_avgx())
|
||||
dst = append(dst, `,"syy":`...)
|
||||
dst = util.AppendNumber(dst, w.regr_syy())
|
||||
dst = append(dst, `,"sxx":`...)
|
||||
dst = util.AppendNumber(dst, w.regr_sxx())
|
||||
dst = append(dst, `,"sxy":`...)
|
||||
dst = util.AppendNumber(dst, w.regr_sxy())
|
||||
dst = append(dst, `,"slope":`...)
|
||||
dst = util.AppendNumber(dst, w.regr_slope())
|
||||
dst = append(dst, `,"intercept":`...)
|
||||
dst = util.AppendNumber(dst, w.regr_intercept())
|
||||
dst = append(dst, `,"r2":`...)
|
||||
dst = util.AppendNumber(dst, w.regr_r2())
|
||||
return append(dst, '}')
|
||||
}
|
||||
|
||||
func (w *welford2) enqueue(y, x float64) {
|
||||
w.n++
|
||||
n := w.n + 1
|
||||
w.n = n
|
||||
d1y := y - w.m1y.hi - w.m1y.lo
|
||||
d1x := x - w.m1x.hi - w.m1x.lo
|
||||
w.m1y.add(d1y / float64(w.n))
|
||||
w.m1x.add(d1x / float64(w.n))
|
||||
w.m1y.add(d1y / float64(n))
|
||||
w.m1x.add(d1x / float64(n))
|
||||
d2y := y - w.m1y.hi - w.m1y.lo
|
||||
d2x := x - w.m1x.hi - w.m1x.lo
|
||||
w.m2y.add(d1y * d2y)
|
||||
@@ -152,30 +153,19 @@ func (w *welford2) enqueue(y, x float64) {
|
||||
}
|
||||
|
||||
func (w *welford2) dequeue(y, x float64) {
|
||||
w.n--
|
||||
n := w.n - 1
|
||||
if n <= 0 {
|
||||
*w = welford2{}
|
||||
return
|
||||
}
|
||||
w.n = n
|
||||
d1y := y - w.m1y.hi - w.m1y.lo
|
||||
d1x := x - w.m1x.hi - w.m1x.lo
|
||||
w.m1y.sub(d1y / float64(w.n))
|
||||
w.m1x.sub(d1x / float64(w.n))
|
||||
w.m1y.sub(d1y / float64(n))
|
||||
w.m1x.sub(d1x / float64(n))
|
||||
d2y := y - w.m1y.hi - w.m1y.lo
|
||||
d2x := x - w.m1x.hi - w.m1x.lo
|
||||
w.m2y.sub(d1y * d2y)
|
||||
w.m2x.sub(d1x * d2x)
|
||||
w.cov.sub(d1y * d2x)
|
||||
}
|
||||
|
||||
type kahan struct{ hi, lo float64 }
|
||||
|
||||
func (k *kahan) add(x float64) {
|
||||
y := k.lo + x
|
||||
t := k.hi + y
|
||||
k.lo = y - (t - k.hi)
|
||||
k.hi = t
|
||||
}
|
||||
|
||||
func (k *kahan) sub(x float64) {
|
||||
y := k.lo - x
|
||||
t := k.hi + y
|
||||
k.lo = y - (t - k.hi)
|
||||
k.hi = t
|
||||
}
|
||||
|
||||
@@ -9,12 +9,14 @@ func Test_welford(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var s1, s2 welford
|
||||
s1.enqueue(1)
|
||||
s1.dequeue(1)
|
||||
|
||||
s1.enqueue(4)
|
||||
s1.enqueue(7)
|
||||
s1.enqueue(13)
|
||||
s1.enqueue(16)
|
||||
if got := s1.average(); got != 10 {
|
||||
if got := s1.mean(); got != 10 {
|
||||
t.Errorf("got %v, want 10", got)
|
||||
}
|
||||
if got := s1.var_samp(); got != 30 {
|
||||
@@ -43,6 +45,8 @@ func Test_covar(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var c1, c2 welford2
|
||||
c1.enqueue(1, 1)
|
||||
c1.dequeue(1, 1)
|
||||
|
||||
c1.enqueue(3, 70)
|
||||
c1.enqueue(5, 80)
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
// Package unicode provides an alternative to the SQLite ICU extension.
|
||||
//
|
||||
// Like the [ICU extension], it provides Unicode aware:
|
||||
// - upper() and lower() functions,
|
||||
// - LIKE and REGEXP operators,
|
||||
// - collation sequences.
|
||||
// - upper() and lower() functions
|
||||
// - LIKE and REGEXP operators
|
||||
// - collation sequences
|
||||
//
|
||||
// It also provides, from PostgreSQL:
|
||||
// - unaccent(),
|
||||
// - initcap().
|
||||
// Like PostgreSQL, it also provides:
|
||||
// - initcap()
|
||||
// - casefold()
|
||||
// - normalize()
|
||||
// - unaccent()
|
||||
//
|
||||
// The implementation is not 100% compatible with the [ICU extension]:
|
||||
// - upper() and lower() use [strings.ToUpper], [strings.ToLower] and [cases];
|
||||
// - the LIKE operator follows [strings.EqualFold] rules;
|
||||
// - the REGEXP operator uses Go [regexp/syntax];
|
||||
// - collation sequences use [collate].
|
||||
// The implementations are not 100% compatible:
|
||||
// - upper(), lower(), initcap() casefold() use [strings.ToUpper], [strings.ToLower], [strings.Title] and [cases]
|
||||
// - normalize(), unaccent() use [transform] and [unicode.Mn]
|
||||
// - the LIKE operator follows [strings.EqualFold] rules
|
||||
// - the REGEXP operator uses Go [regexp/syntax]
|
||||
// - collation sequences use [collate]
|
||||
//
|
||||
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
|
||||
//
|
||||
@@ -25,6 +28,7 @@ import (
|
||||
"errors"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
@@ -39,7 +43,7 @@ import (
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Set RegisterLike to false to not register a Unicode aware LIKE operator.
|
||||
// RegisterLike must be set to false to not register a Unicode aware LIKE operator.
|
||||
// Overriding the built-in LIKE operator disables the [LIKE optimization].
|
||||
//
|
||||
// [LIKE optimization]: https://sqlite.org/optoverview.html#the_like_optimization
|
||||
@@ -48,13 +52,13 @@ var RegisterLike = true
|
||||
// Register registers Unicode aware functions for a database connection.
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
var errs util.ErrorJoiner
|
||||
var lkfn sqlite3.ScalarFunction
|
||||
if RegisterLike {
|
||||
errs.Join(
|
||||
db.CreateFunction("like", 2, flags, like),
|
||||
db.CreateFunction("like", 3, flags, like))
|
||||
lkfn = like
|
||||
}
|
||||
errs.Join(
|
||||
return errors.Join(
|
||||
db.CreateFunction("like", 2, flags, lkfn),
|
||||
db.CreateFunction("like", 3, flags, lkfn),
|
||||
db.CreateFunction("upper", 1, flags, upper),
|
||||
db.CreateFunction("upper", 2, flags, upper),
|
||||
db.CreateFunction("lower", 1, flags, lower),
|
||||
@@ -62,7 +66,10 @@ func Register(db *sqlite3.Conn) error {
|
||||
db.CreateFunction("regexp", 2, flags, regex),
|
||||
db.CreateFunction("initcap", 1, flags, initcap),
|
||||
db.CreateFunction("initcap", 2, flags, initcap),
|
||||
db.CreateFunction("casefold", 1, flags, casefold),
|
||||
db.CreateFunction("unaccent", 1, flags, unaccent),
|
||||
db.CreateFunction("normalize", 1, flags, normalize),
|
||||
db.CreateFunction("normalize", 2, flags, normalize),
|
||||
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
|
||||
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
name := arg[1].Text()
|
||||
@@ -76,7 +83,6 @@ func Register(db *sqlite3.Conn) error {
|
||||
return // notest
|
||||
}
|
||||
}))
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// RegisterCollation registers a Unicode collation sequence for a database connection.
|
||||
@@ -109,9 +115,8 @@ func upper(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
c := cases.Upper(t)
|
||||
ctx.SetAuxData(1, c)
|
||||
cs = c
|
||||
cs = cases.Upper(t)
|
||||
ctx.SetAuxData(1, cs)
|
||||
}
|
||||
ctx.ResultRawText(cs.Bytes(arg[0].RawText()))
|
||||
}
|
||||
@@ -128,9 +133,8 @@ func lower(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
c := cases.Lower(t)
|
||||
ctx.SetAuxData(1, c)
|
||||
cs = c
|
||||
cs = cases.Lower(t)
|
||||
ctx.SetAuxData(1, cs)
|
||||
}
|
||||
ctx.ResultRawText(cs.Bytes(arg[0].RawText()))
|
||||
}
|
||||
@@ -147,15 +151,26 @@ func initcap(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
c := cases.Title(t)
|
||||
ctx.SetAuxData(1, c)
|
||||
cs = c
|
||||
cs = cases.Title(t)
|
||||
ctx.SetAuxData(1, cs)
|
||||
}
|
||||
ctx.ResultRawText(cs.Bytes(arg[0].RawText()))
|
||||
}
|
||||
|
||||
func casefold(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
ctx.ResultRawText(cases.Fold().Bytes(arg[0].RawText()))
|
||||
}
|
||||
|
||||
var unaccentPool = sync.Pool{
|
||||
New: func() any {
|
||||
return transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC)
|
||||
},
|
||||
}
|
||||
|
||||
func unaccent(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
unaccent := transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC)
|
||||
unaccent := unaccentPool.Get().(transform.Transformer)
|
||||
defer unaccentPool.Put(unaccent)
|
||||
|
||||
res, _, err := transform.Bytes(unaccent, arg[0].RawText())
|
||||
if err != nil {
|
||||
ctx.ResultError(err) // notest
|
||||
@@ -164,16 +179,44 @@ func unaccent(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
}
|
||||
}
|
||||
|
||||
func normalize(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
form := norm.NFC
|
||||
if len(arg) > 1 {
|
||||
switch strings.ToUpper(arg[1].Text()) {
|
||||
case "NFC":
|
||||
//
|
||||
case "NFD":
|
||||
form = norm.NFD
|
||||
case "NFKC":
|
||||
form = norm.NFKC
|
||||
case "NFKD":
|
||||
form = norm.NFKD
|
||||
default:
|
||||
ctx.ResultError(util.ErrorString("unicode: invalid form"))
|
||||
return
|
||||
}
|
||||
}
|
||||
res, _, err := transform.Bytes(form, arg[0].RawText())
|
||||
if err != nil {
|
||||
ctx.ResultError(err) // notest
|
||||
} else {
|
||||
ctx.ResultRawText(res)
|
||||
}
|
||||
}
|
||||
|
||||
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
|
||||
if !ok {
|
||||
r, err := regexp.Compile(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
re, ok = arg[0].Pointer().(*regexp.Regexp)
|
||||
if !ok {
|
||||
r, err := regexp.Compile(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
re = r
|
||||
}
|
||||
re = r
|
||||
ctx.SetAuxData(0, r)
|
||||
ctx.SetAuxData(0, re)
|
||||
}
|
||||
ctx.ResultBool(re.Match(arg[1].RawText()))
|
||||
}
|
||||
@@ -189,6 +232,7 @@ func like(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
return
|
||||
}
|
||||
}
|
||||
_ = arg[1] // bounds check
|
||||
|
||||
type likeData struct {
|
||||
*regexp.Regexp
|
||||
|
||||
@@ -2,7 +2,7 @@ package unicode
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
@@ -26,11 +26,10 @@ func TestRegister(t *testing.T) {
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
return stmt.ColumnText(0)
|
||||
if !stmt.Step() {
|
||||
t.Fatal(stmt.Err())
|
||||
}
|
||||
t.Fatal(stmt.Err())
|
||||
return ""
|
||||
return stmt.ColumnText(0)
|
||||
}
|
||||
|
||||
Register(db)
|
||||
@@ -49,6 +48,12 @@ func TestRegister(t *testing.T) {
|
||||
{`upper('Dünyanın İlk Borsası', 'tr-TR')`, "DÜNYANIN İLK BORSASI"},
|
||||
{`initcap('Kad je hladno Marko nosi džemper')`, "Kad Je Hladno Marko Nosi Džemper"},
|
||||
{`initcap('Kad je hladno Marko nosi džemper', 'hr-HR')`, "Kad Je Hladno Marko Nosi Džemper"},
|
||||
{`normalize(X'61cc88')`, "ä"},
|
||||
{`normalize(X'61cc88', 'NFC' )`, "ä"},
|
||||
{`normalize(X'61cc88', 'NFKC')`, "ä"},
|
||||
{`normalize('ä', 'NFD' )`, "\x61\xcc\x88"},
|
||||
{`normalize('ä', 'NFKD')`, "\x61\xcc\x88"},
|
||||
{`casefold('Maße')`, "masse"},
|
||||
{`unaccent('Hôtel')`, "Hotel"},
|
||||
{`'Hello' REGEXP 'ell'`, "1"},
|
||||
{`'Hello' REGEXP 'el.'`, "1"},
|
||||
@@ -115,7 +120,7 @@ func TestRegister_collation(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
if !slices.Equal(got, want) {
|
||||
t.Error("not equal")
|
||||
}
|
||||
|
||||
@@ -166,7 +171,7 @@ func TestRegisterCollationsNeeded(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
if !slices.Equal(got, want) {
|
||||
t.Error("not equal")
|
||||
}
|
||||
|
||||
@@ -208,6 +213,14 @@ func TestRegister_error(t *testing.T) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT normalize('', 'NF')`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT 'hello' REGEXP '\'`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
@@ -16,17 +17,18 @@ import (
|
||||
|
||||
// Register registers the SQL functions:
|
||||
//
|
||||
// uuid([version], [domain/namespace], [id/data])
|
||||
//
|
||||
// Generates a UUID as a string.
|
||||
//
|
||||
// uuid_str(u)
|
||||
//
|
||||
// Converts a UUID into a well-formed UUID string.
|
||||
//
|
||||
// uuid_blob(u)
|
||||
//
|
||||
// Converts a UUID into a 16-byte blob.
|
||||
// - uuid([ version [, domain/namespace, [ id/data ]]]):
|
||||
// to generate a UUID as a string
|
||||
// - uuid_str(u):
|
||||
// to convert a UUID into a well-formed UUID string
|
||||
// - uuid_blob(u):
|
||||
// to convert a UUID into a 16-byte blob
|
||||
// - uuid_extract_version(u):
|
||||
// to extract the version of a RFC 4122 UUID
|
||||
// - uuid_extract_timestamp(u):
|
||||
// to extract the timestamp of a version 1/2/6/7 UUID
|
||||
// - gen_random_uuid(u):
|
||||
// to generate a version 4 (random) UUID
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
return errors.Join(
|
||||
@@ -35,7 +37,10 @@ func Register(db *sqlite3.Conn) error {
|
||||
db.CreateFunction("uuid", 2, sqlite3.INNOCUOUS, generate),
|
||||
db.CreateFunction("uuid", 3, sqlite3.INNOCUOUS, generate),
|
||||
db.CreateFunction("uuid_str", 1, flags, toString),
|
||||
db.CreateFunction("uuid_blob", 1, flags, toBlob))
|
||||
db.CreateFunction("uuid_blob", 1, flags, toBlob),
|
||||
db.CreateFunction("uuid_extract_version", 1, flags, version),
|
||||
db.CreateFunction("uuid_extract_timestamp", 1, flags, timestamp),
|
||||
db.CreateFunction("gen_random_uuid", 0, sqlite3.INNOCUOUS, generate))
|
||||
}
|
||||
|
||||
func generate(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
@@ -167,3 +172,30 @@ func toString(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
ctx.ResultText(u.String())
|
||||
}
|
||||
}
|
||||
|
||||
func version(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
u, err := fromValue(arg[0])
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
if u.Variant() == uuid.RFC4122 {
|
||||
ctx.ResultInt64(int64(u.Version()))
|
||||
}
|
||||
}
|
||||
|
||||
func timestamp(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
u, err := fromValue(arg[0])
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
if u.Variant() == uuid.RFC4122 {
|
||||
switch u.Version() {
|
||||
case 1, 2, 6, 7:
|
||||
ctx.ResultTime(
|
||||
time.Unix(u.Time().UnixTime()).UTC(),
|
||||
sqlite3.TimeFormatDefault)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package uuid
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
@@ -13,9 +14,9 @@ import (
|
||||
|
||||
func Test_generate(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, Register)
|
||||
db, err := driver.Open(dsn, Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -106,7 +107,26 @@ func Test_generate(t *testing.T) {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
hash := []struct {
|
||||
var tstamp time.Time
|
||||
var version uuid.Version
|
||||
err = db.QueryRow(`
|
||||
SELECT
|
||||
column1,
|
||||
uuid_extract_version(column1),
|
||||
uuid_extract_timestamp(column1)
|
||||
FROM (VALUES (uuid(7)))
|
||||
`).Scan(&u, &version, &tstamp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := u.Version(); got != version {
|
||||
t.Errorf("got %d, want %d", got, version)
|
||||
}
|
||||
if got := time.Unix(u.Time().UnixTime()); !got.Equal(tstamp) {
|
||||
t.Errorf("got %v, want %v", got, tstamp)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
ver uuid.Version
|
||||
ns any
|
||||
data string
|
||||
@@ -120,7 +140,7 @@ func Test_generate(t *testing.T) {
|
||||
{3, "url", "https://www.php.net", uuid.MustParse("3f703955-aaba-3e70-a3cb-baff6aa3b28f")},
|
||||
{5, "url", "https://www.php.net", uuid.MustParse("a8f6ae40-d8a7-58f0-be05-a22f94eca9ec")},
|
||||
}
|
||||
for _, tt := range hash {
|
||||
for _, tt := range tests {
|
||||
err = db.QueryRow(`SELECT uuid(?, ?, ?)`, tt.ver, tt.ns, tt.data).Scan(&u)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -133,23 +153,23 @@ func Test_generate(t *testing.T) {
|
||||
|
||||
func Test_convert(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, Register)
|
||||
db, err := driver.Open(dsn, Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
var u uuid.UUID
|
||||
lits := []string{
|
||||
tests := []string{
|
||||
"'6ba7b8119dad11d180b400c04fd430c8'",
|
||||
"'6ba7b811-9dad-11d1-80b4-00c04fd430c8'",
|
||||
"'{6ba7b811-9dad-11d1-80b4-00c04fd430c8}'",
|
||||
"X'6ba7b8119dad11d180b400c04fd430c8'",
|
||||
}
|
||||
|
||||
for _, tt := range lits {
|
||||
for _, tt := range tests {
|
||||
err = db.QueryRow(`SELECT uuid_str(` + tt + `)`).Scan(&u)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -159,7 +179,7 @@ func Test_convert(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range lits {
|
||||
for _, tt := range tests {
|
||||
err = db.QueryRow(`SELECT uuid_blob(` + tt + `)`).Scan(&u)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -178,4 +198,14 @@ func Test_convert(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT uuid_extract_version(X'cafe')`).Scan(&u)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT uuid_extract_timestamp(X'cafe')`).Scan(&u)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,9 +19,9 @@ func Register(db *sqlite3.Conn) error {
|
||||
}
|
||||
|
||||
func zorder(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
var x [63]int64
|
||||
if len(arg) > len(x) {
|
||||
ctx.ResultError(util.ErrorString("zorder: too many parameters"))
|
||||
var x [24]int64
|
||||
if n := len(arg); n < 2 || n > 24 {
|
||||
ctx.ResultError(util.ErrorString("zorder: needs between 2 and 24 dimensions"))
|
||||
return
|
||||
}
|
||||
for i := range arg {
|
||||
@@ -29,17 +29,15 @@ func zorder(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
}
|
||||
|
||||
var z int64
|
||||
if len(arg) > 0 {
|
||||
for i := range x {
|
||||
j := i % len(arg)
|
||||
z |= (x[j] & 1) << i
|
||||
x[j] >>= 1
|
||||
}
|
||||
for i := range 63 {
|
||||
j := i % len(arg)
|
||||
z |= (x[j] & 1) << i
|
||||
x[j] >>= 1
|
||||
}
|
||||
|
||||
for i := range arg {
|
||||
if x[i] != 0 {
|
||||
ctx.ResultError(util.ErrorString("zorder: parameter too large"))
|
||||
ctx.ResultError(util.ErrorString("zorder: argument out of range"))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -51,6 +49,19 @@ func unzorder(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
n := arg[1].Int64()
|
||||
z := arg[0].Int64()
|
||||
|
||||
if n < 2 || n > 24 {
|
||||
ctx.ResultError(util.ErrorString("unzorder: needs between 2 and 24 dimensions"))
|
||||
return
|
||||
}
|
||||
if i < 0 || i >= n {
|
||||
ctx.ResultError(util.ErrorString("unzorder: index out of range"))
|
||||
return
|
||||
}
|
||||
if z < 0 {
|
||||
ctx.ResultError(util.ErrorString("unzorder: argument out of range"))
|
||||
return
|
||||
}
|
||||
|
||||
var k int
|
||||
var x int64
|
||||
for j := i; j < 63; j += n {
|
||||
|
||||
@@ -12,11 +12,11 @@ import (
|
||||
"github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
func TestRegister_zorder(t *testing.T) {
|
||||
func Test_zorder(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, zorder.Register)
|
||||
db, err := driver.Open(dsn, zorder.Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -57,11 +57,11 @@ func TestRegister_zorder(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_unzorder(t *testing.T) {
|
||||
func Test_unzorder(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, zorder.Register)
|
||||
db, err := driver.Open(dsn, zorder.Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -85,11 +85,11 @@ func TestRegister_unzorder(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_error(t *testing.T) {
|
||||
func Test_zorder_error(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, zorder.Register)
|
||||
db, err := driver.Open(dsn, zorder.Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -103,7 +103,7 @@ func TestRegister_error(t *testing.T) {
|
||||
|
||||
var buf strings.Builder
|
||||
buf.WriteString("SELECT zorder(0")
|
||||
for i := 1; i < 80; i++ {
|
||||
for i := 1; i < 25; i++ {
|
||||
buf.WriteByte(',')
|
||||
buf.WriteString(strconv.Itoa(0))
|
||||
}
|
||||
@@ -113,3 +113,30 @@ func TestRegister_error(t *testing.T) {
|
||||
t.Error("want error")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_unzorder_error(t *testing.T) {
|
||||
t.Parallel()
|
||||
dsn := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(dsn, zorder.Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
var got int64
|
||||
err = db.QueryRow(`SELECT unzorder(-1, 2, 0)`).Scan(&got)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT unzorder(0, 2, 2)`).Scan(&got)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT unzorder(0, 25, 2)`).Scan(&got)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
}
|
||||
|
||||
233
func.go
233
func.go
@@ -2,7 +2,10 @@ package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"iter"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
|
||||
@@ -14,12 +17,12 @@ import (
|
||||
//
|
||||
// https://sqlite.org/c3ref/collation_needed.html
|
||||
func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error {
|
||||
var enable uint64
|
||||
var enable int32
|
||||
if cb != nil {
|
||||
enable = 1
|
||||
}
|
||||
r := c.call("sqlite3_collation_needed_go", uint64(c.handle), enable)
|
||||
if err := c.error(r); err != nil {
|
||||
rc := res_t(c.call("sqlite3_collation_needed_go", stk_t(c.handle), stk_t(enable)))
|
||||
if err := c.error(rc); err != nil {
|
||||
return err
|
||||
}
|
||||
c.collation = cb
|
||||
@@ -33,8 +36,8 @@ func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error {
|
||||
// This can be used to load schemas that contain
|
||||
// one or more unknown collating sequences.
|
||||
func (c Conn) AnyCollationNeeded() error {
|
||||
r := c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0)
|
||||
if err := c.error(r); err != nil {
|
||||
rc := res_t(c.call("sqlite3_anycollseq_init", stk_t(c.handle), 0, 0))
|
||||
if err := c.error(rc); err != nil {
|
||||
return err
|
||||
}
|
||||
c.collation = nil
|
||||
@@ -44,60 +47,103 @@ func (c Conn) AnyCollationNeeded() error {
|
||||
// CreateCollation defines a new collating sequence.
|
||||
//
|
||||
// https://sqlite.org/c3ref/create_collation.html
|
||||
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
|
||||
var funcPtr uint32
|
||||
func (c *Conn) CreateCollation(name string, fn CollatingFunction) error {
|
||||
var funcPtr ptr_t
|
||||
defer c.arena.mark()()
|
||||
namePtr := c.arena.string(name)
|
||||
if fn != nil {
|
||||
funcPtr = util.AddHandle(c.ctx, fn)
|
||||
}
|
||||
r := c.call("sqlite3_create_collation_go",
|
||||
uint64(c.handle), uint64(namePtr), uint64(funcPtr))
|
||||
return c.error(r)
|
||||
rc := res_t(c.call("sqlite3_create_collation_go",
|
||||
stk_t(c.handle), stk_t(namePtr), stk_t(funcPtr)))
|
||||
return c.error(rc)
|
||||
}
|
||||
|
||||
// CollatingFunction is the type of a collation callback.
|
||||
// Implementations must not retain a or b.
|
||||
type CollatingFunction func(a, b []byte) int
|
||||
|
||||
// CreateFunction defines a new scalar SQL function.
|
||||
//
|
||||
// https://sqlite.org/c3ref/create_function.html
|
||||
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error {
|
||||
var funcPtr uint32
|
||||
var funcPtr ptr_t
|
||||
defer c.arena.mark()()
|
||||
namePtr := c.arena.string(name)
|
||||
if fn != nil {
|
||||
funcPtr = util.AddHandle(c.ctx, fn)
|
||||
}
|
||||
r := c.call("sqlite3_create_function_go",
|
||||
uint64(c.handle), uint64(namePtr), uint64(nArg),
|
||||
uint64(flag), uint64(funcPtr))
|
||||
return c.error(r)
|
||||
rc := res_t(c.call("sqlite3_create_function_go",
|
||||
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
|
||||
stk_t(flag), stk_t(funcPtr)))
|
||||
return c.error(rc)
|
||||
}
|
||||
|
||||
// ScalarFunction is the type of a scalar SQL function.
|
||||
// Implementations must not retain arg.
|
||||
type ScalarFunction func(ctx Context, arg ...Value)
|
||||
|
||||
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
|
||||
// If fn returns a [WindowFunction], then an aggregate window function is created.
|
||||
// If fn returns an [io.Closer], it will be called to free resources.
|
||||
// CreateAggregateFunction defines a new aggregate SQL function.
|
||||
//
|
||||
// https://sqlite.org/c3ref/create_function.html
|
||||
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
|
||||
var funcPtr uint32
|
||||
func (c *Conn) CreateAggregateFunction(name string, nArg int, flag FunctionFlag, fn AggregateSeqFunction) error {
|
||||
var funcPtr ptr_t
|
||||
defer c.arena.mark()()
|
||||
namePtr := c.arena.string(name)
|
||||
if fn != nil {
|
||||
funcPtr = util.AddHandle(c.ctx, fn)
|
||||
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
|
||||
var a aggregateFunc
|
||||
coro := func(yieldCoro func(struct{}) bool) {
|
||||
seq := func(yieldSeq func([]Value) bool) {
|
||||
for yieldSeq(a.arg) {
|
||||
if !yieldCoro(struct{}{}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
fn(&a.ctx, seq)
|
||||
}
|
||||
a.next, a.stop = iter.Pull(coro)
|
||||
return &a
|
||||
}))
|
||||
}
|
||||
call := "sqlite3_create_aggregate_function_go"
|
||||
if _, ok := fn().(WindowFunction); ok {
|
||||
call = "sqlite3_create_window_function_go"
|
||||
}
|
||||
r := c.call(call,
|
||||
uint64(c.handle), uint64(namePtr), uint64(nArg),
|
||||
uint64(flag), uint64(funcPtr))
|
||||
return c.error(r)
|
||||
rc := res_t(c.call("sqlite3_create_aggregate_function_go",
|
||||
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
|
||||
stk_t(flag), stk_t(funcPtr)))
|
||||
return c.error(rc)
|
||||
}
|
||||
|
||||
// AggregateSeqFunction is the type of an aggregate SQL function.
|
||||
// Implementations must not retain the slices yielded by seq.
|
||||
type AggregateSeqFunction func(ctx *Context, seq iter.Seq[[]Value])
|
||||
|
||||
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
|
||||
// If fn returns a [WindowFunction], an aggregate window function is created.
|
||||
// If fn returns an [io.Closer], it will be called to free resources.
|
||||
//
|
||||
// https://sqlite.org/c3ref/create_function.html
|
||||
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn AggregateConstructor) error {
|
||||
var funcPtr ptr_t
|
||||
defer c.arena.mark()()
|
||||
namePtr := c.arena.string(name)
|
||||
if fn != nil {
|
||||
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
|
||||
agg := fn()
|
||||
if win, ok := agg.(WindowFunction); ok {
|
||||
return win
|
||||
}
|
||||
return agg
|
||||
}))
|
||||
}
|
||||
rc := res_t(c.call("sqlite3_create_window_function_go",
|
||||
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
|
||||
stk_t(flag), stk_t(funcPtr)))
|
||||
return c.error(rc)
|
||||
}
|
||||
|
||||
// AggregateConstructor is a an [AggregateFunction] constructor.
|
||||
type AggregateConstructor func() AggregateFunction
|
||||
|
||||
// AggregateFunction is the interface an aggregate function should implement.
|
||||
//
|
||||
// https://sqlite.org/appfunc.html
|
||||
@@ -129,102 +175,135 @@ type WindowFunction interface {
|
||||
func (c *Conn) OverloadFunction(name string, nArg int) error {
|
||||
defer c.arena.mark()()
|
||||
namePtr := c.arena.string(name)
|
||||
r := c.call("sqlite3_overload_function",
|
||||
uint64(c.handle), uint64(namePtr), uint64(nArg))
|
||||
return c.error(r)
|
||||
rc := res_t(c.call("sqlite3_overload_function",
|
||||
stk_t(c.handle), stk_t(namePtr), stk_t(nArg)))
|
||||
return c.error(rc)
|
||||
}
|
||||
|
||||
func destroyCallback(ctx context.Context, mod api.Module, pApp uint32) {
|
||||
func destroyCallback(ctx context.Context, mod api.Module, pApp ptr_t) {
|
||||
util.DelHandle(ctx, pApp)
|
||||
}
|
||||
|
||||
func collationCallback(ctx context.Context, mod api.Module, pArg, pDB, eTextRep, zName uint32) {
|
||||
func collationCallback(ctx context.Context, mod api.Module, pArg, pDB ptr_t, eTextRep uint32, zName ptr_t) {
|
||||
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.collation != nil {
|
||||
name := util.ReadString(mod, zName, _MAX_NAME)
|
||||
c.collation(c, name)
|
||||
}
|
||||
}
|
||||
|
||||
func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
|
||||
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
|
||||
return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))
|
||||
func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int32, pKey1 ptr_t, nKey2 int32, pKey2 ptr_t) uint32 {
|
||||
fn := util.GetHandle(ctx, pApp).(CollatingFunction)
|
||||
return uint32(fn(util.View(mod, pKey1, int64(nKey1)), util.View(mod, pKey2, int64(nKey2))))
|
||||
}
|
||||
|
||||
func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp, nArg, pArg uint32) {
|
||||
args := getFuncArgs()
|
||||
defer putFuncArgs(args)
|
||||
func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp ptr_t, nArg int32, pArg ptr_t) {
|
||||
db := ctx.Value(connKey{}).(*Conn)
|
||||
args := callbackArgs(db, nArg, pArg)
|
||||
defer returnArgs(args)
|
||||
fn := util.GetHandle(db.ctx, pApp).(ScalarFunction)
|
||||
callbackArgs(db, args[:nArg], pArg)
|
||||
fn(Context{db, pCtx}, args[:nArg]...)
|
||||
fn(Context{db, pCtx}, *args...)
|
||||
}
|
||||
|
||||
func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp, nArg, pArg uint32) {
|
||||
args := getFuncArgs()
|
||||
defer putFuncArgs(args)
|
||||
func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, nArg int32, pArg ptr_t) {
|
||||
db := ctx.Value(connKey{}).(*Conn)
|
||||
callbackArgs(db, args[:nArg], pArg)
|
||||
args := callbackArgs(db, nArg, pArg)
|
||||
defer returnArgs(args)
|
||||
fn, _ := callbackAggregate(db, pAgg, pApp)
|
||||
fn.Step(Context{db, pCtx}, args[:nArg]...)
|
||||
fn.Step(Context{db, pCtx}, *args...)
|
||||
}
|
||||
|
||||
func finalCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp uint32) {
|
||||
func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, final int32) {
|
||||
db := ctx.Value(connKey{}).(*Conn)
|
||||
fn, handle := callbackAggregate(db, pAgg, pApp)
|
||||
fn.Value(Context{db, pCtx})
|
||||
util.DelHandle(ctx, handle)
|
||||
|
||||
// Cleanup.
|
||||
if final != 0 {
|
||||
var err error
|
||||
if handle != 0 {
|
||||
err = util.DelHandle(ctx, handle)
|
||||
} else if c, ok := fn.(io.Closer); ok {
|
||||
err = c.Close()
|
||||
}
|
||||
if err != nil {
|
||||
Context{db, pCtx}.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg uint32) {
|
||||
func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t, nArg int32, pArg ptr_t) {
|
||||
db := ctx.Value(connKey{}).(*Conn)
|
||||
fn := util.GetHandle(db.ctx, pAgg).(AggregateFunction)
|
||||
fn.Value(Context{db, pCtx})
|
||||
}
|
||||
|
||||
func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg, nArg, pArg uint32) {
|
||||
args := getFuncArgs()
|
||||
defer putFuncArgs(args)
|
||||
db := ctx.Value(connKey{}).(*Conn)
|
||||
callbackArgs(db, args[:nArg], pArg)
|
||||
args := callbackArgs(db, nArg, pArg)
|
||||
defer returnArgs(args)
|
||||
fn := util.GetHandle(db.ctx, pAgg).(WindowFunction)
|
||||
fn.Inverse(Context{db, pCtx}, args[:nArg]...)
|
||||
fn.Inverse(Context{db, pCtx}, *args...)
|
||||
}
|
||||
|
||||
func callbackAggregate(db *Conn, pAgg, pApp uint32) (AggregateFunction, uint32) {
|
||||
func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) {
|
||||
if pApp == 0 {
|
||||
handle := util.ReadUint32(db.mod, pAgg)
|
||||
handle := util.Read32[ptr_t](db.mod, pAgg)
|
||||
return util.GetHandle(db.ctx, handle).(AggregateFunction), handle
|
||||
}
|
||||
|
||||
// We need to create the aggregate.
|
||||
fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)()
|
||||
fn := util.GetHandle(db.ctx, pApp).(AggregateConstructor)()
|
||||
if pAgg != 0 {
|
||||
handle := util.AddHandle(db.ctx, fn)
|
||||
util.WriteUint32(db.mod, pAgg, handle)
|
||||
util.Write32(db.mod, pAgg, handle)
|
||||
return fn, handle
|
||||
}
|
||||
return fn, 0
|
||||
}
|
||||
|
||||
func callbackArgs(db *Conn, arg []Value, pArg uint32) {
|
||||
for i := range arg {
|
||||
arg[i] = Value{
|
||||
var (
|
||||
valueArgsPool sync.Pool
|
||||
valueArgsLen atomic.Int32
|
||||
)
|
||||
|
||||
func callbackArgs(db *Conn, nArg int32, pArg ptr_t) *[]Value {
|
||||
arg, ok := valueArgsPool.Get().(*[]Value)
|
||||
if !ok || cap(*arg) < int(nArg) {
|
||||
max := valueArgsLen.Or(nArg) | nArg
|
||||
lst := make([]Value, max)
|
||||
arg = &lst
|
||||
}
|
||||
lst := (*arg)[:nArg]
|
||||
for i := range lst {
|
||||
lst[i] = Value{
|
||||
c: db,
|
||||
handle: util.ReadUint32(db.mod, pArg+ptrlen*uint32(i)),
|
||||
handle: util.Read32[ptr_t](db.mod, pArg+ptr_t(i)*ptrlen),
|
||||
}
|
||||
}
|
||||
*arg = lst
|
||||
return arg
|
||||
}
|
||||
|
||||
var funcArgsPool sync.Pool
|
||||
|
||||
func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) {
|
||||
funcArgsPool.Put(p)
|
||||
func returnArgs(p *[]Value) {
|
||||
valueArgsPool.Put(p)
|
||||
}
|
||||
|
||||
func getFuncArgs() *[_MAX_FUNCTION_ARG]Value {
|
||||
if p := funcArgsPool.Get(); p == nil {
|
||||
return new([_MAX_FUNCTION_ARG]Value)
|
||||
} else {
|
||||
return p.(*[_MAX_FUNCTION_ARG]Value)
|
||||
type aggregateFunc struct {
|
||||
next func() (struct{}, bool)
|
||||
stop func()
|
||||
ctx Context
|
||||
arg []Value
|
||||
}
|
||||
|
||||
func (a *aggregateFunc) Step(ctx Context, arg ...Value) {
|
||||
a.ctx = ctx
|
||||
a.arg = append(a.arg[:0], arg...)
|
||||
if _, more := a.next(); !more {
|
||||
a.stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (a *aggregateFunc) Value(ctx Context) {
|
||||
a.ctx = ctx
|
||||
a.stop()
|
||||
}
|
||||
|
||||
func (a *aggregateFunc) Close() error {
|
||||
a.stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
57
func_seq_test.go
Normal file
57
func_seq_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"log"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func ExampleConn_CreateAggregateFunction() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE test (col)`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES (1), (2), (3)`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateAggregateFunction("seq_avg", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS,
|
||||
func(ctx *sqlite3.Context, seq iter.Seq[[]sqlite3.Value]) {
|
||||
count := 0
|
||||
total := 0.0
|
||||
for arg := range seq {
|
||||
total += arg[0].Float()
|
||||
count++
|
||||
}
|
||||
ctx.ResultFloat(total / float64(count))
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT seq_avg(col) FROM test`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnFloat(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Output:
|
||||
// 2
|
||||
}
|
||||
22
go.mod
22
go.mod
@@ -1,24 +1,26 @@
|
||||
module github.com/ncruces/go-sqlite3
|
||||
|
||||
go 1.21
|
||||
|
||||
toolchain go1.23.0
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/ncruces/julianday v1.0.0
|
||||
github.com/ncruces/sort v0.1.2
|
||||
github.com/tetratelabs/wazero v1.8.2
|
||||
golang.org/x/crypto v0.31.0
|
||||
golang.org/x/sys v0.28.0
|
||||
github.com/ncruces/sort v0.1.6
|
||||
github.com/ncruces/wbt v1.0.0
|
||||
github.com/tetratelabs/wazero v1.11.0
|
||||
golang.org/x/sys v0.40.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/dchest/siphash v1.2.3 // ext/bloom
|
||||
github.com/google/uuid v1.6.0 // ext/uuid
|
||||
github.com/psanford/httpreadat v0.1.0 // example
|
||||
golang.org/x/sync v0.10.0 // test
|
||||
golang.org/x/text v0.21.0 // ext/unicode
|
||||
golang.org/x/crypto v0.46.0 // vfs/adiantum vfs/xts
|
||||
golang.org/x/sync v0.19.0 // test
|
||||
golang.org/x/text v0.32.0 // ext/unicode
|
||||
lukechampine.com/adiantum v1.1.1 // vfs/adiantum
|
||||
)
|
||||
|
||||
retract v0.4.0 // tagged from the wrong branch
|
||||
retract (
|
||||
v0.23.2 // tagged from the wrong branch
|
||||
v0.4.0 // tagged from the wrong branch
|
||||
)
|
||||
|
||||
26
go.sum
26
go.sum
@@ -4,19 +4,21 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
|
||||
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
|
||||
github.com/ncruces/sort v0.1.2 h1:zKQ9CA4fpHPF6xsUhRTfi5EEryspuBpe/QA4VWQOV1U=
|
||||
github.com/ncruces/sort v0.1.2/go.mod h1:vEJUTBJtebIuCMmXD18GKo5GJGhsay+xZFOoBEIXFmE=
|
||||
github.com/ncruces/sort v0.1.6 h1:TrsJfGRH1AoWoaeB4/+gCohot9+cA6u/INaH5agIhNk=
|
||||
github.com/ncruces/sort v0.1.6/go.mod h1:obJToO4rYr6VWP0Uw5FYymgYGt3Br4RXcs/JdKaXAPk=
|
||||
github.com/ncruces/wbt v1.0.0 h1:8iBE7UPjTLUpzu3/FCRjAmuQjWzgxo10RGBgt3ooLSc=
|
||||
github.com/ncruces/wbt v1.0.0/go.mod h1:DtF92amvMxH69EmBFUSFWRDAlo6hOEfoNQnClxj9C/c=
|
||||
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
|
||||
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
|
||||
github.com/tetratelabs/wazero v1.8.2 h1:yIgLR/b2bN31bjxwXHD8a3d+BogigR952csSDdLYEv4=
|
||||
github.com/tetratelabs/wazero v1.8.2/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
github.com/tetratelabs/wazero v1.11.0 h1:+gKemEuKCTevU4d7ZTzlsvgd1uaToIDtlQlmNbwqYhA=
|
||||
github.com/tetratelabs/wazero v1.11.0/go.mod h1:eV28rsN8Q+xwjogd7f4/Pp4xFxO7uOGbLcD/LzB1wiU=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
lukechampine.com/adiantum v1.1.1 h1:4fp6gTxWCqpEbLy40ExiYDDED3oUNWx5cTqBCtPdZqA=
|
||||
lukechampine.com/adiantum v1.1.1/go.mod h1:LrAYVnTYLnUtE/yMp5bQr0HstAf060YUF8nM0B6+rUw=
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user