Compare commits

..

1 Commits

Author SHA1 Message Date
Nuno Cruces
fc3a993c3e Parquet vtab. 2025-01-07 16:33:01 +00:00
245 changed files with 3637 additions and 9915 deletions

View File

@@ -1,6 +1,11 @@
name: VM Actions matrix
description: VM Actions matrix template
inputs:
run:
description: The CI command to run
required: true
runs:
using: composite
steps:
@@ -8,4 +13,4 @@ runs:
with:
usesh: true
copyback: false
run: . ./test.sh
run: ${{inputs.run}}

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash
set -euo pipefail
echo 'set -eux' > test.sh
echo 'set -eu' > test.sh
for p in $(go list ./...); do
dir=".${p#github.com/ncruces/go-sqlite3}"

View File

@@ -1,23 +0,0 @@
name: Benchmark libc
on:
workflow_dispatch:
permissions:
contents: read
jobs:
test:
strategy:
matrix:
os: [ubuntu-24.04, ubuntu-24.04-arm, macos-15, macos-15-intel]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with: { go-version: stable }
- name: Benchmark
shell: bash
run: sqlite3/libc/benchmark.sh

View File

@@ -1,9 +1,28 @@
#!/usr/bin/env bash
set -euo pipefail
if [[ "$OSTYPE" == "linux"* ]]; then
WASI_SDK="https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-25/wasi-sdk-25.0-x86_64-linux.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_121/binaryen-version_121-x86_64-linux.tar.gz"
elif [[ "$OSTYPE" == "darwin"* ]]; then
WASI_SDK="https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-25/wasi-sdk-25.0-arm64-macos.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_121/binaryen-version_121-arm64-macos.tar.gz"
elif [[ "$OSTYPE" == "msys" || "$OSTYPE" == "cygwin" ]]; then
WASI_SDK="https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-25/wasi-sdk-25.0-x86_64-windows.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_121/binaryen-version_121-x86_64-windows.tar.gz"
fi
# Download tools
mkdir -p tools/
[ -d "tools/wasi-sdk" ] || curl -#L "$WASI_SDK" | tar xzC tools &
[ -d "tools/binaryen" ] || curl -#L "$BINARYEN" | tar xzC tools &
wait
[ -d "tools/wasi-sdk" ] || mv "tools/wasi-sdk"* "tools/wasi-sdk"
[ -d "tools/binaryen" ] || mv "tools/binaryen"* "tools/binaryen"
# Download and build SQLite
sqlite3/download.sh
sqlite3/tools.sh
embed/build.sh
embed/bcw2/build.sh

View File

@@ -17,13 +17,13 @@ jobs:
steps:
- uses: ilammy/msvc-dev-cmd@v1
- uses: actions/checkout@v6
- uses: actions/checkout@v4
- name: Build
shell: bash
run: .github/workflows/repro.sh
- uses: actions/attest-build-provenance@v3
- uses: actions/attest-build-provenance@v2
if: matrix.os == 'ubuntu-latest'
with:
subject-path: |

View File

@@ -7,31 +7,24 @@ on:
- '**.go'
- '**.mod'
- '**.wasm'
- '**.yml'
pull_request:
branches: [ 'main' ]
paths:
- '**.go'
- '**.mod'
- '**.wasm'
- '**.yml'
workflow_dispatch:
permissions:
contents: read
jobs:
test:
strategy:
matrix:
os: [macos-latest, ubuntu-latest, windows-latest]
runs-on: ${{ matrix.os }}
permissions:
contents: write
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with: { go-version: stable }
- name: Format
@@ -51,35 +44,29 @@ jobs:
run: go vet ./...
- name: Build
run: go build ./...
run: go build -v ./...
- name: Test
run: go test ./... -bench . -benchtime=1x
run: go test -v ./... -bench . -benchtime=1x
- name: Test BSD locks
run: go test -tags sqlite3_flock ./...
run: go test -v -tags sqlite3_flock ./...
if: matrix.os != 'windows-latest'
- name: Test dot locks
run: go test -tags sqlite3_dotlk ./...
run: go test -v -tags sqlite3_dotlk ./...
if: matrix.os != 'windows-latest'
- name: Test modules
shell: bash
run: |
go work init .
go work use -r embed/bcw2 gormlite
go test ./embed/bcw2 ./gormlite
- name: Test GORM
shell: bash
run: gormlite/test.sh
if: matrix.os == 'ubuntu-latest'
- name: Test modules
shell: bash
run: go test -v ./embed/bcw2/...
- name: Collect coverage
run: |
go get -tool github.com/dave/courtney@v0.4.4
go tool courtney
run: go run github.com/dave/courtney@latest
if: |
github.event_name == 'push' &&
matrix.os == 'ubuntu-latest'
@@ -93,45 +80,44 @@ jobs:
github.event_name == 'push' &&
matrix.os == 'ubuntu-latest'
test-cross:
test-bsd:
strategy:
matrix:
os:
- name: freebsd
version: '15.0'
version: '14.2'
flags: '-test.v'
- name: netbsd
version: '10.1'
- name: illumos
action: omnios
version: 'r151056'
- name: openbsd
version: '7.8'
tflags: '-test.short'
version: '10.0'
flags: '-test.v'
- name: freebsd
arch: arm64
version: '15.0'
tflags: '-test.short'
version: '14.2'
flags: '-test.v -test.short'
- name: netbsd
arch: arm64
version: '10.1'
tflags: '-test.short'
version: '10.0'
flags: '-test.v -test.short'
- name: openbsd
version: '7.6'
flags: '-test.v -test.short'
runs-on: ubuntu-latest
needs: test
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
- name: Build
env:
GOOS: ${{ matrix.os.name }}
GOARCH: ${{ matrix.os.arch }}
TESTFLAGS: ${{ matrix.os.tflags }}
TESTFLAGS: ${{ matrix.os.flags }}
run: .github/workflows/build-test.sh
- name: Test
uses: cross-platform-actions/action@v0.32.0
uses: cross-platform-actions/action@v0.26.0
with:
operating_system: ${{ matrix.os.action || matrix.os.name }}
operating_system: ${{ matrix.os.name }}
architecture: ${{ matrix.os.arch }}
version: ${{ matrix.os.version }}
shell: bash
@@ -144,16 +130,19 @@ jobs:
os:
- name: dragonfly
action: 'vmactions/dragonflybsd-vm@v1'
tflags: '-test.v'
- name: illumos
action: 'vmactions/openindiana-vm@v0'
action: 'vmactions/omnios-vm@v1'
tflags: '-test.v'
- name: solaris
action: 'vmactions/solaris-vm@v1'
bflags: '-tags sqlite3_dotlk'
tflags: '-test.v'
runs-on: ubuntu-latest
needs: test
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
- name: Build
env:
@@ -165,6 +154,10 @@ jobs:
- name: Test
uses: ./.github/actions/vmactions
with:
usesh: true
copyback: false
run: . ./test.sh
test-wasip1:
runs-on: ubuntu-latest
@@ -172,12 +165,12 @@ jobs:
steps:
- uses: bytecodealliance/actions/wasmtime/setup@v1
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with: { go-version: stable }
- name: Set path
run: echo "$(go env GOROOT)/lib/wasm" >> "$GITHUB_PATH"
run: echo "$(go env GOROOT)/misc/wasm" >> "$GITHUB_PATH"
- name: Test wasmtime
env:
@@ -185,7 +178,7 @@ jobs:
GOARCH: wasm
GOWASIRUNTIME: wasmtime
GOWASIRUNTIMEARGS: '--env CI=true'
run: go test -short -tags sqlite3_dotlk -skip Example ./...
run: go test -v -short -tags sqlite3_dotlk -skip Example ./...
test-qemu:
runs-on: ubuntu-latest
@@ -193,60 +186,33 @@ jobs:
steps:
- uses: docker/setup-qemu-action@v3
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with: { go-version: stable }
- name: Test 386 (32-bit)
run: GOARCH=386 go test -short ./...
run: GOARCH=386 go test -v -short ./...
- name: Test arm64 (compiler)
run: GOARCH=arm64 go test -v -short ./...
- name: Test riscv64 (interpreter)
run: GOARCH=riscv64 go test -short ./...
run: GOARCH=riscv64 go test -v -short ./...
- name: Test ppc64le (interpreter)
run: GOARCH=ppc64le go test -short ./...
- name: Test loong64 (interpreter)
run: GOARCH=loong64 go test -short ./...
run: GOARCH=ppc64le go test -v -short ./...
- name: Test s390x (big-endian)
run: GOARCH=s390x go test -short -tags sqlite3_dotlk ./...
test-linuxarm:
runs-on: ubuntu-24.04-arm
needs: test
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with: { go-version: stable }
- name: Test
run: go test ./...
- name: Test arm (32-bit)
run: GOARCH=arm GOARM=7 go test -short ./...
run: GOARCH=s390x go test -v -short -tags sqlite3_dotlk ./...
test-macintel:
runs-on: macos-15-intel
runs-on: macos-13
needs: test
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with: { go-version: stable }
- name: Test
run: go test ./...
test-winarm:
runs-on: windows-11-arm
needs: test
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with: { go-version: stable }
- name: Test
run: go test ./...
run: go test -v ./...

9
.gitignore vendored
View File

@@ -13,11 +13,4 @@
# Dependency directories (remove the comment below to include it)
# vendor/
tools
# Go workspace file
go.work
go.work.sum
# env file
.env
tools

View File

@@ -30,10 +30,10 @@ db.QueryRow(`SELECT sqlite_version()`).Scan(&version)
- [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
wraps the [C SQLite API](https://sqlite.org/cintro.html)
([example](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
- [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver)
provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
([example](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
- [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
embeds a build of SQLite into your application.
- [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
@@ -44,19 +44,12 @@ db.QueryRow(`SELECT sqlite_version()`).Scan(&version)
### Advanced features
- [incremental BLOB I/O](https://sqlite.org/c3ref/blob_open.html)
([example](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/blobio#example-package))
- [nested transactions](https://sqlite.org/lang_savepoint.html)
([example](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-Savepoint))
- [custom functions](https://sqlite.org/c3ref/create_function.html)
([example](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-Conn.CreateFunction))
- [virtual tables](https://sqlite.org/vtab.html)
([example](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-CreateModule))
- [custom VFSes](https://sqlite.org/vfs.html)
([examples](vfs/README.md#custom-vfses))
- [online backup](https://sqlite.org/backup.html)
([example](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#Conn))
- [JSON support](https://sqlite.org/json1.html)
([example](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package-Json))
- [math functions](https://sqlite.org/lang_mathfunc.html)
- [full-text search](https://sqlite.org/fts5.html)
- [geospatial search](https://sqlite.org/geopoly.html)
@@ -64,6 +57,7 @@ db.QueryRow(`SELECT sqlite_version()`).Scan(&version)
- [statistics functions](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
- [encryption at rest](vfs/adiantum/README.md)
- [many extensions](ext/README.md)
- [custom VFSes](vfs/README.md#custom-vfses)
- [and more…](embed/README.md)
### Caveats
@@ -71,52 +65,31 @@ db.QueryRow(`SELECT sqlite_version()`).Scan(&version)
This module replaces the SQLite [OS Interface](https://sqlite.org/vfs.html)
(aka VFS) with a [pure Go](vfs/) implementation,
which has advantages and disadvantages.
Read more about the Go VFS design [here](vfs/README.md).
Because each database connection executes within a Wasm sandboxed environment,
memory usage will be higher than alternatives.
Read more about the Go VFS design [here](vfs/README.md).
### Testing
This project aims for [high test coverage](https://github.com/ncruces/go-sqlite3/wiki/Test-coverage-report).
It also benefits greatly from [SQLite's](https://sqlite.org/testing.html) and
[wazero's](https://tetrate.io/blog/introducing-wazero-from-tetrate/#:~:text=Rock%2Dsolid%20test%20approach)
thorough testing.
[wazero's](https://tetrate.io/blog/introducing-wazero-from-tetrate/#:~:text=Rock%2Dsolid%20test%20approach) thorough testing.
Every commit is tested on:
* Linux: amd64, arm64, 386, arm, riscv64, ppc64le, loong64, s390x
* macOS: amd64, arm64
* Windows: amd64, arm64
* BSD:
* FreeBSD: amd64, arm64
* NetBSD: amd64, arm64
* DragonFly BSD: amd64
* OpenBSD: amd64
* illumos: amd64
* Solaris: amd64
Certain operating system and CPU combinations have some limitations. See the [support matrix](https://github.com/ncruces/go-sqlite3/wiki/Support-matrix) for a complete overview.
Every commit is [tested](https://github.com/ncruces/go-sqlite3/wiki/Support-matrix) on
Linux (amd64/arm64/386/riscv64/ppc64le/s390x), macOS (amd64/arm64),
Windows (amd64), FreeBSD (amd64), OpenBSD (amd64), NetBSD (amd64),
DragonFly BSD (amd64), illumos (amd64), and Solaris (amd64).
The Go VFS is tested by running SQLite's
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c).
### Performance
Performance of the [`database/sql`](https://pkg.go.dev/database/sql) driver is
Perfomance of the [`database/sql`](https://pkg.go.dev/database/sql) driver is
[competitive](https://github.com/cvilsmeier/go-sqlite-bench) with alternatives.
The Wasm and VFS layers are also benchmarked by running SQLite's
The Wasm and VFS layers are also tested by running SQLite's
[speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c).
### Concurrency
This module behaves similarly to SQLite in [multi-thread](https://sqlite.org/threadsafe.html) mode:
it is goroutine-safe, provided that no single database connection, or object derived from it,
is used concurrently by multiple goroutines.
The [`database/sql`](https://pkg.go.dev/database/sql) API is safe to use concurrently,
according to its documentation.
### FAQ, issues, new features
For questions, please see [Discussions](https://github.com/ncruces/go-sqlite3/discussions/categories/q-a).
@@ -125,7 +98,7 @@ Also, post there if you used this driver for something interesting
([_"Show and tell"_](https://github.com/ncruces/go-sqlite3/discussions/categories/show-and-tell)),
have an [idea](https://github.com/ncruces/go-sqlite3/discussions/categories/ideas)…
The [Issue](https://github.com/ncruces/go-sqlite3/issues) tracker is for bugs,
The [Issue](https://github.com/ncruces/go-sqlite3/issues) tracker is for bugs we want fixed,
and features we're working on, planning to work on, or asking for help with.
### Alternatives

View File

@@ -5,8 +5,8 @@ package sqlite3
// https://sqlite.org/c3ref/backup.html
type Backup struct {
c *Conn
handle ptr_t
otherc ptr_t
handle uint32
otherc uint32
}
// Backup backs up srcDB on the src connection to the "main" database in dstURI.
@@ -61,7 +61,7 @@ func (src *Conn) BackupInit(srcDB, dstURI string) (*Backup, error) {
return src.backupInit(dst, "main", src.handle, srcDB)
}
func (c *Conn) backupInit(dst ptr_t, dstName string, src ptr_t, srcName string) (*Backup, error) {
func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string) (*Backup, error) {
defer c.arena.mark()()
dstPtr := c.arena.string(dstName)
srcPtr := c.arena.string(srcName)
@@ -71,19 +71,19 @@ func (c *Conn) backupInit(dst ptr_t, dstName string, src ptr_t, srcName string)
other = src
}
ptr := ptr_t(c.call("sqlite3_backup_init",
stk_t(dst), stk_t(dstPtr),
stk_t(src), stk_t(srcPtr)))
if ptr == 0 {
r := c.call("sqlite3_backup_init",
uint64(dst), uint64(dstPtr),
uint64(src), uint64(srcPtr))
if r == 0 {
defer c.closeDB(other)
rc := res_t(c.call("sqlite3_errcode", stk_t(dst)))
return nil, c.sqlite.error(rc, dst)
r = c.call("sqlite3_errcode", uint64(dst))
return nil, c.sqlite.error(r, dst)
}
return &Backup{
c: c,
otherc: other,
handle: ptr,
handle: uint32(r),
}, nil
}
@@ -97,10 +97,10 @@ func (b *Backup) Close() error {
return nil
}
rc := res_t(b.c.call("sqlite3_backup_finish", stk_t(b.handle)))
r := b.c.call("sqlite3_backup_finish", uint64(b.handle))
b.c.closeDB(b.otherc)
b.handle = 0
return b.c.error(rc)
return b.c.error(r)
}
// Step copies up to nPage pages between the source and destination databases.
@@ -108,11 +108,11 @@ func (b *Backup) Close() error {
//
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backupstep
func (b *Backup) Step(nPage int) (done bool, err error) {
rc := res_t(b.c.call("sqlite3_backup_step", stk_t(b.handle), stk_t(nPage)))
if rc == _DONE {
r := b.c.call("sqlite3_backup_step", uint64(b.handle), uint64(nPage))
if r == _DONE {
return true, nil
}
return false, b.c.error(rc)
return false, b.c.error(r)
}
// Remaining returns the number of pages still to be backed up
@@ -120,8 +120,8 @@ func (b *Backup) Step(nPage int) (done bool, err error) {
//
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backupremaining
func (b *Backup) Remaining() int {
n := int32(b.c.call("sqlite3_backup_remaining", stk_t(b.handle)))
return int(n)
r := b.c.call("sqlite3_backup_remaining", uint64(b.handle))
return int(int32(r))
}
// PageCount returns the total number of pages in the source database
@@ -129,6 +129,6 @@ func (b *Backup) Remaining() int {
//
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backuppagecount
func (b *Backup) PageCount() int {
n := int32(b.c.call("sqlite3_backup_pagecount", stk_t(b.handle)))
return int(n)
r := b.c.call("sqlite3_backup_pagecount", uint64(b.handle))
return int(int32(r))
}

73
blob.go
View File

@@ -20,8 +20,8 @@ type Blob struct {
c *Conn
bytes int64
offset int64
handle ptr_t
bufptr ptr_t
handle uint32
bufptr uint32
buflen int64
}
@@ -31,32 +31,29 @@ var _ io.ReadWriteSeeker = &Blob{}
//
// https://sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
if c.interrupt.Err() != nil {
return nil, INTERRUPT
}
defer c.arena.mark()()
blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db)
tablePtr := c.arena.string(table)
columnPtr := c.arena.string(column)
var flags int32
var flags uint64
if write {
flags = 1
}
rc := res_t(c.call("sqlite3_blob_open", stk_t(c.handle),
stk_t(dbPtr), stk_t(tablePtr), stk_t(columnPtr),
stk_t(row), stk_t(flags), stk_t(blobPtr)))
c.checkInterrupt(c.handle)
r := c.call("sqlite3_blob_open", uint64(c.handle),
uint64(dbPtr), uint64(tablePtr), uint64(columnPtr),
uint64(row), flags, uint64(blobPtr))
if err := c.error(rc); err != nil {
if err := c.error(r); err != nil {
return nil, err
}
blob := Blob{c: c}
blob.handle = util.Read32[ptr_t](c.mod, blobPtr)
blob.bytes = int64(int32(c.call("sqlite3_blob_bytes", stk_t(blob.handle))))
blob.handle = util.ReadUint32(c.mod, blobPtr)
blob.bytes = int64(c.call("sqlite3_blob_bytes", uint64(blob.handle)))
return &blob, nil
}
@@ -70,10 +67,10 @@ func (b *Blob) Close() error {
return nil
}
rc := res_t(b.c.call("sqlite3_blob_close", stk_t(b.handle)))
r := b.c.call("sqlite3_blob_close", uint64(b.handle))
b.c.free(b.bufptr)
b.handle = 0
return b.c.error(rc)
return b.c.error(r)
}
// Size returns the size of the BLOB in bytes.
@@ -97,13 +94,13 @@ func (b *Blob) Read(p []byte) (n int, err error) {
want = avail
}
if want > b.buflen {
b.bufptr = b.c.realloc(b.bufptr, want)
b.bufptr = b.c.realloc(b.bufptr, uint64(want))
b.buflen = want
}
rc := res_t(b.c.call("sqlite3_blob_read", stk_t(b.handle),
stk_t(b.bufptr), stk_t(want), stk_t(b.offset)))
err = b.c.error(rc)
r := b.c.call("sqlite3_blob_read", uint64(b.handle),
uint64(b.bufptr), uint64(want), uint64(b.offset))
err = b.c.error(r)
if err != nil {
return 0, err
}
@@ -112,7 +109,7 @@ func (b *Blob) Read(p []byte) (n int, err error) {
err = io.EOF
}
copy(p, util.View(b.c.mod, b.bufptr, want))
copy(p, util.View(b.c.mod, b.bufptr, uint64(want)))
return int(want), err
}
@@ -130,19 +127,19 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
want = avail
}
if want > b.buflen {
b.bufptr = b.c.realloc(b.bufptr, want)
b.bufptr = b.c.realloc(b.bufptr, uint64(want))
b.buflen = want
}
for want > 0 {
rc := res_t(b.c.call("sqlite3_blob_read", stk_t(b.handle),
stk_t(b.bufptr), stk_t(want), stk_t(b.offset)))
err = b.c.error(rc)
r := b.c.call("sqlite3_blob_read", uint64(b.handle),
uint64(b.bufptr), uint64(want), uint64(b.offset))
err = b.c.error(r)
if err != nil {
return n, err
}
mem := util.View(b.c.mod, b.bufptr, want)
mem := util.View(b.c.mod, b.bufptr, uint64(want))
m, err := w.Write(mem[:want])
b.offset += int64(m)
n += int64(m)
@@ -168,14 +165,14 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
func (b *Blob) Write(p []byte) (n int, err error) {
want := int64(len(p))
if want > b.buflen {
b.bufptr = b.c.realloc(b.bufptr, want)
b.bufptr = b.c.realloc(b.bufptr, uint64(want))
b.buflen = want
}
util.WriteBytes(b.c.mod, b.bufptr, p)
rc := res_t(b.c.call("sqlite3_blob_write", stk_t(b.handle),
stk_t(b.bufptr), stk_t(want), stk_t(b.offset)))
err = b.c.error(rc)
r := b.c.call("sqlite3_blob_write", uint64(b.handle),
uint64(b.bufptr), uint64(want), uint64(b.offset))
err = b.c.error(r)
if err != nil {
return 0, err
}
@@ -199,17 +196,17 @@ func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) {
want = 1
}
if want > b.buflen {
b.bufptr = b.c.realloc(b.bufptr, want)
b.bufptr = b.c.realloc(b.bufptr, uint64(want))
b.buflen = want
}
for {
mem := util.View(b.c.mod, b.bufptr, want)
mem := util.View(b.c.mod, b.bufptr, uint64(want))
m, err := r.Read(mem[:want])
if m > 0 {
rc := res_t(b.c.call("sqlite3_blob_write", stk_t(b.handle),
stk_t(b.bufptr), stk_t(m), stk_t(b.offset)))
err := b.c.error(rc)
r := b.c.call("sqlite3_blob_write", uint64(b.handle),
uint64(b.bufptr), uint64(m), uint64(b.offset))
err := b.c.error(r)
if err != nil {
return n, err
}
@@ -256,11 +253,9 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
//
// https://sqlite.org/c3ref/blob_reopen.html
func (b *Blob) Reopen(row int64) error {
if b.c.interrupt.Err() != nil {
return INTERRUPT
}
err := b.c.error(res_t(b.c.call("sqlite3_blob_reopen", stk_t(b.handle), stk_t(row))))
b.bytes = int64(int32(b.c.call("sqlite3_blob_bytes", stk_t(b.handle))))
b.c.checkInterrupt(b.c.handle)
err := b.c.error(b.c.call("sqlite3_blob_reopen", uint64(b.handle), uint64(row)))
b.bytes = int64(b.c.call("sqlite3_blob_bytes", uint64(b.handle)))
b.offset = 0
return err
}

184
config.go
View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"strconv"
"sync/atomic"
"github.com/tetratelabs/wazero/api"
@@ -33,7 +32,7 @@ func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) {
defer c.arena.mark()()
argsPtr := c.arena.new(intlen + ptrlen)
var flag int32
var flag int
switch {
case len(arg) == 0:
flag = -1
@@ -41,40 +40,31 @@ func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) {
flag = 1
}
util.Write32(c.mod, argsPtr+0*ptrlen, flag)
util.Write32(c.mod, argsPtr+1*ptrlen, argsPtr)
util.WriteUint32(c.mod, argsPtr+0*ptrlen, uint32(flag))
util.WriteUint32(c.mod, argsPtr+1*ptrlen, argsPtr)
rc := res_t(c.call("sqlite3_db_config", stk_t(c.handle),
stk_t(op), stk_t(argsPtr)))
return util.ReadBool(c.mod, argsPtr), c.error(rc)
}
var defaultLogger atomic.Pointer[func(code ExtendedErrorCode, msg string)]
// ConfigLog sets up the default error logging callback for new connections.
//
// https://sqlite.org/errlog.html
func ConfigLog(cb func(code ExtendedErrorCode, msg string)) {
defaultLogger.Store(&cb)
r := c.call("sqlite3_db_config", uint64(c.handle),
uint64(op), uint64(argsPtr))
return util.ReadUint32(c.mod, argsPtr) != 0, c.error(r)
}
// ConfigLog sets up the error logging callback for the connection.
//
// https://sqlite.org/errlog.html
func (c *Conn) ConfigLog(cb func(code ExtendedErrorCode, msg string)) error {
var enable int32
var enable uint64
if cb != nil {
enable = 1
}
rc := res_t(c.call("sqlite3_config_log_go", stk_t(enable)))
if err := c.error(rc); err != nil {
r := c.call("sqlite3_config_log_go", enable)
if err := c.error(r); err != nil {
return err
}
c.log = cb
return nil
}
func logCallback(ctx context.Context, mod api.Module, _ ptr_t, iCode res_t, zMsg ptr_t) {
func logCallback(ctx context.Context, mod api.Module, _, iCode, zMsg uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.log != nil {
msg := util.ReadString(mod, zMsg, _MAX_LENGTH)
c.log(xErrorCode(iCode), msg)
@@ -98,97 +88,93 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro
defer c.arena.mark()()
ptr := c.arena.new(max(ptrlen, intlen))
var schemaPtr ptr_t
var schemaPtr uint32
if schema != "" {
schemaPtr = c.arena.string(schema)
}
var rc res_t
var ret any
var rc uint64
var res any
switch op {
default:
return nil, MISUSE
case FCNTL_RESET_CACHE, FCNTL_NULL_IO:
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), 0))
case FCNTL_RESET_CACHE:
rc = c.call("sqlite3_file_control",
uint64(c.handle), uint64(schemaPtr),
uint64(op), 0)
case FCNTL_PERSIST_WAL, FCNTL_POWERSAFE_OVERWRITE:
var flag int32
var flag int
switch {
case len(arg) == 0:
flag = -1
case arg[0]:
flag = 1
}
util.Write32(c.mod, ptr, flag)
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
ret = util.ReadBool(c.mod, ptr)
util.WriteUint32(c.mod, ptr, uint32(flag))
rc = c.call("sqlite3_file_control",
uint64(c.handle), uint64(schemaPtr),
uint64(op), uint64(ptr))
res = util.ReadUint32(c.mod, ptr) != 0
case FCNTL_CHUNK_SIZE:
util.Write32(c.mod, ptr, int32(arg[0].(int)))
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
util.WriteUint32(c.mod, ptr, uint32(arg[0].(int)))
rc = c.call("sqlite3_file_control",
uint64(c.handle), uint64(schemaPtr),
uint64(op), uint64(ptr))
case FCNTL_RESERVE_BYTES:
bytes := -1
if len(arg) > 0 {
bytes = arg[0].(int)
}
util.Write32(c.mod, ptr, int32(bytes))
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
ret = int(util.Read32[int32](c.mod, ptr))
util.WriteUint32(c.mod, ptr, uint32(bytes))
rc = c.call("sqlite3_file_control",
uint64(c.handle), uint64(schemaPtr),
uint64(op), uint64(ptr))
res = int(util.ReadUint32(c.mod, ptr))
case FCNTL_DATA_VERSION:
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
ret = util.Read32[uint32](c.mod, ptr)
rc = c.call("sqlite3_file_control",
uint64(c.handle), uint64(schemaPtr),
uint64(op), uint64(ptr))
res = util.ReadUint32(c.mod, ptr)
case FCNTL_LOCKSTATE:
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
ret = util.Read32[vfs.LockLevel](c.mod, ptr)
rc = c.call("sqlite3_file_control",
uint64(c.handle), uint64(schemaPtr),
uint64(op), uint64(ptr))
res = vfs.LockLevel(util.ReadUint32(c.mod, ptr))
case FCNTL_VFSNAME, FCNTL_VFS_POINTER:
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(FCNTL_VFS_POINTER), stk_t(ptr)))
case FCNTL_VFS_POINTER:
rc = c.call("sqlite3_file_control",
uint64(c.handle), uint64(schemaPtr),
uint64(op), uint64(ptr))
if rc == _OK {
const zNameOffset = 16
ptr = util.Read32[ptr_t](c.mod, ptr)
ptr = util.Read32[ptr_t](c.mod, ptr+zNameOffset)
ptr = util.ReadUint32(c.mod, ptr)
ptr = util.ReadUint32(c.mod, ptr+zNameOffset)
name := util.ReadString(c.mod, ptr, _MAX_NAME)
if op == FCNTL_VFS_POINTER {
ret = vfs.Find(name)
} else {
ret = name
}
res = vfs.Find(name)
}
case FCNTL_FILE_POINTER, FCNTL_JOURNAL_POINTER:
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
rc = c.call("sqlite3_file_control",
uint64(c.handle), uint64(schemaPtr),
uint64(op), uint64(ptr))
if rc == _OK {
const fileHandleOffset = 4
ptr = util.Read32[ptr_t](c.mod, ptr)
ptr = util.Read32[ptr_t](c.mod, ptr+fileHandleOffset)
ret = util.GetHandle(c.ctx, ptr)
ptr = util.ReadUint32(c.mod, ptr)
ptr = util.ReadUint32(c.mod, ptr+fileHandleOffset)
res = util.GetHandle(c.ctx, ptr)
}
}
if err := c.error(rc); err != nil {
return nil, err
}
return ret, nil
return res, nil
}
// Limit allows the size of various constructs to be
@@ -196,20 +182,20 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro
//
// https://sqlite.org/c3ref/limit.html
func (c *Conn) Limit(id LimitCategory, value int) int {
v := int32(c.call("sqlite3_limit", stk_t(c.handle), stk_t(id), stk_t(value)))
return int(v)
r := c.call("sqlite3_limit", uint64(c.handle), uint64(id), uint64(value))
return int(int32(r))
}
// SetAuthorizer registers an authorizer callback with the database connection.
//
// https://sqlite.org/c3ref/set_authorizer.html
func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4th, schema, inner string) AuthorizerReturnCode) error {
var enable int32
var enable uint64
if cb != nil {
enable = 1
}
rc := res_t(c.call("sqlite3_set_authorizer_go", stk_t(c.handle), stk_t(enable)))
if err := c.error(rc); err != nil {
r := c.call("sqlite3_set_authorizer_go", uint64(c.handle), enable)
if err := c.error(r); err != nil {
return err
}
c.authorizer = cb
@@ -217,7 +203,7 @@ func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4
}
func authorizerCallback(ctx context.Context, mod api.Module, pDB ptr_t, action AuthorizerActionCode, zName3rd, zName4th, zSchema, zInner ptr_t) (rc AuthorizerReturnCode) {
func authorizerCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zName3rd, zName4th, zSchema, zInner uint32) (rc AuthorizerReturnCode) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.authorizer != nil {
var name3rd, name4th, schema, inner string
if zName3rd != 0 {
@@ -241,15 +227,15 @@ func authorizerCallback(ctx context.Context, mod api.Module, pDB ptr_t, action A
//
// https://sqlite.org/c3ref/trace_v2.html
func (c *Conn) Trace(mask TraceEvent, cb func(evt TraceEvent, arg1 any, arg2 any) error) error {
rc := res_t(c.call("sqlite3_trace_go", stk_t(c.handle), stk_t(mask)))
if err := c.error(rc); err != nil {
r := c.call("sqlite3_trace_go", uint64(c.handle), uint64(mask))
if err := c.error(r); err != nil {
return err
}
c.trace = cb
return nil
}
func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pArg1, pArg2 ptr_t) (rc res_t) {
func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pArg1, pArg2 uint32) (rc uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.trace != nil {
var arg1, arg2 any
if evt == TRACE_CLOSE {
@@ -262,14 +248,14 @@ func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pAr
case TRACE_STMT:
arg2 = s.SQL()
case TRACE_PROFILE:
arg2 = util.Read64[int64](mod, pArg2)
arg2 = int64(util.ReadUint64(mod, pArg2))
}
break
}
}
}
if arg1 != nil {
_ = c.trace(evt, arg1, arg2)
_, rc = errorCode(c.trace(evt, arg1, arg2), ERROR)
}
}
return rc
@@ -279,28 +265,24 @@ func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pAr
//
// https://sqlite.org/c3ref/wal_checkpoint_v2.html
func (c *Conn) WALCheckpoint(schema string, mode CheckpointMode) (nLog, nCkpt int, err error) {
if c.interrupt.Err() != nil {
return 0, 0, INTERRUPT
}
defer c.arena.mark()()
nLogPtr := c.arena.new(ptrlen)
nCkptPtr := c.arena.new(ptrlen)
schemaPtr := c.arena.string(schema)
rc := res_t(c.call("sqlite3_wal_checkpoint_v2",
stk_t(c.handle), stk_t(schemaPtr), stk_t(mode),
stk_t(nLogPtr), stk_t(nCkptPtr)))
nLog = int(util.Read32[int32](c.mod, nLogPtr))
nCkpt = int(util.Read32[int32](c.mod, nCkptPtr))
return nLog, nCkpt, c.error(rc)
r := c.call("sqlite3_wal_checkpoint_v2",
uint64(c.handle), uint64(schemaPtr), uint64(mode),
uint64(nLogPtr), uint64(nCkptPtr))
nLog = int(int32(util.ReadUint32(c.mod, nLogPtr)))
nCkpt = int(int32(util.ReadUint32(c.mod, nCkptPtr)))
return nLog, nCkpt, c.error(r)
}
// WALAutoCheckpoint configures WAL auto-checkpoints.
//
// https://sqlite.org/c3ref/wal_autocheckpoint.html
func (c *Conn) WALAutoCheckpoint(pages int) error {
rc := res_t(c.call("sqlite3_wal_autocheckpoint", stk_t(c.handle), stk_t(pages)))
return c.error(rc)
r := c.call("sqlite3_wal_autocheckpoint", uint64(c.handle), uint64(pages))
return c.error(r)
}
// WALHook registers a callback function to be invoked
@@ -308,15 +290,15 @@ func (c *Conn) WALAutoCheckpoint(pages int) error {
//
// https://sqlite.org/c3ref/wal_hook.html
func (c *Conn) WALHook(cb func(db *Conn, schema string, pages int) error) {
var enable int32
var enable uint64
if cb != nil {
enable = 1
}
c.call("sqlite3_wal_hook_go", stk_t(c.handle), stk_t(enable))
c.call("sqlite3_wal_hook_go", uint64(c.handle), enable)
c.wal = cb
}
func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema ptr_t, pages int32) (rc res_t) {
func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema uint32, pages int32) (rc uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.wal != nil {
schema := util.ReadString(mod, zSchema, _MAX_NAME)
err := c.wal(c, schema, int(pages))
@@ -329,15 +311,15 @@ func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema ptr_t, pag
//
// https://sqlite.org/c3ref/autovacuum_pages.html
func (c *Conn) AutoVacuumPages(cb func(schema string, dbPages, freePages, bytesPerPage uint) uint) error {
var funcPtr ptr_t
var funcPtr uint32
if cb != nil {
funcPtr = util.AddHandle(c.ctx, cb)
}
rc := res_t(c.call("sqlite3_autovacuum_pages_go", stk_t(c.handle), stk_t(funcPtr)))
return c.error(rc)
r := c.call("sqlite3_autovacuum_pages_go", uint64(c.handle), uint64(funcPtr))
return c.error(r)
}
func autoVacuumCallback(ctx context.Context, mod api.Module, pApp, zSchema ptr_t, nDbPage, nFreePage, nBytePerPage uint32) uint32 {
func autoVacuumCallback(ctx context.Context, mod api.Module, pApp, zSchema, nDbPage, nFreePage, nBytePerPage uint32) uint32 {
fn := util.GetHandle(ctx, pApp).(func(schema string, dbPages, freePages, bytesPerPage uint) uint)
schema := util.ReadString(mod, zSchema, _MAX_NAME)
return uint32(fn(schema, uint(nDbPage), uint(nFreePage), uint(nBytePerPage)))
@@ -347,14 +329,14 @@ func autoVacuumCallback(ctx context.Context, mod api.Module, pApp, zSchema ptr_t
//
// https://sqlite.org/c3ref/hard_heap_limit64.html
func (c *Conn) SoftHeapLimit(n int64) int64 {
return int64(c.call("sqlite3_soft_heap_limit64", stk_t(n)))
return int64(c.call("sqlite3_soft_heap_limit64", uint64(n)))
}
// HardHeapLimit imposes a hard limit on heap size.
//
// https://sqlite.org/c3ref/hard_heap_limit64.html
func (c *Conn) HardHeapLimit(n int64) int64 {
return int64(c.call("sqlite3_hard_heap_limit64", stk_t(n)))
return int64(c.call("sqlite3_hard_heap_limit64", uint64(n)))
}
// EnableChecksums enables checksums on a database.
@@ -396,6 +378,6 @@ func (c *Conn) EnableChecksums(schema string) error {
}
// Checkpoint the WAL.
_, _, err = c.WALCheckpoint(schema, CHECKPOINT_FULL)
_, _, err = c.WALCheckpoint(schema, CHECKPOINT_RESTART)
return err
}

256
conn.go
View File

@@ -3,7 +3,6 @@ package sqlite3
import (
"context"
"fmt"
"iter"
"math"
"math/rand"
"net/url"
@@ -25,6 +24,7 @@ type Conn struct {
*sqlite
interrupt context.Context
pending *Stmt
stmts []*Stmt
busy func(context.Context, int) bool
log func(xErrorCode, string)
@@ -35,12 +35,11 @@ type Conn struct {
update func(AuthorizerActionCode, string, string, int64)
commit func() bool
rollback func()
arena arena
busy1st time.Time
busylst time.Time
arena arena
handle ptr_t
gosched uint8
handle uint32
}
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI].
@@ -49,7 +48,7 @@ func Open(filename string) (*Conn, error) {
}
// OpenContext is like [Open] but includes a context,
// which is used to interrupt the process of opening the connection.
// which is used to interrupt the process of opening the connectiton.
func OpenContext(ctx context.Context, filename string) (*Conn, error) {
return newConn(ctx, filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
}
@@ -69,9 +68,9 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
return newConn(context.Background(), filename, flags)
}
type connKey = util.ConnKey
type connKey struct{}
func newConn(ctx context.Context, filename string, flags OpenFlag) (ret *Conn, _ error) {
func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _ error) {
err := ctx.Err()
if err != nil {
return nil, err
@@ -83,7 +82,7 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (ret *Conn, _
return nil, err
}
defer func() {
if ret == nil {
if res == nil {
c.Close()
c.sqlite.close()
} else {
@@ -92,10 +91,7 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (ret *Conn, _
}()
c.ctx = context.WithValue(c.ctx, connKey{}, c)
if logger := defaultLogger.Load(); logger != nil {
c.ConfigLog(*logger)
}
c.arena = c.newArena()
c.arena = c.newArena(1024)
c.handle, err = c.openDB(filename, flags)
if err == nil {
err = initExtensions(c)
@@ -106,21 +102,21 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (ret *Conn, _
return c, nil
}
func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) {
func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
defer c.arena.mark()()
connPtr := c.arena.new(ptrlen)
namePtr := c.arena.string(filename)
flags |= OPEN_EXRESCODE
rc := res_t(c.call("sqlite3_open_v2", stk_t(namePtr), stk_t(connPtr), stk_t(flags), 0))
r := c.call("sqlite3_open_v2", uint64(namePtr), uint64(connPtr), uint64(flags), 0)
handle := util.Read32[ptr_t](c.mod, connPtr)
if err := c.sqlite.error(rc, handle); err != nil {
handle := util.ReadUint32(c.mod, connPtr)
if err := c.sqlite.error(r, handle); err != nil {
c.closeDB(handle)
return 0, err
}
c.call("sqlite3_progress_handler_go", stk_t(handle), 1000)
c.call("sqlite3_progress_handler_go", uint64(handle), 100)
if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
var pragmas strings.Builder
if _, after, ok := strings.Cut(filename, "?"); ok {
@@ -132,9 +128,10 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) {
}
}
if pragmas.Len() != 0 {
c.checkInterrupt(handle)
pragmaPtr := c.arena.string(pragmas.String())
rc := res_t(c.call("sqlite3_exec", stk_t(handle), stk_t(pragmaPtr), 0, 0, 0))
if err := c.sqlite.error(rc, handle, pragmas.String()); err != nil {
r := c.call("sqlite3_exec", uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.sqlite.error(r, handle, pragmas.String()); err != nil {
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
c.closeDB(handle)
return 0, err
@@ -144,9 +141,9 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) {
return handle, nil
}
func (c *Conn) closeDB(handle ptr_t) {
rc := res_t(c.call("sqlite3_close_v2", stk_t(handle)))
if err := c.sqlite.error(rc, handle); err != nil {
func (c *Conn) closeDB(handle uint32) {
r := c.call("sqlite3_close_v2", uint64(handle))
if err := c.sqlite.error(r, handle); err != nil {
panic(err)
}
}
@@ -165,8 +162,11 @@ func (c *Conn) Close() error {
return nil
}
rc := res_t(c.call("sqlite3_close", stk_t(c.handle)))
if err := c.error(rc); err != nil {
c.pending.Close()
c.pending = nil
r := c.call("sqlite3_close", uint64(c.handle))
if err := c.error(r); err != nil {
return err
}
@@ -179,17 +179,12 @@ func (c *Conn) Close() error {
//
// https://sqlite.org/c3ref/exec.html
func (c *Conn) Exec(sql string) error {
if c.interrupt.Err() != nil {
return INTERRUPT
}
return c.exec(sql)
}
func (c *Conn) exec(sql string) error {
defer c.arena.mark()()
textPtr := c.arena.string(sql)
rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(textPtr), 0, 0, 0))
return c.error(rc, sql)
sqlPtr := c.arena.string(sql)
c.checkInterrupt(c.handle)
r := c.call("sqlite3_exec", uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
return c.error(r, sql)
}
// Prepare calls [Conn.PrepareFlags] with no flags.
@@ -207,26 +202,24 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
if len(sql) > _MAX_SQL_LENGTH {
return nil, "", TOOBIG
}
if c.interrupt.Err() != nil {
return nil, "", INTERRUPT
}
defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen)
tailPtr := c.arena.new(ptrlen)
textPtr := c.arena.string(sql)
sqlPtr := c.arena.string(sql)
rc := res_t(c.call("sqlite3_prepare_v3", stk_t(c.handle),
stk_t(textPtr), stk_t(len(sql)+1), stk_t(flags),
stk_t(stmtPtr), stk_t(tailPtr)))
c.checkInterrupt(c.handle)
r := c.call("sqlite3_prepare_v3", uint64(c.handle),
uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),
uint64(stmtPtr), uint64(tailPtr))
stmt = &Stmt{c: c, sql: sql}
stmt.handle = util.Read32[ptr_t](c.mod, stmtPtr)
if sql := sql[util.Read32[ptr_t](c.mod, tailPtr)-textPtr:]; sql != "" {
stmt = &Stmt{c: c}
stmt.handle = util.ReadUint32(c.mod, stmtPtr)
if sql := sql[util.ReadUint32(c.mod, tailPtr)-sqlPtr:]; sql != "" {
tail = sql
}
if err := c.error(rc, sql); err != nil {
if err := c.error(r, sql); err != nil {
return nil, "", err
}
if stmt.handle == 0 {
@@ -240,7 +233,9 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
//
// https://sqlite.org/c3ref/db_name.html
func (c *Conn) DBName(n int) string {
ptr := ptr_t(c.call("sqlite3_db_name", stk_t(c.handle), stk_t(n)))
r := c.call("sqlite3_db_name", uint64(c.handle), uint64(n))
ptr := uint32(r)
if ptr == 0 {
return ""
}
@@ -251,34 +246,34 @@ func (c *Conn) DBName(n int) string {
//
// https://sqlite.org/c3ref/db_filename.html
func (c *Conn) Filename(schema string) *vfs.Filename {
var ptr ptr_t
var ptr uint32
if schema != "" {
defer c.arena.mark()()
ptr = c.arena.string(schema)
}
ptr = ptr_t(c.call("sqlite3_db_filename", stk_t(c.handle), stk_t(ptr)))
return vfs.GetFilename(c.ctx, c.mod, ptr, vfs.OPEN_MAIN_DB)
r := c.call("sqlite3_db_filename", uint64(c.handle), uint64(ptr))
return vfs.GetFilename(c.ctx, c.mod, uint32(r), vfs.OPEN_MAIN_DB)
}
// ReadOnly determines if a database is read-only.
//
// https://sqlite.org/c3ref/db_readonly.html
func (c *Conn) ReadOnly(schema string) (ro bool, ok bool) {
var ptr ptr_t
var ptr uint32
if schema != "" {
defer c.arena.mark()()
ptr = c.arena.string(schema)
}
b := int32(c.call("sqlite3_db_readonly", stk_t(c.handle), stk_t(ptr)))
return b > 0, b < 0
r := c.call("sqlite3_db_readonly", uint64(c.handle), uint64(ptr))
return int32(r) > 0, int32(r) < 0
}
// GetAutocommit tests the connection for auto-commit mode.
//
// https://sqlite.org/c3ref/get_autocommit.html
func (c *Conn) GetAutocommit() bool {
b := int32(c.call("sqlite3_get_autocommit", stk_t(c.handle)))
return b != 0
r := c.call("sqlite3_get_autocommit", uint64(c.handle))
return r != 0
}
// LastInsertRowID returns the rowid of the most recent successful INSERT
@@ -286,7 +281,8 @@ func (c *Conn) GetAutocommit() bool {
//
// https://sqlite.org/c3ref/last_insert_rowid.html
func (c *Conn) LastInsertRowID() int64 {
return int64(c.call("sqlite3_last_insert_rowid", stk_t(c.handle)))
r := c.call("sqlite3_last_insert_rowid", uint64(c.handle))
return int64(r)
}
// SetLastInsertRowID allows the application to set the value returned by
@@ -294,7 +290,7 @@ func (c *Conn) LastInsertRowID() int64 {
//
// https://sqlite.org/c3ref/set_last_insert_rowid.html
func (c *Conn) SetLastInsertRowID(id int64) {
c.call("sqlite3_set_last_insert_rowid", stk_t(c.handle), stk_t(id))
c.call("sqlite3_set_last_insert_rowid", uint64(c.handle), uint64(id))
}
// Changes returns the number of rows modified, inserted or deleted
@@ -303,7 +299,8 @@ func (c *Conn) SetLastInsertRowID(id int64) {
//
// https://sqlite.org/c3ref/changes.html
func (c *Conn) Changes() int64 {
return int64(c.call("sqlite3_changes64", stk_t(c.handle)))
r := c.call("sqlite3_changes64", uint64(c.handle))
return int64(r)
}
// TotalChanges returns the number of rows modified, inserted or deleted
@@ -312,15 +309,16 @@ func (c *Conn) Changes() int64 {
//
// https://sqlite.org/c3ref/total_changes.html
func (c *Conn) TotalChanges() int64 {
return int64(c.call("sqlite3_total_changes64", stk_t(c.handle)))
r := c.call("sqlite3_total_changes64", uint64(c.handle))
return int64(r)
}
// ReleaseMemory frees memory used by a database connection.
//
// https://sqlite.org/c3ref/db_release_memory.html
func (c *Conn) ReleaseMemory() error {
rc := res_t(c.call("sqlite3_db_release_memory", stk_t(c.handle)))
return c.error(rc)
r := c.call("sqlite3_db_release_memory", uint64(c.handle))
return c.error(r)
}
// GetInterrupt gets the context set with [Conn.SetInterrupt].
@@ -343,17 +341,43 @@ func (c *Conn) GetInterrupt() context.Context {
//
// https://sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
if ctx == nil {
panic("nil Context")
}
old = c.interrupt
c.interrupt = ctx
if ctx == old || ctx.Done() == old.Done() {
return old
}
// A busy SQL statement prevents SQLite from ignoring an interrupt
// that comes before any other statements are started.
if c.pending == nil {
defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen)
loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`)
c.call("sqlite3_prepare_v3", uint64(c.handle), uint64(loopPtr), math.MaxUint64,
uint64(PREPARE_PERSISTENT), uint64(stmtPtr), 0)
c.pending = &Stmt{c: c}
c.pending.handle = util.ReadUint32(c.mod, stmtPtr)
}
if old.Done() != nil && ctx.Err() == nil {
c.pending.Reset()
}
if ctx.Done() != nil {
c.pending.Step()
}
return old
}
func progressCallback(ctx context.Context, mod api.Module, _ ptr_t) (interrupt int32) {
func (c *Conn) checkInterrupt(handle uint32) {
if c.interrupt.Err() != nil {
c.call("sqlite3_interrupt", uint64(handle))
}
}
func progressCallback(ctx context.Context, mod api.Module, _ uint32) (interrupt uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
if c.gosched++; c.gosched%16 == 0 {
if c.interrupt.Done() != nil {
runtime.Gosched()
}
if c.interrupt.Err() != nil {
@@ -368,11 +392,11 @@ func progressCallback(ctx context.Context, mod api.Module, _ ptr_t) (interrupt i
// https://sqlite.org/c3ref/busy_timeout.html
func (c *Conn) BusyTimeout(timeout time.Duration) error {
ms := min((timeout+time.Millisecond-1)/time.Millisecond, math.MaxInt32)
rc := res_t(c.call("sqlite3_busy_timeout", stk_t(c.handle), stk_t(ms)))
return c.error(rc)
r := c.call("sqlite3_busy_timeout", uint64(c.handle), uint64(ms))
return c.error(r)
}
func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (retry int32) {
func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (retry uint32) {
// https://fractaledmind.github.io/2024/04/15/sqlite-on-rails-the-how-and-why-of-optimal-performance/
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.interrupt.Err() == nil {
switch {
@@ -395,22 +419,25 @@ func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (r
//
// https://sqlite.org/c3ref/busy_handler.html
func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool)) error {
var enable int32
var enable uint64
if cb != nil {
enable = 1
}
rc := res_t(c.call("sqlite3_busy_handler_go", stk_t(c.handle), stk_t(enable)))
if err := c.error(rc); err != nil {
r := c.call("sqlite3_busy_handler_go", uint64(c.handle), enable)
if err := c.error(r); err != nil {
return err
}
c.busy = cb
return nil
}
func busyCallback(ctx context.Context, mod api.Module, pDB ptr_t, count int32) (retry int32) {
func busyCallback(ctx context.Context, mod api.Module, pDB uint32, count int32) (retry uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil {
if interrupt := c.interrupt; interrupt.Err() == nil &&
c.busy(interrupt, int(count)) {
interrupt := c.interrupt
if interrupt == nil {
interrupt = context.Background()
}
if interrupt.Err() == nil && c.busy(interrupt, int(count)) {
retry = 1
}
}
@@ -420,21 +447,21 @@ func busyCallback(ctx context.Context, mod api.Module, pDB ptr_t, count int32) (
// Status retrieves runtime status information about a database connection.
//
// https://sqlite.org/c3ref/db_status.html
func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int64, err error) {
func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int, err error) {
defer c.arena.mark()()
hiPtr := c.arena.new(8)
curPtr := c.arena.new(8)
hiPtr := c.arena.new(intlen)
curPtr := c.arena.new(intlen)
var i int32
var i uint64
if reset {
i = 1
}
rc := res_t(c.call("sqlite3_db_status64", stk_t(c.handle),
stk_t(op), stk_t(curPtr), stk_t(hiPtr), stk_t(i)))
if err = c.error(rc); err == nil {
current = util.Read64[int64](c.mod, curPtr)
highwater = util.Read64[int64](c.mod, hiPtr)
r := c.call("sqlite3_db_status", uint64(c.handle),
uint64(op), uint64(curPtr), uint64(hiPtr), i)
if err = c.error(r); err == nil {
current = int(util.ReadUint32(c.mod, curPtr))
highwater = int(util.ReadUint32(c.mod, hiPtr))
}
return
}
@@ -444,60 +471,47 @@ func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int64, err er
// https://sqlite.org/c3ref/table_column_metadata.html
func (c *Conn) TableColumnMetadata(schema, table, column string) (declType, collSeq string, notNull, primaryKey, autoInc bool, err error) {
defer c.arena.mark()()
var (
declTypePtr ptr_t
collSeqPtr ptr_t
notNullPtr ptr_t
primaryKeyPtr ptr_t
autoIncPtr ptr_t
columnPtr ptr_t
schemaPtr ptr_t
)
if column != "" {
declTypePtr = c.arena.new(ptrlen)
collSeqPtr = c.arena.new(ptrlen)
notNullPtr = c.arena.new(ptrlen)
primaryKeyPtr = c.arena.new(ptrlen)
autoIncPtr = c.arena.new(ptrlen)
columnPtr = c.arena.string(column)
}
var schemaPtr, columnPtr uint32
declTypePtr := c.arena.new(ptrlen)
collSeqPtr := c.arena.new(ptrlen)
notNullPtr := c.arena.new(ptrlen)
autoIncPtr := c.arena.new(ptrlen)
primaryKeyPtr := c.arena.new(ptrlen)
if schema != "" {
schemaPtr = c.arena.string(schema)
}
tablePtr := c.arena.string(table)
if column != "" {
columnPtr = c.arena.string(column)
}
rc := res_t(c.call("sqlite3_table_column_metadata", stk_t(c.handle),
stk_t(schemaPtr), stk_t(tablePtr), stk_t(columnPtr),
stk_t(declTypePtr), stk_t(collSeqPtr),
stk_t(notNullPtr), stk_t(primaryKeyPtr), stk_t(autoIncPtr)))
if err = c.error(rc); err == nil && column != "" {
if ptr := util.Read32[ptr_t](c.mod, declTypePtr); ptr != 0 {
r := c.call("sqlite3_table_column_metadata", uint64(c.handle),
uint64(schemaPtr), uint64(tablePtr), uint64(columnPtr),
uint64(declTypePtr), uint64(collSeqPtr),
uint64(notNullPtr), uint64(primaryKeyPtr), uint64(autoIncPtr))
if err = c.error(r); err == nil && column != "" {
if ptr := util.ReadUint32(c.mod, declTypePtr); ptr != 0 {
declType = util.ReadString(c.mod, ptr, _MAX_NAME)
}
if ptr := util.Read32[ptr_t](c.mod, collSeqPtr); ptr != 0 {
if ptr := util.ReadUint32(c.mod, collSeqPtr); ptr != 0 {
collSeq = util.ReadString(c.mod, ptr, _MAX_NAME)
}
notNull = util.ReadBool(c.mod, notNullPtr)
autoInc = util.ReadBool(c.mod, autoIncPtr)
primaryKey = util.ReadBool(c.mod, primaryKeyPtr)
notNull = util.ReadUint32(c.mod, notNullPtr) != 0
autoInc = util.ReadUint32(c.mod, autoIncPtr) != 0
primaryKey = util.ReadUint32(c.mod, primaryKeyPtr) != 0
}
return
}
func (c *Conn) error(rc res_t, sql ...string) error {
func (c *Conn) error(rc uint64, sql ...string) error {
return c.sqlite.error(rc, c.handle, sql...)
}
// Stmts returns an iterator for the prepared statements
// associated with the database connection.
//
// https://sqlite.org/c3ref/next_stmt.html
func (c *Conn) Stmts() iter.Seq[*Stmt] {
return func(yield func(*Stmt) bool) {
for _, s := range c.stmts {
if !yield(s) {
break
}
func (c *Conn) stmtsIter(yield func(*Stmt) bool) {
for _, s := range c.stmts {
if !yield(s) {
break
}
}
}

11
conn_iter.go Normal file
View File

@@ -0,0 +1,11 @@
//go:build go1.23
package sqlite3
import "iter"
// Stmts returns an iterator for the prepared statements
// associated with the database connection.
//
// https://sqlite.org/c3ref/next_stmt.html
func (c *Conn) Stmts() iter.Seq[*Stmt] { return c.stmtsIter }

9
conn_old.go Normal file
View File

@@ -0,0 +1,9 @@
//go:build !go1.23
package sqlite3
// Stmts returns an iterator for the prepared statements
// associated with the database connection.
//
// https://sqlite.org/c3ref/next_stmt.html
func (c *Conn) Stmts() func(func(*Stmt) bool) { return c.stmtsIter }

View File

@@ -1,28 +1,19 @@
package sqlite3
import (
"strconv"
"github.com/ncruces/go-sqlite3/internal/util"
)
import "strconv"
const (
_OK = 0 /* Successful result */
_ROW = 100 /* sqlite3_step() has another row ready */
_DONE = 101 /* sqlite3_step() has finished executing */
_MAX_NAME = 1e6 // Self-imposed limit for most NUL terminated strings.
_MAX_LENGTH = 1e9
_MAX_SQL_LENGTH = 1e9
_MAX_NAME = 1e6 // Self-imposed limit for most NUL terminated strings.
_MAX_LENGTH = 1e9
_MAX_SQL_LENGTH = 1e9
_MAX_FUNCTION_ARG = 100
ptrlen = util.PtrLen
intlen = util.IntLen
)
type (
stk_t = util.Stk_t
ptr_t = util.Ptr_t
res_t = util.Res_t
ptrlen = 4
intlen = 4
)
// ErrorCode is a result code that [Error.Code] might return.
@@ -73,9 +64,6 @@ const (
ERROR_MISSING_COLLSEQ ExtendedErrorCode = xErrorCode(ERROR) | (1 << 8)
ERROR_RETRY ExtendedErrorCode = xErrorCode(ERROR) | (2 << 8)
ERROR_SNAPSHOT ExtendedErrorCode = xErrorCode(ERROR) | (3 << 8)
ERROR_RESERVESIZE ExtendedErrorCode = xErrorCode(ERROR) | (4 << 8)
ERROR_KEY ExtendedErrorCode = xErrorCode(ERROR) | (5 << 8)
ERROR_UNABLE ExtendedErrorCode = xErrorCode(ERROR) | (6 << 8)
IOERR_READ ExtendedErrorCode = xErrorCode(IOERR) | (1 << 8)
IOERR_SHORT_READ ExtendedErrorCode = xErrorCode(IOERR) | (2 << 8)
IOERR_WRITE ExtendedErrorCode = xErrorCode(IOERR) | (3 << 8)
@@ -110,8 +98,6 @@ const (
IOERR_DATA ExtendedErrorCode = xErrorCode(IOERR) | (32 << 8)
IOERR_CORRUPTFS ExtendedErrorCode = xErrorCode(IOERR) | (33 << 8)
IOERR_IN_PAGE ExtendedErrorCode = xErrorCode(IOERR) | (34 << 8)
IOERR_BADKEY ExtendedErrorCode = xErrorCode(IOERR) | (35 << 8)
IOERR_CODEC ExtendedErrorCode = xErrorCode(IOERR) | (36 << 8)
LOCKED_SHAREDCACHE ExtendedErrorCode = xErrorCode(LOCKED) | (1 << 8)
LOCKED_VTAB ExtendedErrorCode = xErrorCode(LOCKED) | (2 << 8)
BUSY_RECOVERY ExtendedErrorCode = xErrorCode(BUSY) | (1 << 8)
@@ -173,15 +159,13 @@ const (
// PrepareFlag is a flag that can be passed to [Conn.PrepareFlags].
//
// https://sqlite.org/c3ref/c_prepare_dont_log.html
// https://sqlite.org/c3ref/c_prepare_normalize.html
type PrepareFlag uint32
const (
PREPARE_PERSISTENT PrepareFlag = 0x01
PREPARE_NORMALIZE PrepareFlag = 0x02
PREPARE_NO_VTAB PrepareFlag = 0x04
PREPARE_DONT_LOG PrepareFlag = 0x10
PREPARE_FROM_DDL PrepareFlag = 0x20
)
// FunctionFlag is a flag that can be passed to
@@ -191,12 +175,12 @@ const (
type FunctionFlag uint32
const (
DETERMINISTIC FunctionFlag = 0x000000800
DIRECTONLY FunctionFlag = 0x000080000
SUBTYPE FunctionFlag = 0x000100000
INNOCUOUS FunctionFlag = 0x000200000
RESULT_SUBTYPE FunctionFlag = 0x001000000
SELFORDER1 FunctionFlag = 0x002000000
DETERMINISTIC FunctionFlag = 0x000000800
DIRECTONLY FunctionFlag = 0x000080000
INNOCUOUS FunctionFlag = 0x000200000
SELFORDER1 FunctionFlag = 0x002000000
// SUBTYPE FunctionFlag = 0x000100000
// RESULT_SUBTYPE FunctionFlag = 0x001000000
)
// StmtStatus name counter values associated with the [Stmt.Status] method.
@@ -235,8 +219,6 @@ const (
DBSTATUS_DEFERRED_FKS DBStatus = 10
DBSTATUS_CACHE_USED_SHARED DBStatus = 11
DBSTATUS_CACHE_SPILL DBStatus = 12
DBSTATUS_TEMPBUF_SPILL DBStatus = 13
// DBSTATUS_MAX DBStatus = 13
)
// DBConfig are the available database connection configuration options.
@@ -265,10 +247,7 @@ const (
DBCONFIG_TRUSTED_SCHEMA DBConfig = 1017
DBCONFIG_STMT_SCANSTATUS DBConfig = 1018
DBCONFIG_REVERSE_SCANORDER DBConfig = 1019
DBCONFIG_ENABLE_ATTACH_CREATE DBConfig = 1020
DBCONFIG_ENABLE_ATTACH_WRITE DBConfig = 1021
DBCONFIG_ENABLE_COMMENTS DBConfig = 1022
// DBCONFIG_MAX DBConfig = 1022
// DBCONFIG_MAX DBConfig = 1019
)
// FcntlOpcode are the available opcodes for [Conn.FileControl].
@@ -281,14 +260,12 @@ const (
FCNTL_CHUNK_SIZE FcntlOpcode = 6
FCNTL_FILE_POINTER FcntlOpcode = 7
FCNTL_PERSIST_WAL FcntlOpcode = 10
FCNTL_VFSNAME FcntlOpcode = 12
FCNTL_POWERSAFE_OVERWRITE FcntlOpcode = 13
FCNTL_VFS_POINTER FcntlOpcode = 27
FCNTL_JOURNAL_POINTER FcntlOpcode = 28
FCNTL_DATA_VERSION FcntlOpcode = 35
FCNTL_RESERVE_BYTES FcntlOpcode = 38
FCNTL_RESET_CACHE FcntlOpcode = 42
FCNTL_NULL_IO FcntlOpcode = 43
)
// LimitCategory are the available run-time limit categories.
@@ -309,7 +286,6 @@ const (
LIMIT_VARIABLE_NUMBER LimitCategory = 9
LIMIT_TRIGGER_DEPTH LimitCategory = 10
LIMIT_WORKER_THREADS LimitCategory = 11
LIMIT_PARSER_DEPTH LimitCategory = 12
)
// AuthorizerActionCode are the integer action codes
@@ -371,14 +347,13 @@ const (
// CheckpointMode are all the checkpoint mode values.
//
// https://sqlite.org/c3ref/c_checkpoint_full.html
type CheckpointMode int32
type CheckpointMode uint32
const (
CHECKPOINT_NOOP CheckpointMode = -1 /* Do no work at all */
CHECKPOINT_PASSIVE CheckpointMode = 0 /* Do as much as possible w/o blocking */
CHECKPOINT_FULL CheckpointMode = 1 /* Wait for writers, then checkpoint */
CHECKPOINT_RESTART CheckpointMode = 2 /* Like FULL but wait for readers */
CHECKPOINT_TRUNCATE CheckpointMode = 3 /* Like RESTART but also truncate WAL */
CHECKPOINT_PASSIVE CheckpointMode = 0 /* Do as much as possible w/o blocking */
CHECKPOINT_FULL CheckpointMode = 1 /* Wait for writers, then checkpoint */
CHECKPOINT_RESTART CheckpointMode = 2 /* Like FULL but wait for readers */
CHECKPOINT_TRUNCATE CheckpointMode = 3 /* Like RESTART but also truncate WAL */
)
// TxnState are the allowed return values from [Conn.TxnState].

View File

@@ -1,6 +1,7 @@
package sqlite3
import (
"encoding/json"
"errors"
"math"
"time"
@@ -14,7 +15,7 @@ import (
// https://sqlite.org/c3ref/context.html
type Context struct {
c *Conn
handle ptr_t
handle uint32
}
// Conn returns the database connection of the
@@ -31,14 +32,14 @@ func (ctx Context) Conn() *Conn {
// https://sqlite.org/c3ref/get_auxdata.html
func (ctx Context) SetAuxData(n int, data any) {
ptr := util.AddHandle(ctx.c.ctx, data)
ctx.c.call("sqlite3_set_auxdata_go", stk_t(ctx.handle), stk_t(n), stk_t(ptr))
ctx.c.call("sqlite3_set_auxdata_go", uint64(ctx.handle), uint64(n), uint64(ptr))
}
// GetAuxData returns metadata for argument n of the function.
//
// https://sqlite.org/c3ref/get_auxdata.html
func (ctx Context) GetAuxData(n int) any {
ptr := ptr_t(ctx.c.call("sqlite3_get_auxdata", stk_t(ctx.handle), stk_t(n)))
ptr := uint32(ctx.c.call("sqlite3_get_auxdata", uint64(ctx.handle), uint64(n)))
return util.GetHandle(ctx.c.ctx, ptr)
}
@@ -67,7 +68,7 @@ func (ctx Context) ResultInt(value int) {
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultInt64(value int64) {
ctx.c.call("sqlite3_result_int64",
stk_t(ctx.handle), stk_t(value))
uint64(ctx.handle), uint64(value))
}
// ResultFloat sets the result of the function to a float64.
@@ -75,7 +76,7 @@ func (ctx Context) ResultInt64(value int64) {
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultFloat(value float64) {
ctx.c.call("sqlite3_result_double",
stk_t(ctx.handle), stk_t(math.Float64bits(value)))
uint64(ctx.handle), math.Float64bits(value))
}
// ResultText sets the result of the function to a string.
@@ -84,33 +85,27 @@ func (ctx Context) ResultFloat(value float64) {
func (ctx Context) ResultText(value string) {
ptr := ctx.c.newString(value)
ctx.c.call("sqlite3_result_text_go",
stk_t(ctx.handle), stk_t(ptr), stk_t(len(value)))
uint64(ctx.handle), uint64(ptr), uint64(len(value)))
}
// ResultRawText sets the text result of the function to a []byte.
// Returning a nil slice is the same as calling [Context.ResultNull].
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultRawText(value []byte) {
if len(value) == 0 {
ctx.ResultText("")
return
}
ptr := ctx.c.newBytes(value)
ctx.c.call("sqlite3_result_text_go",
stk_t(ctx.handle), stk_t(ptr), stk_t(len(value)))
uint64(ctx.handle), uint64(ptr), uint64(len(value)))
}
// ResultBlob sets the result of the function to a []byte.
// Returning a nil slice is the same as calling [Context.ResultNull].
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultBlob(value []byte) {
if len(value) == 0 {
ctx.ResultZeroBlob(0)
return
}
ptr := ctx.c.newBytes(value)
ctx.c.call("sqlite3_result_blob_go",
stk_t(ctx.handle), stk_t(ptr), stk_t(len(value)))
uint64(ctx.handle), uint64(ptr), uint64(len(value)))
}
// ResultZeroBlob sets the result of the function to a zero-filled, length n BLOB.
@@ -118,7 +113,7 @@ func (ctx Context) ResultBlob(value []byte) {
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultZeroBlob(n int64) {
ctx.c.call("sqlite3_result_zeroblob64",
stk_t(ctx.handle), stk_t(n))
uint64(ctx.handle), uint64(n))
}
// ResultNull sets the result of the function to NULL.
@@ -126,7 +121,7 @@ func (ctx Context) ResultZeroBlob(n int64) {
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultNull() {
ctx.c.call("sqlite3_result_null",
stk_t(ctx.handle))
uint64(ctx.handle))
}
// ResultTime sets the result of the function to a [time.Time].
@@ -151,14 +146,14 @@ func (ctx Context) ResultTime(value time.Time, format TimeFormat) {
}
func (ctx Context) resultRFC3339Nano(value time.Time) {
const maxlen = int64(len(time.RFC3339Nano)) + 5
const maxlen = uint64(len(time.RFC3339Nano)) + 5
ptr := ctx.c.new(maxlen)
buf := util.View(ctx.c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
ctx.c.call("sqlite3_result_text_go",
stk_t(ctx.handle), stk_t(ptr), stk_t(len(buf)))
uint64(ctx.handle), uint64(ptr), uint64(len(buf)))
}
// ResultPointer sets the result of the function to NULL, just like [Context.ResultNull],
@@ -169,7 +164,19 @@ func (ctx Context) resultRFC3339Nano(value time.Time) {
func (ctx Context) ResultPointer(ptr any) {
valPtr := util.AddHandle(ctx.c.ctx, ptr)
ctx.c.call("sqlite3_result_pointer_go",
stk_t(ctx.handle), stk_t(valPtr))
uint64(ctx.handle), uint64(valPtr))
}
// ResultJSON sets the result of the function to the JSON encoding of value.
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultJSON(value any) {
data, err := json.Marshal(value)
if err != nil {
ctx.ResultError(err)
return // notest
}
ctx.ResultRawText(data)
}
// ResultValue sets the result of the function to a copy of [Value].
@@ -181,7 +188,7 @@ func (ctx Context) ResultValue(value Value) {
return
}
ctx.c.call("sqlite3_result_value",
stk_t(ctx.handle), stk_t(value.handle))
uint64(ctx.handle), uint64(value.handle))
}
// ResultError sets the result of the function an error.
@@ -189,41 +196,33 @@ func (ctx Context) ResultValue(value Value) {
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultError(err error) {
if errors.Is(err, NOMEM) {
ctx.c.call("sqlite3_result_error_nomem", stk_t(ctx.handle))
ctx.c.call("sqlite3_result_error_nomem", uint64(ctx.handle))
return
}
if errors.Is(err, TOOBIG) {
ctx.c.call("sqlite3_result_error_toobig", stk_t(ctx.handle))
ctx.c.call("sqlite3_result_error_toobig", uint64(ctx.handle))
return
}
msg, code := errorCode(err, ERROR)
msg, code := errorCode(err, _OK)
if msg != "" {
defer ctx.c.arena.mark()()
ptr := ctx.c.arena.string(msg)
ctx.c.call("sqlite3_result_error",
stk_t(ctx.handle), stk_t(ptr), stk_t(len(msg)))
uint64(ctx.handle), uint64(ptr), uint64(len(msg)))
}
if code != res_t(ERROR) {
if code != _OK {
ctx.c.call("sqlite3_result_error_code",
stk_t(ctx.handle), stk_t(code))
uint64(ctx.handle), uint64(code))
}
}
// ResultSubtype sets the subtype of the result of the function.
//
// https://sqlite.org/c3ref/result_subtype.html
func (ctx Context) ResultSubtype(t uint) {
ctx.c.call("sqlite3_result_subtype",
stk_t(ctx.handle), stk_t(uint32(t)))
}
// VTabNoChange may return true if a column is being fetched as part
// of an update during which the column value will not change.
//
// https://sqlite.org/c3ref/vtab_nochange.html
func (ctx Context) VTabNoChange() bool {
b := int32(ctx.c.call("sqlite3_vtab_nochange", stk_t(ctx.handle)))
return b != 0
r := ctx.c.call("sqlite3_vtab_nochange", uint64(ctx.handle))
return r != 0
}

View File

@@ -20,45 +20,22 @@
// - a [serializable] transaction is always "immediate";
// - a [read-only] transaction is always "deferred".
//
// # Datatypes In SQLite
//
// SQLite is dynamically typed.
// Columns can mostly hold any value regardless of their declared type.
// SQLite supports most [driver.Value] types out of the box,
// but bool and [time.Time] require special care.
//
// Booleans can be stored on any column type and scanned back to a *bool.
// However, if scanned to a *any, booleans may either become an
// int64, string or bool, depending on the declared type of the column.
// If you use BOOLEAN for your column type,
// 1 and 0 will always scan as true and false.
//
// # Working with time
//
// Time values can similarly be stored on any column type.
// The time encoding/decoding format can be specified using "_timefmt":
//
// sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite")
//
// Special values are: "auto" (the default), "sqlite", "rfc3339";
// Possible values are: "auto" (the default), "sqlite", "rfc3339";
// - "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite;
// - "sqlite" encodes as SQLite and decodes any [format] supported by SQLite;
// - "rfc3339" encodes and decodes RFC 3339 only.
//
// You can also set "_timefmt" to an arbitrary [sqlite3.TimeFormat] or [time.Layout].
//
// If you encode as RFC 3339 (the default),
// consider using the TIME [collating sequence] to produce time-ordered sequences.
// consider using the TIME [collating sequence] to produce a time-ordered sequence.
//
// If you encode as RFC 3339 (the default),
// time values will scan back to a *time.Time unless your column type is TEXT.
// Otherwise, if scanned to a *any, time values may either become an
// int64, float64 or string, depending on the time format and declared type of the column.
// If you use DATE, TIME, DATETIME, or TIMESTAMP for your column type,
// "_timefmt" will be used to decode values.
//
// To scan values in custom formats, [sqlite3.TimeFormat.Scanner] may be helpful.
// To bind values in custom formats, [sqlite3.TimeFormat.Encode] them before binding.
// To scan values in other formats, [sqlite3.TimeFormat.Scanner] may be helpful.
// To bind values in other formats, [sqlite3.TimeFormat.Encode] them before binding.
//
// When using a custom time struct, you'll have to implement
// [database/sql/driver.Valuer] and [database/sql.Scanner].
@@ -71,7 +48,7 @@
// The Scan method needs to take into account that the value it receives can be of differing types.
// It can already be a [time.Time], if the driver decoded the value according to "_timefmt" rules.
// Or it can be a: string, int64, float64, []byte, or nil,
// depending on the column type and whoever wrote the value.
// depending on the column type and what whoever wrote the value.
// [sqlite3.TimeFormat.Decode] may help.
//
// # Setting PRAGMAs
@@ -224,7 +201,7 @@ func (n *connector) Driver() driver.Driver {
return &SQLite{}
}
func (n *connector) Connect(ctx context.Context) (ret driver.Conn, err error) {
func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) {
c := &conn{
txLock: n.txLock,
tmRead: n.tmRead,
@@ -236,14 +213,13 @@ func (n *connector) Connect(ctx context.Context) (ret driver.Conn, err error) {
return nil, err
}
defer func() {
if ret == nil {
if res == nil {
c.Close()
}
}()
if old := c.Conn.SetInterrupt(ctx); old != ctx {
defer c.Conn.SetInterrupt(old)
}
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
if !n.pragmas {
err = c.Conn.BusyTimeout(time.Minute)
@@ -263,8 +239,10 @@ func (n *connector) Connect(ctx context.Context) (ret driver.Conn, err error) {
return nil, err
}
defer s.Close()
if s.Step() {
c.readOnly = s.ColumnBool(0)
if s.Step() && s.ColumnBool(0) {
c.readOnly = '1'
} else {
c.readOnly = '0'
}
err = s.Close()
if err != nil {
@@ -320,7 +298,7 @@ type conn struct {
txReset string
tmRead sqlite3.TimeFormat
tmWrite sqlite3.TimeFormat
readOnly bool
readOnly byte
}
var (
@@ -356,14 +334,13 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
c.txReset = ``
txBegin := `BEGIN ` + txLock
if opts.ReadOnly && !c.readOnly {
if opts.ReadOnly {
txBegin += ` ; PRAGMA query_only=on`
c.txReset = `; PRAGMA query_only=off`
c.txReset = `; PRAGMA query_only=` + string(c.readOnly)
}
if old := c.Conn.SetInterrupt(ctx); old != ctx {
defer c.Conn.SetInterrupt(old)
}
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
err := c.Conn.Exec(txBegin)
if err != nil {
@@ -381,12 +358,13 @@ func (c *conn) Commit() error {
}
func (c *conn) Rollback() error {
// ROLLBACK even if interrupted.
ctx := context.Background()
if old := c.Conn.SetInterrupt(ctx); old != ctx {
err := c.Conn.Exec(`ROLLBACK` + c.txReset)
if errors.Is(err, sqlite3.INTERRUPT) {
old := c.Conn.SetInterrupt(context.Background())
defer c.Conn.SetInterrupt(old)
err = c.Conn.Exec(`ROLLBACK` + c.txReset)
}
return c.Conn.Exec(`ROLLBACK` + c.txReset)
return err
}
func (c *conn) Prepare(query string) (driver.Stmt, error) {
@@ -395,9 +373,8 @@ func (c *conn) Prepare(query string) (driver.Stmt, error) {
}
func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if old := c.Conn.SetInterrupt(ctx); old != ctx {
defer c.Conn.SetInterrupt(old)
}
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
s, tail, err := c.Conn.Prepare(query)
if err != nil {
@@ -422,9 +399,8 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
return resultRowsAffected(0), nil
}
if old := c.Conn.SetInterrupt(ctx); old != ctx {
defer c.Conn.SetInterrupt(old)
}
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
err := c.Conn.Exec(query)
if err != nil {
@@ -487,19 +463,16 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
return nil, err
}
c := s.Stmt.Conn()
if old := c.SetInterrupt(ctx); old != ctx {
defer c.SetInterrupt(old)
}
old := s.Stmt.Conn().SetInterrupt(ctx)
defer s.Stmt.Conn().SetInterrupt(old)
err = errors.Join(
s.Stmt.Exec(),
s.Stmt.ClearBindings())
err = s.Stmt.Exec()
s.Stmt.ClearBindings()
if err != nil {
return nil, err
}
return newResult(c), nil
return newResult(s.Stmt.Conn()), nil
}
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
@@ -544,8 +517,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) (err error) {
err = s.Stmt.BindTime(id, a, s.tmWrite)
case util.JSON:
err = s.Stmt.BindJSON(id, a.Value)
case util.Pointer:
err = s.Stmt.BindPointer(id, a.Value)
case util.PointerUnwrap:
err = s.Stmt.BindPointer(id, util.UnwrapPointer(a))
case nil:
err = s.Stmt.BindNull(id)
default:
@@ -563,7 +536,7 @@ func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) {
case bool, int, int64, float64, string, []byte,
time.Time, sqlite3.ZeroBlob,
util.JSON, util.Pointer,
util.JSON, util.PointerUnwrap,
nil:
return nil
default:
@@ -602,59 +575,28 @@ func (r resultRowsAffected) RowsAffected() (int64, error) {
return int64(r), nil
}
type scantype byte
const (
_ANY scantype = iota
_INT
_REAL
_TEXT
_BLOB
_NULL
_BOOL
_TIME
_NOT_NULL
)
var (
_ [0]struct{} = [scantype(sqlite3.INTEGER) - _INT]struct{}{}
_ [0]struct{} = [scantype(sqlite3.FLOAT) - _REAL]struct{}{}
_ [0]struct{} = [scantype(sqlite3.TEXT) - _TEXT]struct{}{}
_ [0]struct{} = [scantype(sqlite3.BLOB) - _BLOB]struct{}{}
_ [0]struct{} = [scantype(sqlite3.NULL) - _NULL]struct{}{}
_ [0]struct{} = [_NOT_NULL & (_NOT_NULL - 1)]struct{}{}
)
func scanFromDecl(decl string) scantype {
// These types are only used before we have rows,
// and otherwise as type hints.
// The first few ensure STRICT tables are strictly typed.
// The other two are type hints for booleans and time.
switch decl {
case "INT", "INTEGER":
return _INT
case "REAL":
return _REAL
case "TEXT":
return _TEXT
case "BLOB":
return _BLOB
case "BOOLEAN":
return _BOOL
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
return _TIME
}
return _ANY
}
type rows struct {
ctx context.Context
*stmt
names []string
types []string
nulls []bool
scans []scantype
}
type scantype byte
const (
_ANY scantype = iota
_INT scantype = scantype(sqlite3.INTEGER)
_REAL scantype = scantype(sqlite3.FLOAT)
_TEXT scantype = scantype(sqlite3.TEXT)
_BLOB scantype = scantype(sqlite3.BLOB)
_NULL scantype = scantype(sqlite3.NULL)
_BOOL scantype = iota
_TIME
)
var (
// Ensure these interfaces are implemented:
_ driver.RowsColumnTypeDatabaseTypeName = &rows{}
@@ -662,9 +604,8 @@ var (
)
func (r *rows) Close() error {
return errors.Join(
r.Stmt.Reset(),
r.Stmt.ClearBindings())
r.Stmt.ClearBindings()
return r.Stmt.Reset()
}
func (r *rows) Columns() []string {
@@ -679,69 +620,79 @@ func (r *rows) Columns() []string {
return r.names
}
func (r *rows) scanType(index int) scantype {
if r.scans == nil {
count := len(r.names)
scans := make([]scantype, count)
for i := range scans {
scans[i] = scanFromDecl(strings.ToUpper(r.Stmt.ColumnDeclType(i)))
}
r.scans = scans
}
return r.scans[index] &^ _NOT_NULL
}
func (r *rows) loadColumnMetadata() {
if r.types == nil {
c := r.Stmt.Conn()
count := len(r.names)
if r.nulls == nil {
count := r.Stmt.ColumnCount()
nulls := make([]bool, count)
types := make([]string, count)
scans := make([]scantype, count)
for i := range types {
var declType string
var notNull, autoInc bool
if column := r.Stmt.ColumnOriginName(i); column != "" {
declType, _, notNull, _, autoInc, _ = c.TableColumnMetadata(
for i := range nulls {
if col := r.Stmt.ColumnOriginName(i); col != "" {
types[i], _, nulls[i], _, _, _ = r.Stmt.Conn().TableColumnMetadata(
r.Stmt.ColumnDatabaseName(i),
r.Stmt.ColumnTableName(i),
column)
} else {
declType = r.Stmt.ColumnDeclType(i)
}
if declType != "" {
declType = strings.ToUpper(declType)
scans[i] = scanFromDecl(declType)
types[i] = declType
}
if notNull || autoInc {
scans[i] |= _NOT_NULL
col)
types[i] = strings.ToUpper(types[i])
// These types are only used before we have rows,
// and otherwise as type hints.
// The first few ensure STRICT tables are strictly typed.
// The other two are type hints for booleans and time.
switch types[i] {
case "INT", "INTEGER":
scans[i] = _INT
case "REAL":
scans[i] = _REAL
case "TEXT":
scans[i] = _TEXT
case "BLOB":
scans[i] = _BLOB
case "BOOLEAN":
scans[i] = _BOOL
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
scans[i] = _TIME
}
}
}
r.nulls = nulls
r.types = types
r.scans = scans
}
}
func (r *rows) declType(index int) string {
if r.types == nil {
count := r.Stmt.ColumnCount()
types := make([]string, count)
for i := range types {
types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i))
}
r.types = types
}
return r.types[index]
}
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
r.loadColumnMetadata()
decl := r.types[index]
if len := len(decl); len > 0 && decl[len-1] == ')' {
if i := strings.LastIndexByte(decl, '('); i >= 0 {
decl = decl[:i]
decltype := r.types[index]
if len := len(decltype); len > 0 && decltype[len-1] == ')' {
if i := strings.LastIndexByte(decltype, '('); i >= 0 {
decltype = decltype[:i]
}
}
return strings.TrimSpace(decl)
return strings.TrimSpace(decltype)
}
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
r.loadColumnMetadata()
nullable = r.scans[index]&^_NOT_NULL == 0
return nullable, !nullable
if r.nulls[index] {
return false, true
}
return true, false
}
func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
r.loadColumnMetadata()
scan := r.scans[index] &^ _NOT_NULL
scan := r.scans[index]
if r.Stmt.Busy() {
// SQLite is dynamically typed and we now have a row.
@@ -753,7 +704,7 @@ func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
switch {
case scan == _TIME && val != _BLOB && val != _NULL:
t := r.Stmt.ColumnTime(index, r.tmRead)
useValType = t.IsZero()
useValType = t == time.Time{}
case scan == _BOOL && val == _INT:
i := r.Stmt.ColumnInt64(index)
useValType = i != 0 && i != 1
@@ -767,27 +718,25 @@ func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
switch scan {
case _INT:
return reflect.TypeFor[int64]()
return reflect.TypeOf(int64(0))
case _REAL:
return reflect.TypeFor[float64]()
return reflect.TypeOf(float64(0))
case _TEXT:
return reflect.TypeFor[string]()
return reflect.TypeOf("")
case _BLOB:
return reflect.TypeFor[[]byte]()
return reflect.TypeOf([]byte{})
case _BOOL:
return reflect.TypeFor[bool]()
return reflect.TypeOf(false)
case _TIME:
return reflect.TypeFor[time.Time]()
return reflect.TypeOf(time.Time{})
default:
return reflect.TypeFor[any]()
return reflect.TypeOf((*any)(nil)).Elem()
}
}
func (r *rows) Next(dest []driver.Value) error {
c := r.Stmt.Conn()
if old := c.SetInterrupt(r.ctx); old != r.ctx {
defer c.SetInterrupt(old)
}
old := r.Stmt.Conn().SetInterrupt(r.ctx)
defer r.Stmt.Conn().SetInterrupt(old)
if !r.Stmt.Step() {
if err := r.Stmt.Err(); err != nil {
@@ -797,41 +746,36 @@ func (r *rows) Next(dest []driver.Value) error {
}
data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest))
if err := r.Stmt.ColumnsRaw(data...); err != nil {
return err
}
err := r.Stmt.Columns(data...)
for i := range dest {
scan := r.scanType(i)
if v, ok := dest[i].([]byte); ok {
if len(v) == cap(v) { // a BLOB
continue
}
if scan != _TEXT {
switch r.tmWrite {
case "", time.RFC3339, time.RFC3339Nano:
t, ok := maybeTime(v)
if ok {
dest[i] = t
continue
}
}
}
dest[i] = string(v)
}
switch scan {
case _TIME:
t, err := r.tmRead.Decode(dest[i])
if err == nil {
dest[i] = t
}
case _BOOL:
switch dest[i] {
case int64(0):
dest[i] = false
case int64(1):
dest[i] = true
}
if t, ok := r.decodeTime(i, dest[i]); ok {
dest[i] = t
}
}
return nil
return err
}
func (r *rows) decodeTime(i int, v any) (_ time.Time, ok bool) {
switch v := v.(type) {
case int64, float64:
// could be a time value
case string:
if r.tmWrite != "" && r.tmWrite != time.RFC3339 && r.tmWrite != time.RFC3339Nano {
break
}
t, ok := maybeTime(v)
if ok {
return t, true
}
default:
return
}
switch r.declType(i) {
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
// could be a time value
default:
return
}
t, err := r.tmRead.Decode(v)
return t, err == nil
}

View File

@@ -33,7 +33,7 @@ func Test_Open_error(t *testing.T) {
func Test_Open_dir(t *testing.T) {
t.Parallel()
db, err := Open(".")
db, err := sql.Open("sqlite3", ".")
if err != nil {
t.Fatal(err)
}
@@ -43,18 +43,18 @@ func Test_Open_dir(t *testing.T) {
if err == nil {
t.Fatal("want error")
}
if !errors.Is(err, sqlite3.CANTOPEN_ISDIR) {
t.Errorf("got %v, want sqlite3.CANTOPEN_ISDIR", err)
if !errors.Is(err, sqlite3.CANTOPEN) {
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
}
}
func Test_Open_pragma(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t, url.Values{
tmp := memdb.TestDB(t, url.Values{
"_pragma": {"busy_timeout(1000)"},
})
db, err := Open(dsn)
db, err := sql.Open("sqlite3", tmp)
if err != nil {
t.Fatal(err)
}
@@ -72,11 +72,11 @@ func Test_Open_pragma(t *testing.T) {
func Test_Open_pragma_invalid(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t, url.Values{
tmp := memdb.TestDB(t, url.Values{
"_pragma": {"busy_timeout 1000"},
})
db, err := Open(dsn)
db, err := sql.Open("sqlite3", tmp)
if err != nil {
t.Fatal(err)
}
@@ -100,12 +100,12 @@ func Test_Open_pragma_invalid(t *testing.T) {
func Test_Open_txLock(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t, url.Values{
tmp := memdb.TestDB(t, url.Values{
"_txlock": {"exclusive"},
"_pragma": {"busy_timeout(1000)"},
})
db, err := Open(dsn)
db, err := sql.Open("sqlite3", tmp)
if err != nil {
t.Fatal(err)
}
@@ -136,11 +136,11 @@ func Test_Open_txLock(t *testing.T) {
func Test_Open_txLock_invalid(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t, url.Values{
tmp := memdb.TestDB(t, url.Values{
"_txlock": {"xclusive"},
})
_, err := Open(dsn)
_, err := sql.Open("sqlite3", tmp+"_txlock=xclusive")
if err == nil {
t.Fatal("want error")
}
@@ -151,28 +151,31 @@ func Test_Open_txLock_invalid(t *testing.T) {
func Test_BeginTx(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t, url.Values{
tmp := memdb.TestDB(t, url.Values{
"_txlock": {"exclusive"},
"_pragma": {"busy_timeout(0)"},
})
db, err := Open(dsn)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", tmp)
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.BeginTx(t.Context(), &sql.TxOptions{Isolation: sql.LevelReadCommitted})
_, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err.Error() != string(util.IsolationErr) {
t.Error("want isolationErr")
}
tx1, err := db.BeginTx(t.Context(), &sql.TxOptions{ReadOnly: true})
tx1, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
t.Fatal(err)
}
tx2, err := db.BeginTx(t.Context(), &sql.TxOptions{ReadOnly: true})
tx2, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
t.Fatal(err)
}
@@ -196,69 +199,11 @@ func Test_BeginTx(t *testing.T) {
}
}
func Test_nested_context(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
db, err := Open(dsn)
if err != nil {
t.Fatal(err)
}
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
outer, err := tx.Query(`SELECT value FROM generate_series(0)`)
if err != nil {
t.Fatal(err)
}
defer outer.Close()
want := func(rows *sql.Rows, want int) {
t.Helper()
var got int
rows.Next()
if err := rows.Scan(&got); err != nil {
t.Fatal(err)
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
want(outer, 0)
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
inner, err := tx.QueryContext(ctx, `SELECT value FROM generate_series(0)`)
if err != nil {
t.Fatal(err)
}
defer inner.Close()
want(inner, 0)
cancel()
var terr interface{ Temporary() bool }
if inner.Next() || !errors.Is(inner.Err(), context.Canceled) &&
(!errors.As(inner.Err(), &terr) || !terr.Temporary()) {
t.Fatalf("got %v, want cancellation", inner.Err())
}
want(outer, 1)
}
func Test_Prepare(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := Open(dsn)
db, err := sql.Open("sqlite3", tmp)
if err != nil {
t.Fatal(err)
}
@@ -297,21 +242,24 @@ func Test_Prepare(t *testing.T) {
func Test_QueryRow_named(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := Open(dsn)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", tmp)
if err != nil {
t.Fatal(err)
}
defer db.Close()
conn, err := db.Conn(t.Context())
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
stmt, err := conn.PrepareContext(t.Context(), `SELECT ?, ?5, :AAA, @AAA, $AAA`)
stmt, err := conn.PrepareContext(ctx, `SELECT ?, ?5, :AAA, @AAA, $AAA`)
if err != nil {
t.Fatal(err)
}
@@ -347,9 +295,9 @@ func Test_QueryRow_named(t *testing.T) {
func Test_QueryRow_blob_null(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := Open(dsn)
db, err := sql.Open("sqlite3", tmp)
if err != nil {
t.Fatal(err)
}
@@ -384,11 +332,11 @@ func Test_time(t *testing.T) {
for _, fmt := range []string{"auto", "sqlite", "rfc3339", time.ANSIC} {
t.Run(fmt, func(t *testing.T) {
dsn := memdb.TestDB(t, url.Values{
tmp := memdb.TestDB(t, url.Values{
"_timefmt": {fmt},
})
db, err := Open(dsn)
db, err := sql.Open("sqlite3", tmp)
if err != nil {
t.Fatal(err)
}
@@ -421,19 +369,19 @@ func Test_time(t *testing.T) {
func Test_ColumnType_ScanType(t *testing.T) {
var (
INT = reflect.TypeFor[int64]()
REAL = reflect.TypeFor[float64]()
TEXT = reflect.TypeFor[string]()
BLOB = reflect.TypeFor[[]byte]()
BOOL = reflect.TypeFor[bool]()
TIME = reflect.TypeFor[time.Time]()
ANY = reflect.TypeFor[any]()
INT = reflect.TypeOf(int64(0))
REAL = reflect.TypeOf(float64(0))
TEXT = reflect.TypeOf("")
BLOB = reflect.TypeOf([]byte{})
BOOL = reflect.TypeOf(false)
TIME = reflect.TypeOf(time.Time{})
ANY = reflect.TypeOf((*any)(nil)).Elem()
)
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := Open(dsn)
db, err := sql.Open("sqlite3", tmp)
if err != nil {
t.Fatal(err)
}
@@ -519,25 +467,3 @@ func Test_ColumnType_ScanType(t *testing.T) {
t.Fatal(err)
}
}
func Benchmark_loop(b *testing.B) {
db, err := Open(":memory:")
if err != nil {
b.Fatal(err)
}
defer db.Close()
var version string
err = db.QueryRow(`SELECT sqlite_version();`).Scan(&version)
if err != nil {
b.Fatal(err)
}
for b.Loop() {
_, err := db.ExecContext(b.Context(),
`WITH RECURSIVE c(x) AS (VALUES(1) UNION ALL SELECT x+1 FROM c WHERE x < 1000000) SELECT x FROM c;`)
if err != nil {
b.Fatal(err)
}
}
}

View File

@@ -1,5 +1,9 @@
//go:build linux || darwin || windows || freebsd || openbsd || netbsd || dragonfly || illumos || sqlite3_flock || sqlite3_dotlk
package driver_test
// Adapted from: https://go.dev/doc/tutorial/database-access
import (
"database/sql"
"database/sql/driver"
@@ -23,7 +27,7 @@ func Example_customTime() {
_, err = db.Exec(`
CREATE TABLE data (
id INTEGER PRIMARY KEY,
date_time ANY
date_time TEXT
) STRICT;
`)
if err != nil {

View File

@@ -1,15 +1,12 @@
package driver
import (
"bytes"
"time"
)
import "time"
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
// if it roundtrips back to the same string.
// This way times can be persisted to, and recovered from, the database,
// but if a string is needed, [database/sql] will recover the same string.
func maybeTime(text []byte) (_ time.Time, _ bool) {
func maybeTime(text string) (_ time.Time, _ bool) {
// Weed out (some) values that can't possibly be
// [time.RFC3339Nano] timestamps.
if len(text) < len("2006-01-02T15:04:05Z") {
@@ -24,8 +21,8 @@ func maybeTime(text []byte) (_ time.Time, _ bool) {
// Slow path.
var buf [len(time.RFC3339Nano)]byte
date, err := time.Parse(time.RFC3339Nano, string(text))
if err == nil && bytes.Equal(text, date.AppendFormat(buf[:0], time.RFC3339Nano)) {
date, err := time.Parse(time.RFC3339Nano, text)
if err == nil && text == string(date.AppendFormat(buf[:0], time.RFC3339Nano)) {
return date, true
}
return

View File

@@ -22,7 +22,7 @@ func Fuzz_stringOrTime_1(f *testing.F) {
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
f.Fuzz(func(t *testing.T, str string) {
v, ok := maybeTime([]byte(str))
v, ok := maybeTime(str)
if ok {
// Make sure times round-trip to the same string:
// https://pkg.go.dev/database/sql#Rows.Scan
@@ -51,7 +51,7 @@ func Fuzz_stringOrTime_2(f *testing.F) {
f.Add(int64(-763421161058), int64(222_222_222)) // twosday, year 22222BC
checkTime := func(t testing.TB, date time.Time) {
v, ok := maybeTime(date.AppendFormat(nil, time.RFC3339Nano))
v, ok := maybeTime(date.Format(time.RFC3339Nano))
if ok {
// Make sure times round-trip to the same time:
if !v.Equal(date) {

View File

@@ -1,8 +1,9 @@
package driver
import (
"context"
"database/sql/driver"
"slices"
"reflect"
"testing"
_ "github.com/ncruces/go-sqlite3/embed"
@@ -15,7 +16,7 @@ func Test_namedValues(t *testing.T) {
{Ordinal: 2, Value: false},
}
got := namedValues([]driver.Value{true, false})
if !slices.Equal(got, want) {
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
}
@@ -55,7 +56,7 @@ func Fuzz_notWhitespace(f *testing.F) {
t.SkipNow()
}
c, err := db.Conn(t.Context())
c, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}

View File

@@ -1,6 +1,6 @@
# Embeddable Wasm build of SQLite
This folder includes an embeddable Wasm build of SQLite 3.51.1 for use with
This folder includes an embeddable Wasm build of SQLite 3.47.2 for use with
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
The following optional features are compiled in:

View File

@@ -1,19 +1,13 @@
# Embeddable Wasm build of SQLite
This folder includes an alternative embeddable Wasm build of SQLite,
which includes the experimental
This folder includes an embeddable Wasm build of SQLite, including the experimental
[`BEGIN CONCURRENT`](https://sqlite.org/src/doc/begin-concurrent/doc/begin_concurrent.md) and
[Wal2](https://sqlite.org/cgi/src/doc/wal2/doc/wal2.md) patches.
It also enables the optional
[`UPDATE … ORDER BY … LIMIT`](https://sqlite.org/lang_update.html#optional_limit_and_order_by_clauses) and
[`DELETE … ORDER BY … LIMIT`](https://sqlite.org/lang_delete.html#optional_limit_and_order_by_clauses) clauses,
and the [`WITHIN GROUP ORDER BY`](https://sqlite.org/compile.html#enable_ordered_set_aggregates) aggregate syntax.
> [!IMPORTANT]
> This package is experimental.
> It is built from the `bedrock` branch of SQLite,
> since that is _currently_ the most stable, maintained branch to include these features.
> since that is _currently_ the most stable, maintained branch to include both features.
> [!CAUTION]
> The Wal2 journaling mode creates databases that other versions of SQLite cannot access.

Binary file not shown.

View File

@@ -5,7 +5,6 @@ import (
"testing"
"github.com/ncruces/go-sqlite3/driver"
"github.com/ncruces/go-sqlite3/ext/stats"
"github.com/ncruces/go-sqlite3/vfs"
)
@@ -16,7 +15,7 @@ func Test_bcw2(t *testing.T) {
tmp := filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))
db, err := driver.Open("file:"+tmp+"?_pragma=journal_mode(wal2)&_txlock=concurrent", stats.Register)
db, err := driver.Open("file:" + tmp + "?_pragma=journal_mode(wal2)&_txlock=concurrent")
if err != nil {
t.Fatal(err)
}
@@ -38,11 +37,6 @@ func Test_bcw2(t *testing.T) {
t.Fatal(err)
}
_, err = tx.Exec(`SELECT median() WITHIN GROUP (ORDER BY col) FROM test`)
if err != nil {
t.Fatal(err)
}
err = tx.Commit()
if err != nil {
t.Fatal(err)
@@ -53,7 +47,7 @@ func Test_bcw2(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if version != "3.52.0" {
if version != "3.48.0" {
t.Error(version)
}
}

View File

@@ -7,24 +7,20 @@ ROOT=../../
BINARYEN="$ROOT/tools/binaryen/bin"
WASI_SDK="$ROOT/tools/wasi-sdk/bin"
trap 'rm -rf sqlite/ build/ bcw2.tmp' EXIT
trap 'rm -rf build/ sqlite/ bcw2.tmp' EXIT
mkdir -p sqlite/
mkdir -p build/ext/
cp "$ROOT"/sqlite3/*.[ch] build/
cp "$ROOT"/sqlite3/*.patch build/
cd sqlite/
# https://sqlite.org/src/info/f273f6b8245c5dca
curl -#L https://github.com/sqlite/sqlite/archive/7c126d7.tar.gz | tar xz --strip-components=1
# curl -#L https://sqlite.org/src/tarball/sqlite.tar.gz?r=f273f6b824 | tar xz --strip-components=1
# https://sqlite.org/src/info/ec5d7025cba9f4ac
curl -# https://sqlite.org/src/tarball/sqlite.tar.gz?r=ec5d7025 | tar xz
cd sqlite
if [[ "$OSTYPE" == "msys" || "$OSTYPE" == "cygwin" ]]; then
MSYS_NO_PATHCONV=1 nmake /f makefile.msc sqlite3.c "OPTS=-DSQLITE_ENABLE_UPDATE_DELETE_LIMIT -DSQLITE_ENABLE_ORDERED_SET_AGGREGATES"
MSYS_NO_PATHCONV=1 nmake /f makefile.msc sqlite3.c OPTS=-DSQLITE_ENABLE_UPDATE_DELETE_LIMIT
else
sh configure --enable-update-limit
make verify-source
OPTS=-DSQLITE_ENABLE_ORDERED_SET_AGGREGATES make sqlite3.c
sh configure --enable-update-limit && make sqlite3.c
fi
cd ~-
@@ -41,33 +37,29 @@ mv sqlite/ext/misc/spellfix.c build/ext/
mv sqlite/ext/misc/uint.c build/ext/
cd build
cat *.patch | patch -p0 --no-backup-if-mismatch
cat *.patch | patch --no-backup-if-mismatch
cd ~-
"$WASI_SDK/clang" --target=wasm32-wasi -std=c23 -g0 -O2 \
-Wall -Wextra -Wno-unused-parameter -Wno-unused-function \
-o bcw2.wasm build/main.c \
-I"$ROOT/sqlite3/libc" -I"build" \
-o bcw2.wasm "build/main.c" \
-I"build" \
-mexec-model=reactor \
-mmutable-globals -mnontrapping-fptoint \
-msimd128 -mbulk-memory -msign-ext \
-mreference-types -mmultivalue \
-mno-extended-const \
-fno-stack-protector \
-msimd128 -mmutable-globals -mmultivalue \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-fno-stack-protector -fno-stack-clash-protection \
-Wl,--stack-first \
-Wl,--import-undefined \
-Wl,--initial-memory=327680 \
-D_HAVE_SQLITE_CONFIG_H \
-DSQLITE_ENABLE_UPDATE_DELETE_LIMIT \
-DSQLITE_ENABLE_ORDERED_SET_AGGREGATES \
-DSQLITE_EXPERIMENTAL_PRAGMA_20251114 \
-DSQLITE_CUSTOM_INCLUDE=sqlite_opt.h \
$(awk '{print "-Wl,--export="$0}' ../exports.txt)
"$BINARYEN/wasm-ctor-eval" -g -c _initialize bcw2.wasm -o bcw2.tmp
"$BINARYEN/wasm-opt" -g bcw2.tmp -o bcw2.wasm \
--gufa --generate-global-effects --low-memory-unused --converge -O3 \
--enable-mutable-globals --enable-nontrapping-float-to-int \
--enable-simd --enable-bulk-memory --enable-sign-ext \
--enable-reference-types --enable-multivalue \
--strip --strip-producers
"$BINARYEN/wasm-opt" -g --strip --strip-producers -c -O3 \
bcw2.tmp -o bcw2.wasm \
--enable-simd --enable-mutable-globals --enable-multivalue \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext

View File

@@ -1,12 +1,13 @@
module github.com/ncruces/go-sqlite3/embed/bcw2
go 1.24.0
go 1.21
require github.com/ncruces/go-sqlite3 v0.30.3
toolchain go1.23.0
require github.com/ncruces/go-sqlite3 v0.21.3
require (
github.com/ncruces/julianday v1.0.0 // indirect
github.com/ncruces/sort v0.1.6 // indirect
github.com/tetratelabs/wazero v1.11.0 // indirect
golang.org/x/sys v0.39.0 // indirect
github.com/tetratelabs/wazero v1.8.2 // indirect
golang.org/x/sys v0.29.0 // indirect
)

View File

@@ -1,12 +1,10 @@
github.com/ncruces/go-sqlite3 v0.30.3 h1:X/CgWW9GzmIAkEPrifhKqf0cC15DuOVxAJaHFTTAURQ=
github.com/ncruces/go-sqlite3 v0.30.3/go.mod h1:AxKu9sRxkludimFocbktlY6LiYSkxiI5gTA8r+os/Nw=
github.com/ncruces/go-sqlite3 v0.21.3 h1:hHkfNQLcbnxPJZhC/RGw9SwP3bfkv/Y0xUHWsr1CdMQ=
github.com/ncruces/go-sqlite3 v0.21.3/go.mod h1:zxMOaSG5kFYVFK4xQa0pdwIszqxqJ0W0BxBgwdrNjuA=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/ncruces/sort v0.1.6 h1:TrsJfGRH1AoWoaeB4/+gCohot9+cA6u/INaH5agIhNk=
github.com/ncruces/sort v0.1.6/go.mod h1:obJToO4rYr6VWP0Uw5FYymgYGt3Br4RXcs/JdKaXAPk=
github.com/tetratelabs/wazero v1.11.0 h1:+gKemEuKCTevU4d7ZTzlsvgd1uaToIDtlQlmNbwqYhA=
github.com/tetratelabs/wazero v1.11.0/go.mod h1:eV28rsN8Q+xwjogd7f4/Pp4xFxO7uOGbLcD/LzB1wiU=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
github.com/tetratelabs/wazero v1.8.2 h1:yIgLR/b2bN31bjxwXHD8a3d+BogigR952csSDdLYEv4=
github.com/tetratelabs/wazero v1.8.2/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=

View File

@@ -11,14 +11,13 @@ package bcw2
import (
_ "embed"
"unsafe"
"github.com/ncruces/go-sqlite3"
)
//go:embed bcw2.wasm
var binary string
var binary []byte
func init() {
sqlite3.Binary = unsafe.Slice(unsafe.StringData(binary), len(binary))
sqlite3.Binary = binary
}

View File

@@ -12,25 +12,22 @@ trap 'rm -f sqlite3.tmp' EXIT
"$WASI_SDK/clang" --target=wasm32-wasi -std=c23 -g0 -O2 \
-Wall -Wextra -Wno-unused-parameter -Wno-unused-function \
-o sqlite3.wasm "$ROOT/sqlite3/main.c" \
-I"$ROOT/sqlite3/libc" -I"$ROOT/sqlite3" \
-I"$ROOT/sqlite3" \
-mexec-model=reactor \
-mmutable-globals -mnontrapping-fptoint \
-msimd128 -mbulk-memory -msign-ext \
-mreference-types -mmultivalue \
-mno-extended-const \
-fno-stack-protector \
-msimd128 -mmutable-globals -mmultivalue \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-fno-stack-protector -fno-stack-clash-protection \
-Wl,--stack-first \
-Wl,--import-undefined \
-Wl,--initial-memory=327680 \
-D_HAVE_SQLITE_CONFIG_H \
-DSQLITE_EXPERIMENTAL_PRAGMA_20251114 \
-DSQLITE_CUSTOM_INCLUDE=sqlite_opt.h \
$(awk '{print "-Wl,--export="$0}' exports.txt)
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
"$BINARYEN/wasm-opt" -g sqlite3.tmp -o sqlite3.wasm \
--gufa --generate-global-effects --low-memory-unused --converge -O3 \
--enable-mutable-globals --enable-nontrapping-float-to-int \
--enable-simd --enable-bulk-memory --enable-sign-ext \
--enable-reference-types --enable-multivalue \
--strip --strip-producers
"$BINARYEN/wasm-opt" -g --strip --strip-producers -c -O3 \
sqlite3.tmp -o sqlite3.wasm \
--enable-simd --enable-mutable-globals --enable-multivalue \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext

View File

@@ -59,14 +59,13 @@ sqlite3_db_filename
sqlite3_db_name
sqlite3_db_readonly
sqlite3_db_release_memory
sqlite3_db_status64
sqlite3_db_status
sqlite3_declare_vtab
sqlite3_errcode
sqlite3_errmsg
sqlite3_error_offset
sqlite3_errstr
sqlite3_exec
sqlite3_exec_go
sqlite3_expanded_sql
sqlite3_file_control
sqlite3_filename_database
@@ -78,10 +77,8 @@ sqlite3_get_autocommit
sqlite3_get_auxdata
sqlite3_hard_heap_limit64
sqlite3_interrupt
sqlite3_invoke_busy_handler_go
sqlite3_last_insert_rowid
sqlite3_limit
sqlite3_log_go
sqlite3_malloc64
sqlite3_open_v2
sqlite3_overload_function
@@ -98,7 +95,6 @@ sqlite3_result_error_toobig
sqlite3_result_int64
sqlite3_result_null
sqlite3_result_pointer_go
sqlite3_result_subtype
sqlite3_result_text_go
sqlite3_result_value
sqlite3_result_zeroblob64
@@ -127,7 +123,6 @@ sqlite3_value_int64
sqlite3_value_nochange
sqlite3_value_numeric_type
sqlite3_value_pointer_go
sqlite3_value_subtype
sqlite3_value_text
sqlite3_value_type
sqlite3_vtab_collation

View File

@@ -8,16 +8,13 @@ package embed
import (
_ "embed"
"unsafe"
"github.com/ncruces/go-sqlite3"
)
//go:embed sqlite3.wasm
var binary string
var binary []byte
func init() {
if sqlite3.Binary == nil {
sqlite3.Binary = unsafe.Slice(unsafe.StringData(binary), len(binary))
}
sqlite3.Binary = binary
}

View File

@@ -19,7 +19,7 @@ func Test_init(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if version != "3.51.1" {
if version != "3.47.2" {
t.Error(version)
}
}

Binary file not shown.

View File

@@ -2,6 +2,7 @@ package sqlite3
import (
"errors"
"strconv"
"strings"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -11,10 +12,10 @@ import (
//
// https://sqlite.org/c3ref/errcode.html
type Error struct {
sys error
str string
msg string
sql string
code res_t
code uint64
}
// Code returns the primary error code for this error.
@@ -28,34 +29,28 @@ func (e *Error) Code() ErrorCode {
//
// https://sqlite.org/rescode.html
func (e *Error) ExtendedCode() ExtendedErrorCode {
return xErrorCode(e.code)
return ExtendedErrorCode(e.code)
}
// Error implements the error interface.
func (e *Error) Error() string {
var b strings.Builder
b.WriteString(util.ErrorCodeString(e.code))
b.WriteString("sqlite3: ")
if e.str != "" {
b.WriteString(e.str)
} else {
b.WriteString(strconv.Itoa(int(e.code)))
}
if e.msg != "" {
b.WriteString(": ")
b.WriteString(e.msg)
}
if e.sys != nil {
b.WriteString(": ")
b.WriteString(e.sys.Error())
}
return b.String()
}
// Unwrap returns the underlying operating system error
// that caused the I/O error or failure to open a file.
//
// https://sqlite.org/c3ref/system_errno.html
func (e *Error) Unwrap() error {
return e.sys
}
// Is tests whether this error matches a given [ErrorCode] or [ExtendedErrorCode].
//
// It makes it possible to do:
@@ -88,7 +83,7 @@ func (e *Error) As(err any) bool {
// Temporary returns true for [BUSY] errors.
func (e *Error) Temporary() bool {
return e.Code() == BUSY || e.Code() == INTERRUPT
return e.Code() == BUSY
}
// Timeout returns true for [BUSY_TIMEOUT] errors.
@@ -103,31 +98,22 @@ func (e *Error) SQL() string {
// Error implements the error interface.
func (e ErrorCode) Error() string {
return util.ErrorCodeString(e)
}
// As converts this error to an [ExtendedErrorCode].
func (e ErrorCode) As(err any) bool {
c, ok := err.(*xErrorCode)
if ok {
*c = xErrorCode(e)
}
return ok
return util.ErrorCodeString(uint32(e))
}
// Temporary returns true for [BUSY] errors.
func (e ErrorCode) Temporary() bool {
return e == BUSY || e == INTERRUPT
return e == BUSY
}
// ExtendedCode returns the extended error code for this error.
func (e ErrorCode) ExtendedCode() ExtendedErrorCode {
return xErrorCode(e)
return ExtendedErrorCode(e)
}
// Error implements the error interface.
func (e ExtendedErrorCode) Error() string {
return util.ErrorCodeString(e)
return util.ErrorCodeString(uint32(e))
}
// Is tests whether this error matches a given [ErrorCode].
@@ -147,7 +133,7 @@ func (e ExtendedErrorCode) As(err any) bool {
// Temporary returns true for [BUSY] errors.
func (e ExtendedErrorCode) Temporary() bool {
return ErrorCode(e) == BUSY || ErrorCode(e) == INTERRUPT
return ErrorCode(e) == BUSY
}
// Timeout returns true for [BUSY_TIMEOUT] errors.
@@ -160,23 +146,27 @@ func (e ExtendedErrorCode) Code() ErrorCode {
return ErrorCode(e)
}
func errorCode(err error, def ErrorCode) (msg string, code res_t) {
func errorCode(err error, def ErrorCode) (msg string, code uint32) {
switch code := err.(type) {
case nil:
return "", _OK
case ErrorCode:
return "", res_t(code)
return "", uint32(code)
case xErrorCode:
return "", res_t(code)
return "", uint32(code)
case *Error:
return code.msg, res_t(code.code)
return code.msg, uint32(code.code)
}
var ecode ErrorCode
var xcode xErrorCode
if errors.As(err, &xcode) {
code = res_t(xcode)
} else {
code = res_t(def)
switch {
case errors.As(err, &xcode):
code = uint32(xcode)
case errors.As(err, &ecode):
code = uint32(ecode)
default:
code = uint32(def)
}
return err.Error(), code
}

View File

@@ -43,7 +43,7 @@ func TestError(t *testing.T) {
if !errors.Is(err, xErrorCode(0x8080)) {
t.Errorf("want true")
}
if s := err.Error(); s != "sqlite3: unknown error" {
if s := err.Error(); s != "sqlite3: 32896" {
t.Errorf("got %q", s)
}
if ok := errors.As(err.ExtendedCode(), &ecode); !ok || ecode != ErrorCode(0x80) {
@@ -59,14 +59,14 @@ func TestError_Temporary(t *testing.T) {
tests := []struct {
name string
code res_t
code uint64
want bool
}{
{"ERROR", res_t(ERROR), false},
{"BUSY", res_t(BUSY), true},
{"BUSY_RECOVERY", res_t(BUSY_RECOVERY), true},
{"BUSY_SNAPSHOT", res_t(BUSY_SNAPSHOT), true},
{"BUSY_TIMEOUT", res_t(BUSY_TIMEOUT), true},
{"ERROR", uint64(ERROR), false},
{"BUSY", uint64(BUSY), true},
{"BUSY_RECOVERY", uint64(BUSY_RECOVERY), true},
{"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), true},
{"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -83,7 +83,7 @@ func TestError_Temporary(t *testing.T) {
}
}
{
err := xErrorCode(tt.code)
err := ExtendedErrorCode(tt.code)
if got := err.Temporary(); got != tt.want {
t.Errorf("ExtendedErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
}
@@ -97,14 +97,14 @@ func TestError_Timeout(t *testing.T) {
tests := []struct {
name string
code res_t
code uint64
want bool
}{
{"ERROR", res_t(ERROR), false},
{"BUSY", res_t(BUSY), false},
{"BUSY_RECOVERY", res_t(BUSY_RECOVERY), false},
{"BUSY_SNAPSHOT", res_t(BUSY_SNAPSHOT), false},
{"BUSY_TIMEOUT", res_t(BUSY_TIMEOUT), true},
{"ERROR", uint64(ERROR), false},
{"BUSY", uint64(BUSY), false},
{"BUSY_RECOVERY", uint64(BUSY_RECOVERY), false},
{"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), false},
{"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -115,7 +115,7 @@ func TestError_Timeout(t *testing.T) {
}
}
{
err := xErrorCode(tt.code)
err := ExtendedErrorCode(tt.code)
if got := err.Timeout(); got != tt.want {
t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want)
}
@@ -136,8 +136,8 @@ func Test_ErrorCode_Error(t *testing.T) {
// Test all error codes.
for i := 0; i == int(ErrorCode(i)); i++ {
want := "sqlite3: "
ptr := ptr_t(db.call("sqlite3_errstr", stk_t(i)))
want += util.ReadString(db.mod, ptr, _MAX_NAME)
r := db.call("sqlite3_errstr", uint64(i))
want += util.ReadString(db.mod, uint32(r), _MAX_NAME)
got := ErrorCode(i).Error()
if got != want {
@@ -156,12 +156,12 @@ func Test_ExtendedErrorCode_Error(t *testing.T) {
defer db.Close()
// Test all extended error codes.
for i := 0; i == int(xErrorCode(i)); i++ {
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
want := "sqlite3: "
ptr := ptr_t(db.call("sqlite3_errstr", stk_t(i)))
want += util.ReadString(db.mod, ptr, _MAX_NAME)
r := db.call("sqlite3_errstr", uint64(i))
want += util.ReadString(db.mod, uint32(r), _MAX_NAME)
got := xErrorCode(i).Error()
got := ExtendedErrorCode(i).Error()
if got != want {
t.Fatalf("got %q, want %q, with %d", got, want, i)
}
@@ -172,7 +172,7 @@ func Test_errorCode(t *testing.T) {
tests := []struct {
arg error
wantMsg string
wantCode res_t
wantCode uint32
}{
{nil, "", _OK},
{ERROR, "", util.ERROR},
@@ -190,7 +190,7 @@ func Test_errorCode(t *testing.T) {
if gotMsg != tt.wantMsg {
t.Errorf("errorCode() gotMsg = %q, want %q", gotMsg, tt.wantMsg)
}
if gotCode != tt.wantCode {
if gotCode != uint32(tt.wantCode) {
t.Errorf("errorCode() gotCode = %d, want %d", gotCode, tt.wantCode)
}
})

View File

@@ -25,24 +25,13 @@ you can load into your database connections.
creates [pivot tables](https://github.com/jakethaw/pivot_vtab).
- [`github.com/ncruces/go-sqlite3/ext/regexp`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/regexp)
provides regular expression functions.
- [`github.com/ncruces/go-sqlite3/ext/serdes`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/serdes)
(de)serializes databases.
- [`github.com/ncruces/go-sqlite3/ext/statement`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/statement)
creates [parameterized views](https://github.com/0x09/sqlite-statement-vtab).
- [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
provides [statistics](https://oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html) functions.
provides [statistics](https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html) functions.
- [`github.com/ncruces/go-sqlite3/ext/unicode`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/unicode)
provides [Unicode aware](https://sqlite.org/src/dir/ext/icu) functions.
- [`github.com/ncruces/go-sqlite3/ext/uuid`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/uuid)
generates [UUIDs](https://en.wikipedia.org/wiki/Universally_unique_identifier).
- [`github.com/ncruces/go-sqlite3/ext/zorder`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/zorder)
maps multidimensional data to one dimension.
### Packages
These packages may also be useful to work with SQLite:
- [`github.com/ncruces/decimal`](https://pkg.go.dev/github.com/ncruces/decimal)
decimal arithmetic.
- [`github.com/ncruces/julianday`](https://pkg.go.dev/github.com/ncruces/julianday)
Julian day math.
maps multidimensional data to one dimension.

View File

@@ -59,8 +59,7 @@ func (c *cursor) Next() error {
}
func (c *cursor) RowID() (int64, error) {
// One-based RowID for consistency with carray and other tables.
return int64(c.rowID) + 1, nil
return int64(c.rowID), nil
}
func (c *cursor) Column(ctx sqlite3.Context, n int) error {

View File

@@ -88,9 +88,9 @@ func Example() {
func Test_cursor_Column(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := driver.Open(dsn, array.Register)
db, err := driver.Open(tmp, array.Register)
if err != nil {
t.Fatal(err)
}

View File

@@ -1,11 +1,10 @@
package blobio_test
import (
"database/sql"
"io"
"log"
"os"
"slices"
"reflect"
"strings"
"testing"
@@ -35,8 +34,7 @@ func Example() {
const message = "Hello BLOB!"
// Create the BLOB.
r, err := db.Exec(`INSERT INTO test VALUES (:data)`,
sql.Named("data", sqlite3.ZeroBlob(len(message))))
r, err := db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(len(message)))
if err != nil {
log.Fatal(err)
}
@@ -47,19 +45,15 @@ func Example() {
}
// Write the BLOB.
_, err = db.Exec(`SELECT writeblob('main', 'test', 'col', :rowid, :offset, :message)`,
sql.Named("rowid", id),
sql.Named("offset", 0),
sql.Named("message", message))
_, err = db.Exec(`SELECT writeblob('main', 'test', 'col', ?, 0, ?)`,
id, message)
if err != nil {
log.Fatal(err)
}
// Read the BLOB.
_, err = db.Exec(`SELECT readblob('main', 'test', 'col', :rowid, :offset, :writer)`,
sql.Named("rowid", id),
sql.Named("offset", 0),
sql.Named("writer", sqlite3.Pointer(os.Stdout)))
_, err = db.Exec(`SELECT readblob('main', 'test', 'col', ?, 0, ?)`,
id, sqlite3.Pointer(os.Stdout))
if err != nil {
log.Fatal(err)
}
@@ -70,7 +64,7 @@ func Example() {
func TestMain(m *testing.M) {
sqlite3.AutoExtension(blobio.Register)
sqlite3.AutoExtension(array.Register)
os.Exit(m.Run())
m.Run()
}
func Test_readblob(t *testing.T) {
@@ -144,16 +138,18 @@ func Test_readblob(t *testing.T) {
}
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnText(0); got != tt.want1 {
t.Errorf("got %q", got)
if stmt.Step() {
got := stmt.ColumnText(0)
if got != tt.want1 {
t.Errorf("got %q", got)
}
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnText(0); got != tt.want2 {
t.Errorf("got %q", got)
if stmt.Step() {
got := stmt.ColumnText(0)
if got != tt.want2 {
t.Errorf("got %q", got)
}
}
err = stmt.Err()
@@ -282,7 +278,7 @@ func Test_openblob(t *testing.T) {
}
want := []string{"\xca\xfe", "\xba\xbe"}
if !slices.Equal(got, want) {
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
}

View File

@@ -16,7 +16,6 @@ import (
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/sql3util"
)
// Register registers the bloom_filter virtual table:
@@ -35,8 +34,6 @@ type bloom struct {
hashes int
}
const vtab = `CREATE TABLE x(present, word TEXT HIDDEN NOT NULL PRIMARY KEY) WITHOUT ROWID`
func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom, err error) {
b := bloom{
db: db,
@@ -58,9 +55,11 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom,
}
if len(arg) > 1 {
var ok bool
b.prob, ok = sql3util.ParseFloat(arg[1])
if !ok || b.prob <= 0 || b.prob >= 1 {
b.prob, err = strconv.ParseFloat(arg[1], 64)
if err != nil {
return nil, err
}
if b.prob <= 0 || b.prob >= 1 {
return nil, util.ErrorString("bloom: probability must be in the range (0,1)")
}
} else {
@@ -81,7 +80,8 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom,
b.bytes = numBytes(nelem, b.prob)
err = db.DeclareVTab(vtab)
err = db.DeclareVTab(
`CREATE TABLE x(present, word HIDDEN NOT NULL PRIMARY KEY) WITHOUT ROWID`)
if err != nil {
return nil, err
}
@@ -115,15 +115,15 @@ func connect(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom
storage: table + "_storage",
}
err = db.DeclareVTab(vtab)
err = db.DeclareVTab(
`CREATE TABLE x(present, word HIDDEN NOT NULL PRIMARY KEY) WITHOUT ROWID`)
if err != nil {
return nil, err
}
load, _, err := db.PrepareFlags(fmt.Sprintf(
load, _, err := db.Prepare(fmt.Sprintf(
`SELECT m/8, p, k FROM %s.%s WHERE rowid = 1`,
sqlite3.QuoteIdentifier(b.schema), sqlite3.QuoteIdentifier(b.storage)),
sqlite3.PREPARE_DONT_LOG)
sqlite3.QuoteIdentifier(b.schema), sqlite3.QuoteIdentifier(b.storage)))
if err != nil {
return nil, err
}
@@ -166,10 +166,9 @@ func (t *bloom) ShadowTables() {
}
func (t *bloom) Integrity(schema, table string, flags int) error {
load, _, err := t.db.PrepareFlags(fmt.Sprintf(
load, _, err := t.db.Prepare(fmt.Sprintf(
`SELECT typeof(data), length(data), p, n, m, k FROM %s.%s WHERE rowid = 1`,
sqlite3.QuoteIdentifier(t.schema), sqlite3.QuoteIdentifier(t.storage)),
sqlite3.PREPARE_DONT_LOG)
sqlite3.QuoteIdentifier(t.schema), sqlite3.QuoteIdentifier(t.storage)))
if err != nil {
return fmt.Errorf("bloom: %v", err) // can't wrap!
}
@@ -233,7 +232,7 @@ func (b *bloom) Update(arg ...sqlite3.Value) (rowid int64, err error) {
}
defer f.Close()
for n := range b.hashes {
for n := 0; n < b.hashes; n++ {
hash := calcHash(n, blob)
hash %= uint64(b.bytes * 8)
bitpos := byte(hash % 8)
@@ -269,13 +268,13 @@ func (b *bloom) Open() (sqlite3.VTabCursor, error) {
type cursor struct {
*bloom
arg sqlite3.Value
arg *sqlite3.Value
eof bool
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
c.eof = false
c.arg = arg[0]
c.arg = &arg[0]
blob := arg[0].RawBlob()
f, err := c.db.OpenBlob(c.schema, c.storage, "data", 1, false)
@@ -313,7 +312,7 @@ func (c *cursor) Column(ctx sqlite3.Context, n int) error {
case 0:
ctx.ResultBool(true)
case 1:
ctx.ResultValue(c.arg)
ctx.ResultValue(*c.arg)
}
return nil
}

View File

@@ -14,7 +14,7 @@ import (
func TestMain(m *testing.M) {
sqlite3.AutoExtension(bloom.Register)
os.Exit(m.Run())
m.Run()
}
func TestRegister(t *testing.T) {

View File

@@ -56,7 +56,7 @@ func Register(db *sqlite3.Conn) error {
done.Add(key)
}
err := db.DeclareVTab(`CREATE TABLE x(id INT,depth INT,root HIDDEN,tablename TEXT HIDDEN,idcolumn TEXT HIDDEN,parentcolumn TEXT HIDDEN)`)
err := db.DeclareVTab(`CREATE TABLE x(id,depth,root HIDDEN,tablename HIDDEN,idcolumn HIDDEN,parentcolumn HIDDEN)`)
if err != nil {
return nil, err
}
@@ -154,7 +154,6 @@ func (c *closure) BestIndex(idx *sqlite3.IndexInfo) error {
return sqlite3.CONSTRAINT
}
idx.IdxFlags = sqlite3.INDEX_SCAN_HEX
idx.EstimatedCost = cost
idx.IdxNum = plan
return nil
@@ -202,7 +201,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
sqlite3.QuoteIdentifier(column),
sqlite3.QuoteIdentifier(parent),
)
stmt, _, err := c.db.PrepareFlags(sql, sqlite3.PREPARE_DONT_LOG)
stmt, _, err := c.db.Prepare(sql)
if err != nil {
return err
}
@@ -211,14 +210,12 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
c.nodes = []node{{root, 0}}
set := util.Set[int64]{}
set.Add(root)
for i := range c.nodes {
for i := 0; i < len(c.nodes); i++ {
curr := c.nodes[i]
if curr.depth >= maxDepth {
continue
}
if err := stmt.BindInt64(1, curr.id); err != nil {
return err
}
stmt.BindInt64(1, curr.id)
for stmt.Step() {
if stmt.ColumnType(0) == sqlite3.INTEGER {
next := stmt.ColumnInt64(0)
@@ -228,9 +225,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
}
}
}
if err := stmt.Reset(); err != nil {
return err
}
stmt.Reset()
}
return nil
}

View File

@@ -4,7 +4,6 @@ import (
_ "embed"
"fmt"
"log"
"os"
"testing"
"github.com/ncruces/go-sqlite3"
@@ -15,7 +14,7 @@ import (
func TestMain(m *testing.M) {
sqlite3.AutoExtension(closure.Register)
os.Exit(m.Run())
m.Run()
}
func Example() {

View File

@@ -30,7 +30,7 @@ func Register(db *sqlite3.Conn) error {
// RegisterFS registers the CSV virtual table.
// If a filename is specified, fsys is used to open the file.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (res *table, err error) {
var (
filename string
data string
@@ -100,7 +100,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
}
schema = getSchema(header, columns, row)
} else {
t.typs, err = getColumnAffinities(db, schema)
t.typs, err = getColumnAffinities(schema)
if err != nil {
return nil, err
}
@@ -214,10 +214,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
return err
}
if c.table.header {
err = c.Next() // skip header
if err != nil {
return err
}
c.Next() // skip header
}
c.rowID = 0
return c.Next()
@@ -254,15 +251,19 @@ func (c *cursor) Column(ctx sqlite3.Context, col int) error {
switch typ {
case numeric, integer:
if i, err := strconv.ParseInt(txt, 10, 64); err == nil {
ctx.ResultInt64(i)
return nil
if strings.TrimLeft(txt, "+-0123456789") == "" {
if i, err := strconv.ParseInt(txt, 10, 64); err == nil {
ctx.ResultInt64(i)
return nil
}
}
fallthrough
case real:
if f, ok := sql3util.ParseFloat(txt); ok {
ctx.ResultFloat(f)
return nil
if strings.TrimLeft(txt, "+-.0123456789Ee") == "" {
if f, err := strconv.ParseFloat(txt, 64); err == nil {
ctx.ResultFloat(f)
return nil
}
}
fallthrough
default:

View File

@@ -3,7 +3,6 @@ package csv_test
import (
"fmt"
"log"
"os"
"testing"
"github.com/ncruces/go-sqlite3"
@@ -57,7 +56,7 @@ func Example() {
func TestMain(m *testing.M) {
sqlite3.AutoExtension(csv.Register)
os.Exit(m.Run())
m.Run()
}
func TestRegister(t *testing.T) {
@@ -147,21 +146,20 @@ func TestAffinity(t *testing.T) {
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnText(0); got != "1" {
t.Errorf("got %q want 1", got)
if stmt.Step() {
if got := stmt.ColumnText(0); got != "1" {
t.Errorf("got %q want 1", got)
}
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnText(0); got != "0.1" {
t.Errorf("got %q want 0.1", got)
if stmt.Step() {
if got := stmt.ColumnText(0); got != "0.1" {
t.Errorf("got %q want 0.1", got)
}
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnText(0); got != "e" {
t.Errorf("got %q want e", got)
if stmt.Step() {
if got := stmt.ColumnText(0); got != "e" {
t.Errorf("got %q want e", got)
}
}
}

View File

@@ -3,8 +3,6 @@ package csv
import (
"strings"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/sql3util"
)
@@ -18,17 +16,7 @@ const (
real affinity = 4
)
func getColumnAffinities(db *sqlite3.Conn, schema string) ([]affinity, error) {
stmt, tail, err := db.PrepareFlags(schema,
sqlite3.PREPARE_DONT_LOG|sqlite3.PREPARE_NO_VTAB|sqlite3.PREPARE_FROM_DDL)
if err != nil {
return nil, err
}
stmt.Close()
if tail != "" {
return nil, util.TailErr
}
func getColumnAffinities(schema string) ([]affinity, error) {
tab, err := sql3util.ParseTable(schema)
if err != nil {
return nil, err

70
ext/fileio/coro.go Normal file
View File

@@ -0,0 +1,70 @@
//go:build !go1.23
package fileio
import (
"fmt"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Adapted from: https://research.swtch.com/coro
const errCoroCanceled = util.ErrorString("coroutine canceled")
func coroNew[In, Out any](f func(In, func(Out) In) Out) (resume func(In) (Out, bool), cancel func()) {
type msg[T any] struct {
panic any
val T
}
cin := make(chan msg[In])
cout := make(chan msg[Out])
running := true
resume = func(in In) (out Out, ok bool) {
if !running {
return
}
cin <- msg[In]{val: in}
m := <-cout
if m.panic != nil {
panic(m.panic)
}
return m.val, running
}
cancel = func() {
if !running {
return
}
e := fmt.Errorf("%w", errCoroCanceled)
cin <- msg[In]{panic: e}
m := <-cout
if m.panic != nil && m.panic != e {
panic(m.panic)
}
}
yield := func(out Out) In {
cout <- msg[Out]{val: out}
m := <-cin
if m.panic != nil {
panic(m.panic)
}
return m.val
}
go func() {
defer func() {
if running {
running = false
cout <- msg[Out]{panic: recover()}
}
}()
var out Out
m := <-cin
if m.panic == nil {
out = f(m.val, yield)
}
running = false
cout <- msg[Out]{val: out}
}()
return resume, cancel
}

View File

@@ -18,7 +18,7 @@ func Register(db *sqlite3.Conn) error {
return RegisterFS(db, nil)
}
// RegisterFS registers SQL functions readfile, lsmode,
// Register registers SQL functions readfile, lsmode,
// and the table-valued function fsdir;
// fsys will be used to read files and list directories.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
@@ -30,7 +30,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys)),
db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode),
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
err := db.DeclareVTab(`CREATE TABLE x(name TEXT,mode INT,mtime TIMESTAMP,data BLOB,path HIDDEN,dir HIDDEN)`)
err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
if err == nil {
err = db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
}
@@ -42,7 +42,7 @@ func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultText(fs.FileMode(arg[0].Int()).String())
}
func readfile(fsys fs.FS) sqlite3.ScalarFunction {
func readfile(fsys fs.FS) func(ctx sqlite3.Context, arg ...sqlite3.Value) {
return func(ctx sqlite3.Context, arg ...sqlite3.Value) {
var err error
var data []byte

View File

@@ -17,9 +17,9 @@ import (
func Test_lsmode(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := driver.Open(dsn, fileio.Register)
db, err := driver.Open(tmp, fileio.Register)
if err != nil {
t.Fatal(err)
}
@@ -53,9 +53,9 @@ func Test_readfile(t *testing.T) {
for _, fsys := range []fs.FS{nil, os.DirFS(".")} {
t.Run("", func(t *testing.T) {
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := driver.Open(dsn, func(c *sqlite3.Conn) error {
db, err := driver.Open(tmp, func(c *sqlite3.Conn) error {
fileio.RegisterFS(c, fsys)
return nil
})

View File

@@ -2,7 +2,6 @@ package fileio
import (
"io/fs"
"iter"
"os"
"path"
"path/filepath"
@@ -63,12 +62,12 @@ func (d fsdir) Open() (sqlite3.VTabCursor, error) {
type cursor struct {
fsdir
base string
next func() (entry, bool)
stop func()
curr entry
eof bool
rowID int64
base string
resume resume
cancel func()
curr entry
eof bool
rowID int64
}
type entry struct {
@@ -78,8 +77,8 @@ type entry struct {
}
func (c *cursor) Close() error {
if c.stop != nil {
c.stop()
if c.cancel != nil {
c.cancel()
}
return nil
}
@@ -102,26 +101,14 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
c.base = base
}
c.next, c.stop = iter.Pull(func(yield func(entry) bool) {
walkDir := func(path string, d fs.DirEntry, err error) error {
if yield(entry{d, err, path}) {
return nil
}
return fs.SkipAll
}
if c.fsys != nil {
fs.WalkDir(c.fsys, root, walkDir)
} else {
filepath.WalkDir(root, walkDir)
}
})
c.resume, c.cancel = pull(c, root)
c.eof = false
c.rowID = 0
return c.Next()
}
func (c *cursor) Next() error {
curr, ok := c.next()
curr, ok := next(c)
c.curr = curr
c.eof = !ok
c.rowID++

29
ext/fileio/fsdir_coro.go Normal file
View File

@@ -0,0 +1,29 @@
//go:build !go1.23
package fileio
import (
"io/fs"
"path/filepath"
)
type resume = func(struct{}) (entry, bool)
func next(c *cursor) (entry, bool) {
return c.resume(struct{}{})
}
func pull(c *cursor, root string) (resume, func()) {
return coroNew(func(_ struct{}, yield func(entry) struct{}) entry {
walkDir := func(path string, d fs.DirEntry, err error) error {
yield(entry{d, err, path})
return nil
}
if c.fsys != nil {
fs.WalkDir(c.fsys, root, walkDir)
} else {
filepath.WalkDir(root, walkDir)
}
return entry{}
})
}

31
ext/fileio/fsdir_iter.go Normal file
View File

@@ -0,0 +1,31 @@
//go:build go1.23
package fileio
import (
"io/fs"
"iter"
"path/filepath"
)
type resume = func() (entry, bool)
func next(c *cursor) (entry, bool) {
return c.resume()
}
func pull(c *cursor, root string) (resume, func()) {
return iter.Pull(func(yield func(entry) bool) {
walkDir := func(path string, d fs.DirEntry, err error) error {
if yield(entry{d, err, path}) {
return nil
}
return fs.SkipAll
}
if c.fsys != nil {
fs.WalkDir(c.fsys, root, walkDir)
} else {
filepath.WalkDir(root, walkDir)
}
})
}

View File

@@ -21,9 +21,9 @@ func Test_fsdir(t *testing.T) {
for _, fsys := range []fs.FS{nil, os.DirFS(".")} {
t.Run("", func(t *testing.T) {
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := driver.Open(dsn, func(c *sqlite3.Conn) error {
db, err := driver.Open(tmp, func(c *sqlite3.Conn) error {
fileio.RegisterFS(c, fsys)
return nil
})

View File

@@ -15,9 +15,9 @@ import (
func Test_writefile(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := driver.Open(dsn, Register)
db, err := driver.Open(tmp, Register)
if err != nil {
t.Fatal(err)
}

View File

@@ -4,7 +4,6 @@ import (
_ "crypto/md5"
_ "crypto/sha1"
_ "crypto/sha256"
_ "crypto/sha3"
_ "crypto/sha512"
"testing"
@@ -12,6 +11,7 @@ import (
_ "golang.org/x/crypto/blake2s"
_ "golang.org/x/crypto/md4"
_ "golang.org/x/crypto/ripemd160"
_ "golang.org/x/crypto/sha3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
@@ -21,7 +21,7 @@ import (
func TestRegister(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
tests := []struct {
name string
@@ -55,7 +55,7 @@ func TestRegister(t *testing.T) {
{"blake2b('', 256)", "0E5751C026E543B2E8AB2EB06099DAA1D1E5DF47778F7787FAAB45CDF12FE3A8"},
}
db, err := driver.Open(dsn, Register)
db, err := driver.Open(tmp, Register)
if err != nil {
t.Fatal(err)
}

View File

@@ -1,113 +0,0 @@
// Package ipaddr provides functions to manipulate IPs and CIDRs.
//
// It provides the following functions:
// - ipcontains(prefix, ip)
// - ipoverlaps(prefix1, prefix2)
// - ipfamily(ip/prefix)
// - iphost(ip/prefix)
// - ipmasklen(prefix)
// - ipnetwork(prefix)
package ipaddr
import (
"errors"
"net/netip"
"github.com/ncruces/go-sqlite3"
)
// Register IP/CIDR functions for a database connection.
func Register(db *sqlite3.Conn) error {
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
return errors.Join(
db.CreateFunction("ipcontains", 2, flags, contains),
db.CreateFunction("ipoverlaps", 2, flags, overlaps),
db.CreateFunction("ipfamily", 1, flags, family),
db.CreateFunction("iphost", 1, flags, host),
db.CreateFunction("ipmasklen", 1, flags, masklen),
db.CreateFunction("ipnetwork", 1, flags, network))
}
func contains(ctx sqlite3.Context, arg ...sqlite3.Value) {
prefix, err := netip.ParsePrefix(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
addr, err := netip.ParseAddr(arg[1].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
ctx.ResultBool(prefix.Contains(addr))
}
func overlaps(ctx sqlite3.Context, arg ...sqlite3.Value) {
prefix1, err := netip.ParsePrefix(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
prefix2, err := netip.ParsePrefix(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
ctx.ResultBool(prefix1.Overlaps(prefix2))
}
func family(ctx sqlite3.Context, arg ...sqlite3.Value) {
addr, err := addr(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
switch {
case addr.Is4():
ctx.ResultInt(4)
case addr.Is6():
ctx.ResultInt(6)
}
}
func host(ctx sqlite3.Context, arg ...sqlite3.Value) {
addr, err := addr(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
buf, _ := addr.MarshalText()
ctx.ResultRawText(buf)
}
func masklen(ctx sqlite3.Context, arg ...sqlite3.Value) {
prefix, err := netip.ParsePrefix(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
ctx.ResultInt(prefix.Bits())
}
func network(ctx sqlite3.Context, arg ...sqlite3.Value) {
prefix, err := netip.ParsePrefix(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
buf, _ := prefix.Masked().MarshalText()
ctx.ResultRawText(buf)
}
func addr(text string) (netip.Addr, error) {
addr, err := netip.ParseAddr(text)
if err != nil {
if prefix, err := netip.ParsePrefix(text); err == nil {
return prefix.Addr(), nil
}
if addrpt, err := netip.ParseAddrPort(text); err == nil {
return addrpt.Addr(), nil
}
}
return addr, err
}

View File

@@ -1,88 +0,0 @@
package ipaddr_test
import (
"testing"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/ipaddr"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
"github.com/ncruces/go-sqlite3/vfs/memdb"
)
func TestRegister(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
db, err := driver.Open(dsn, ipaddr.Register)
if err != nil {
t.Fatal(err)
}
defer db.Close()
var got string
err = db.QueryRow(`SELECT ipfamily('::1')`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != "6" {
t.Fatalf("got %s", got)
}
err = db.QueryRow(`SELECT ipfamily('[::1]:80')`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != "6" {
t.Fatalf("got %s", got)
}
err = db.QueryRow(`SELECT ipfamily('192.168.1.5/24')`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != "4" {
t.Fatalf("got %s", got)
}
err = db.QueryRow(`SELECT iphost('192.168.1.5/24')`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != "192.168.1.5" {
t.Fatalf("got %s", got)
}
err = db.QueryRow(`SELECT ipmasklen('192.168.1.5/24')`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != "24" {
t.Fatalf("got %s", got)
}
err = db.QueryRow(`SELECT ipnetwork('192.168.1.5/24')`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != "192.168.1.0/24" {
t.Fatalf("got %s", got)
}
err = db.QueryRow(`SELECT ipcontains('192.168.1.0/24', '192.168.1.5')`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != "1" {
t.Fatalf("got %s", got)
}
err = db.QueryRow(`SELECT ipoverlaps('192.168.1.0/24', '192.168.1.5/32')`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != "1" {
t.Fatalf("got %s", got)
}
}

View File

@@ -67,9 +67,9 @@ func Example() {
func Test_lines(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := driver.Open(dsn, lines.Register)
db, err := driver.Open(tmp, lines.Register)
if err != nil {
log.Fatal(err)
}
@@ -98,9 +98,9 @@ func Test_lines(t *testing.T) {
func Test_lines_error(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := driver.Open(dsn, lines.Register)
db, err := driver.Open(tmp, lines.Register)
if err != nil {
log.Fatal(err)
}
@@ -123,9 +123,9 @@ func Test_lines_error(t *testing.T) {
func Test_lines_read(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := driver.Open(dsn, lines.Register)
db, err := driver.Open(tmp, lines.Register)
if err != nil {
log.Fatal(err)
}
@@ -155,9 +155,9 @@ func Test_lines_read(t *testing.T) {
func Test_lines_test(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := driver.Open(dsn, lines.Register)
db, err := driver.Open(tmp, lines.Register)
if err != nil {
log.Fatal(err)
}

23
ext/parquet/go.mod Normal file
View File

@@ -0,0 +1,23 @@
module github.com/ncruces/go-sqlite3/ext/parquet
go 1.22
toolchain go1.23.0
require (
github.com/ncruces/go-sqlite3 v0.21.0
github.com/parquet-go/parquet-go v0.24.0
)
require (
github.com/andybalholm/brotli v1.1.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/klauspost/compress v1.17.9 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/ncruces/julianday v1.0.0 // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/pierrec/lz4/v4 v4.1.21 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/tetratelabs/wazero v1.8.2 // indirect
golang.org/x/sys v0.28.0 // indirect
)

32
ext/parquet/go.sum Normal file
View File

@@ -0,0 +1,32 @@
github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/ncruces/go-sqlite3 v0.21.0 h1:EwKFoy1hHEopN4sFZarmi+McXdbCcbTuLixhEayXVbQ=
github.com/ncruces/go-sqlite3 v0.21.0/go.mod h1:zxMOaSG5kFYVFK4xQa0pdwIszqxqJ0W0BxBgwdrNjuA=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/parquet-go/parquet-go v0.24.0 h1:VrsifmLPDnas8zpoHmYiWDZ1YHzLmc7NmNwPGkI2JM4=
github.com/parquet-go/parquet-go v0.24.0/go.mod h1:OqBBRGBl7+llplCvDMql8dEKaDqjaFA/VAPw+OJiNiw=
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/tetratelabs/wazero v1.8.2 h1:yIgLR/b2bN31bjxwXHD8a3d+BogigR952csSDdLYEv4=
github.com/tetratelabs/wazero v1.8.2/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=

62
ext/parquet/parquet.go Normal file
View File

@@ -0,0 +1,62 @@
package parquet
import (
"os"
"strings"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/osutil"
"github.com/ncruces/go-sqlite3/util/sql3util"
"github.com/parquet-go/parquet-go"
)
func Register(db *sqlite3.Conn) error {
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
if len(arg) == 0 {
return nil, util.ErrorString(`parquet: must specify a filename`)
}
file, err := osutil.OpenFile(sql3util.Unquote(arg[0]), os.O_RDONLY, 0)
if err != nil {
return nil, err
}
reader := parquet.NewReader(file)
column := make(map[int]string)
var schema strings.Builder
schema.WriteString("CREATE TABLE x(")
for i, field := range reader.Schema().Fields() {
if i > 0 {
schema.WriteByte(',')
}
schema.WriteString(sqlite3.QuoteIdentifier(field.Name()))
schema.WriteByte(' ')
switch field.Type().Kind() {
case parquet.Boolean:
schema.WriteString("BOOLEAN")
case parquet.Int32, parquet.Int64, parquet.Int96:
schema.WriteString("INTEGER")
case parquet.Float, parquet.Double:
schema.WriteString("REAL")
case parquet.ByteArray, parquet.FixedLenByteArray:
schema.WriteString("TEXT")
}
// Save the column name
column[i] = field.Name()
}
schema.WriteString(");")
err = db.DeclareVTab(schema.String())
if err != nil {
return nil, err
}
return &table{}, nil
}
return sqlite3.CreateModule(db, "parquet", declare, declare)
}
type table struct {
}

View File

@@ -25,14 +25,14 @@ type table struct {
cols []*sqlite3.Value
}
func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (ret *table, err error) {
func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (res *table, err error) {
if len(arg) != 3 {
return nil, fmt.Errorf("pivot: wrong number of arguments")
}
t := &table{db: db}
defer func() {
if ret == nil {
if res == nil {
t.Close()
}
}()
@@ -43,14 +43,11 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (ret *table, err e
// Row key query.
t.scan = "SELECT * FROM\n" + arg[0]
stmt, tail, err := db.PrepareFlags(t.scan, sqlite3.PREPARE_FROM_DDL)
stmt, _, err := db.Prepare(t.scan)
if err != nil {
return nil, err
}
defer stmt.Close()
if tail != "" {
return nil, util.TailErr
}
t.keys = make([]string, stmt.ColumnCount())
for i := range t.keys {
@@ -58,20 +55,15 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (ret *table, err e
t.keys[i] = name
create.WriteString(sep)
create.WriteString(name)
create.WriteString(" ")
create.WriteString(stmt.ColumnDeclType(i))
sep = ","
}
stmt.Close()
// Column definition query.
stmt, tail, err = db.PrepareFlags("SELECT * FROM\n"+arg[1], sqlite3.PREPARE_FROM_DDL)
stmt, _, err = db.Prepare("SELECT * FROM\n" + arg[1])
if err != nil {
return nil, err
}
if tail != "" {
return nil, util.TailErr
}
if stmt.ColumnCount() != 2 {
return nil, util.ErrorString("pivot: column definition query expects 2 result columns")
@@ -79,23 +71,17 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (ret *table, err e
for stmt.Step() {
name := sqlite3.QuoteIdentifier(stmt.ColumnText(1))
t.cols = append(t.cols, stmt.ColumnValue(0).Dup())
create.WriteString(sep)
create.WriteString(",")
create.WriteString(name)
create.WriteString(" ")
create.WriteString(stmt.ColumnDeclType(1))
sep = ","
}
stmt.Close()
// Pivot cell query.
t.cell = "SELECT * FROM\n" + arg[2]
stmt, tail, err = db.PrepareFlags(t.cell, sqlite3.PREPARE_FROM_DDL)
stmt, _, err = db.Prepare(t.cell)
if err != nil {
return nil, err
}
if tail != "" {
return nil, util.TailErr
}
if stmt.ColumnCount() != 1 {
return nil, util.ErrorString("pivot: cell query expects 1 result columns")
@@ -113,11 +99,10 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (ret *table, err e
}
func (t *table) Close() error {
var errs []error
for _, c := range t.cols {
errs = append(errs, c.Close())
c.Close()
}
return errors.Join(errs...)
return nil
}
func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
@@ -196,9 +181,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
return err
}
const prepflags = sqlite3.PREPARE_DONT_LOG | sqlite3.PREPARE_FROM_DDL
c.scan, _, err = c.table.db.PrepareFlags(idxStr, prepflags)
c.scan, _, err = c.table.db.Prepare(idxStr)
if err != nil {
return err
}
@@ -210,7 +193,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
}
if c.cell == nil {
c.cell, _, err = c.table.db.PrepareFlags(c.table.cell, prepflags)
c.cell, _, err = c.table.db.Prepare(c.table.cell)
if err != nil {
return err
}
@@ -223,7 +206,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
func (c *cursor) Next() error {
if c.scan.Step() {
count := c.scan.ColumnCount()
for i := range count {
for i := 0; i < count; i++ {
err := c.cell.BindValue(i+1, c.scan.ColumnValue(i))
if err != nil {
return err

View File

@@ -3,7 +3,6 @@ package pivot_test
import (
"fmt"
"log"
"os"
"strings"
"testing"
@@ -86,7 +85,7 @@ func Example() {
func TestMain(m *testing.M) {
sqlite3.AutoExtension(pivot.Register)
os.Exit(m.Run())
m.Run()
}
func TestRegister(t *testing.T) {
@@ -141,10 +140,10 @@ func TestRegister(t *testing.T) {
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnInt(0); got != 3 {
t.Errorf("got %d, want 3", got)
if stmt.Step() {
if got := stmt.ColumnInt(0); got != 3 {
t.Errorf("got %d, want 3", got)
}
}
err = db.Exec(`ALTER TABLE v_x RENAME TO v_y`)

View File

@@ -16,9 +16,7 @@ package regexp
import (
"errors"
"regexp"
"regexp/syntax"
"strings"
"unicode/utf8"
"github.com/ncruces/go-sqlite3"
)
@@ -52,83 +50,34 @@ func Register(db *sqlite3.Conn) error {
// SELECT column WHERE column GLOB :glob_prefix AND column REGEXP :regexp
//
// [LIKE optimization]: https://sqlite.org/optoverview.html#the_like_optimization
func GlobPrefix(expr string) string {
re, err := syntax.Parse(expr, syntax.Perl)
if err != nil {
return "" // no match possible
}
prog, err := syntax.Compile(re.Simplify())
if err != nil {
return "" // notest
}
i := &prog.Inst[prog.Start]
var empty syntax.EmptyOp
loop1:
for {
switch i.Op {
case syntax.InstFail:
return "" // notest
case syntax.InstCapture, syntax.InstNop:
// skip
case syntax.InstEmptyWidth:
empty |= syntax.EmptyOp(i.Arg)
default:
break loop1
func GlobPrefix(re *regexp.Regexp) string {
prefix, complete := re.LiteralPrefix()
i := strings.IndexAny(prefix, "*?[")
if i < 0 {
if complete {
return prefix
}
i = &prog.Inst[i.Out]
i = len(prefix)
}
if empty&syntax.EmptyBeginText == 0 {
return "*" // not anchored
}
var glob strings.Builder
loop2:
for {
switch i.Op {
case syntax.InstFail:
return "" // notest
case syntax.InstCapture, syntax.InstEmptyWidth, syntax.InstNop:
// skip
case syntax.InstRune, syntax.InstRune1:
if len(i.Rune) != 1 || syntax.Flags(i.Arg)&syntax.FoldCase != 0 {
break loop2
}
switch r := i.Rune[0]; r {
case '*', '?', '[', utf8.RuneError:
break loop2
default:
glob.WriteRune(r)
}
default:
break loop2
}
i = &prog.Inst[i.Out]
}
glob.WriteByte('*')
return glob.String()
return prefix[:i] + "*"
}
func load(ctx sqlite3.Context, arg []sqlite3.Value, i int) (*regexp.Regexp, error) {
func load(ctx sqlite3.Context, i int, expr string) (*regexp.Regexp, error) {
re, ok := ctx.GetAuxData(i).(*regexp.Regexp)
if !ok {
re, ok = arg[i].Pointer().(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(arg[i].Text())
if err != nil {
return nil, err
}
re = r
r, err := regexp.Compile(expr)
if err != nil {
return nil, err
}
ctx.SetAuxData(i, re)
re = r
ctx.SetAuxData(0, r)
}
return re, nil
}
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, arg, 0)
_ = arg[1] // bounds check
re, err := load(ctx, 0, arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
@@ -138,17 +87,18 @@ func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
func regexLike(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, arg, 1)
re, err := load(ctx, 1, arg[1].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
text := arg[0].RawText()
ctx.ResultBool(re.Match(text))
}
func regexCount(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, arg, 1)
re, err := load(ctx, 1, arg[1].Text())
if err != nil {
ctx.ResultError(err)
return // notest
@@ -163,7 +113,7 @@ func regexCount(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
func regexSubstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, arg, 1)
re, err := load(ctx, 1, arg[1].Text())
if err != nil {
ctx.ResultError(err)
return // notest
@@ -188,7 +138,7 @@ func regexSubstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
func regexInstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, arg, 1)
re, err := load(ctx, 1, arg[1].Text())
if err != nil {
ctx.ResultError(err)
return // notest
@@ -216,14 +166,16 @@ func regexInstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
func regexReplace(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, arg, 1)
_ = arg[2] // bounds check
re, err := load(ctx, 1, arg[1].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
repl := arg[2].RawText()
text := arg[0].RawText()
repl := arg[2].RawText()
var pos, n int
if len(arg) > 3 {
pos = arg[3].Int()

View File

@@ -3,10 +3,8 @@ package regexp
import (
"database/sql"
"regexp"
"strings"
"testing"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
@@ -15,9 +13,9 @@ import (
func TestRegister(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := driver.Open(dsn, Register)
db, err := driver.Open(tmp, Register)
if err != nil {
t.Fatal(err)
}
@@ -38,7 +36,7 @@ func TestRegister(t *testing.T) {
{`regexp_instr('Hello', '.', 6)`, ""},
{`regexp_substr('Hello', 'el.')`, "ell"},
{`regexp_replace('Hello', 'llo', 'll')`, "Hell"},
// https://postgresql.org/docs/current/functions-matching.html
// https://www.postgresql.org/docs/current/functions-matching.html
{`regexp_count('ABCABCAXYaxy', 'A.')`, "3"},
{`regexp_count('ABCABCAXYaxy', '(?i)A.', 1)`, "4"},
{`regexp_instr('number of your street, town zip, FR', '[^,]+', 1, 2)`, "23"},
@@ -80,9 +78,9 @@ func TestRegister(t *testing.T) {
func TestRegister_errors(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := driver.Open(dsn, Register)
db, err := driver.Open(tmp, Register)
if err != nil {
t.Fatal(err)
}
@@ -105,81 +103,24 @@ func TestRegister_errors(t *testing.T) {
}
}
func TestRegister_pointer(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
db, err := driver.Open(dsn, Register)
if err != nil {
t.Fatal(err)
}
defer db.Close()
var got int
err = db.QueryRow(`SELECT regexp_count('ABCABCAXYaxy', ?, 1)`,
sqlite3.Pointer(regexp.MustCompile(`(?i)A.`))).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != 4 {
t.Errorf("got %d, want %d", got, 4)
}
}
func TestGlobPrefix(t *testing.T) {
tests := []struct {
re string
want string
}{
{`[`, ""},
{``, "*"},
{`^`, "*"},
{`a`, "*"},
{`ab`, "*"},
{`^a`, "a*"},
{`^a*`, "*"},
{`^a+`, "a*"},
{`^ab*`, "a*"},
{`^ab+`, "ab*"},
{`^a\?b`, "a*"},
{`^[a-z]`, "*"},
{``, ""},
{`a`, "a"},
{`a*`, "*"},
{`a+`, "a*"},
{`ab*`, "a*"},
{`ab+`, "ab*"},
{`a\?b`, "a*"},
}
for _, tt := range tests {
t.Run(tt.re, func(t *testing.T) {
if got := GlobPrefix(tt.re); got != tt.want {
t.Errorf("GlobPrefix(%v) = %v, want %v", tt.re, got, tt.want)
if got := GlobPrefix(regexp.MustCompile(tt.re)); got != tt.want {
t.Errorf("GlobPrefix() = %v, want %v", got, tt.want)
}
})
}
}
func FuzzGlobPrefix(f *testing.F) {
f.Add(``, ``)
f.Add(`[`, ``)
f.Add(`^`, ``)
f.Add(`a`, `a`)
f.Add(`ab`, `b`)
f.Add(`^a`, `a`)
f.Add(`^a*`, `ab`)
f.Add(`^a+`, `ab`)
f.Add(`^ab*`, `ab`)
f.Add(`^ab+`, `ab`)
f.Add(`^a\?b`, `ab`)
f.Add(`^[a-z]`, `ab`)
f.Fuzz(func(t *testing.T, lit, str string) {
re, err := regexp.Compile(lit)
if err != nil {
t.SkipNow()
}
if re.MatchString(str) {
prefix, ok := strings.CutSuffix(GlobPrefix(lit), "*")
if !ok {
t.Fatalf("missing * after %q for %q with %q", prefix, lit, str)
}
if !strings.HasPrefix(str, prefix) {
t.Fatalf("missing prefix %q for %q with %q", prefix, lit, str)
}
}
})
}

View File

@@ -1,72 +0,0 @@
// Package serdes provides functions to (de)serialize databases.
package serdes
import (
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/util/vfsutil"
"github.com/ncruces/go-sqlite3/vfs"
)
const vfsName = "github.com/ncruces/go-sqlite3/ext/serdes.sliceVFS"
func init() {
vfs.Register(vfsName, sliceVFS{})
}
var fileToOpen = make(chan *[]byte, 1)
// Serialize backs up a database into a byte slice.
//
// https://sqlite.org/c3ref/serialize.html
func Serialize(db *sqlite3.Conn, schema string) ([]byte, error) {
var file []byte
fileToOpen <- &file
err := db.Backup(schema, "file:serdes.db?nolock=1&vfs="+vfsName)
return file, err
}
// Deserialize restores a database from a byte slice,
// DESTROYING any contents previously stored in schema.
//
// To non-destructively open a database from a byte slice,
// consider alternatives like the ["reader"] or ["memdb"] VFSes.
//
// This differs from the similarly named SQLite API
// in that it DOES NOT disconnect from schema
// to reopen as an in-memory database.
//
// https://sqlite.org/c3ref/deserialize.html
//
// ["memdb"]: https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/memdb
// ["reader"]: https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/readervfs
func Deserialize(db *sqlite3.Conn, schema string, data []byte) error {
fileToOpen <- &data
return db.Restore(schema, "file:serdes.db?immutable=1&vfs="+vfsName)
}
type sliceVFS struct{}
func (sliceVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
if flags&vfs.OPEN_MAIN_DB == 0 || name != "serdes.db" {
return nil, flags, sqlite3.CANTOPEN
}
select {
case file := <-fileToOpen:
return (*vfsutil.SliceFile)(file), flags | vfs.OPEN_MEMORY, nil
default:
return nil, flags, sqlite3.MISUSE
}
}
func (sliceVFS) Delete(name string, dirSync bool) error {
// notest // no journals to delete
return sqlite3.IOERR_DELETE
}
func (sliceVFS) Access(name string, flag vfs.AccessFlag) (bool, error) {
return name == "serdes.db", nil
}
func (sliceVFS) FullPathname(name string) (string, error) {
return name, nil
}

View File

@@ -1,115 +0,0 @@
package serdes_test
import (
_ "embed"
"errors"
"io"
"net/http"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/serdes"
)
//go:embed testdata/wal.db
var walDB []byte
func Test_wal(t *testing.T) {
db, err := sqlite3.Open("testdata/wal.db")
if err != nil {
t.Fatal(err)
}
defer db.Close()
data, err := serdes.Serialize(db, "main")
if err != nil {
t.Fatal(err)
}
compareDBs(t, data, walDB)
err = serdes.Deserialize(db, "temp", walDB)
if err != nil {
t.Fatal(err)
}
}
func Test_northwind(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
input, err := httpGet()
if err != nil {
t.Fatal(err)
}
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = serdes.Deserialize(db, "temp", input)
if err != nil {
t.Fatal(err)
}
output, err := serdes.Serialize(db, "temp")
if err != nil {
t.Fatal(err)
}
compareDBs(t, input, output)
}
func compareDBs(t *testing.T, a, b []byte) {
if len(a) != len(b) {
t.Fatal("lengths are different")
}
for i := range a {
// These may be different.
switch {
case 24 <= i && i < 28:
// File change counter.
continue
case 40 <= i && i < 44:
// Schema cookie.
continue
case 92 <= i && i < 100:
// SQLite version that wrote the file.
continue
}
if a[i] != b[i] {
t.Errorf("difference at %d: %d %d", i, a[i], b[i])
}
}
}
func httpGet() ([]byte, error) {
res, err := http.Get("https://github.com/jpwhite3/northwind-SQLite3/raw/refs/heads/main/dist/northwind.db")
if err != nil {
return nil, err
}
defer res.Body.Close()
return io.ReadAll(res.Body)
}
func TestOpen_errors(t *testing.T) {
_, err := sqlite3.Open("file:test.db?vfs=github.com/ncruces/go-sqlite3/ext/serdes.sliceVFS")
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.CANTOPEN) {
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
}
_, err = sqlite3.Open("file:serdes.db?vfs=github.com/ncruces/go-sqlite3/ext/serdes.sliceVFS")
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.MISUSE) {
t.Errorf("got %v, want sqlite3.MISUSE", err)
}
}

Binary file not shown.

View File

@@ -8,7 +8,6 @@ package statement
import (
"encoding/json"
"errors"
"strconv"
"strings"
"unsafe"
@@ -35,21 +34,16 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (*table, error) {
sql := "SELECT * FROM\n" + arg[0]
stmt, tail, err := db.PrepareFlags(sql,
sqlite3.PREPARE_PERSISTENT|sqlite3.PREPARE_FROM_DDL)
stmt, _, err := db.PrepareFlags(sql, sqlite3.PREPARE_PERSISTENT)
if err != nil {
return nil, err
}
if tail != "" {
stmt.Close()
return nil, util.TailErr
}
var sep string
var str strings.Builder
str.WriteString("CREATE TABLE x(")
outputs := stmt.ColumnCount()
for i := range outputs {
for i := 0; i < outputs; i++ {
name := sqlite3.QuoteIdentifier(stmt.ColumnName(i))
str.WriteString(sep)
str.WriteString(name)
@@ -134,8 +128,7 @@ func (t *table) Open() (_ sqlite3.VTabCursor, err error) {
if !t.inuse {
t.inuse = true
} else {
stmt, _, err = t.stmt.Conn().PrepareFlags(t.sql,
sqlite3.PREPARE_DONT_LOG|sqlite3.PREPARE_FROM_DDL)
stmt, _, err = t.stmt.Conn().Prepare(t.sql)
if err != nil {
return nil, err
}
@@ -157,18 +150,17 @@ type cursor struct {
func (c *cursor) Close() error {
if c.stmt == c.table.stmt {
c.table.inuse = false
return errors.Join(
c.stmt.Reset(),
c.stmt.ClearBindings())
c.stmt.ClearBindings()
return c.stmt.Reset()
}
return c.stmt.Close()
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
err := errors.Join(
c.stmt.Reset(),
c.stmt.ClearBindings())
if err != nil {
c.arg = arg
c.rowID = 0
c.stmt.ClearBindings()
if err := c.stmt.Reset(); err != nil {
return err
}
@@ -191,8 +183,6 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
return err
}
}
c.arg = append(c.arg[:0], arg...)
c.rowID = 0
return c.Next()
}

View File

@@ -3,7 +3,6 @@ package statement_test
import (
"fmt"
"log"
"os"
"testing"
"github.com/ncruces/go-sqlite3"
@@ -51,7 +50,7 @@ func Example() {
func TestMain(m *testing.M) {
sqlite3.AutoExtension(statement.Register)
os.Exit(m.Run())
m.Run()
}
func TestRegister(t *testing.T) {
@@ -92,9 +91,7 @@ func TestRegister(t *testing.T) {
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal(stmt.Err())
} else {
if stmt.Step() {
x := stmt.ColumnInt(0)
y := stmt.ColumnInt(1)
hypot := stmt.ColumnInt(2)

View File

@@ -1,6 +1,6 @@
# ANSI SQL Aggregate Functions
https://oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
## Built in aggregates

View File

@@ -7,7 +7,7 @@ const (
some
)
func newBoolean(kind int) sqlite3.AggregateConstructor {
func newBoolean(kind int) func() sqlite3.AggregateFunction {
return func() sqlite3.AggregateFunction { return &boolean{kind: kind} }
}

View File

@@ -37,9 +37,7 @@ func TestRegister_boolean(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else {
if stmt.Step() {
if got := stmt.ColumnBool(0); got != true {
t.Errorf("got %v, want true", got)
}

View File

@@ -1,19 +0,0 @@
package stats
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
type kahan struct{ hi, lo float64 }
func (k *kahan) add(x float64) {
y := k.lo + x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}
func (k *kahan) sub(x float64) {
y := k.lo - x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}

View File

@@ -1,121 +0,0 @@
package stats
import (
"unsafe"
"github.com/ncruces/go-sqlite3"
)
func newMode() sqlite3.AggregateFunction {
return &mode{}
}
type mode struct {
ints counter[int64]
reals counter[float64]
texts counter[string]
blobs counter[string]
}
func (m mode) Value(ctx sqlite3.Context) {
var (
typ = sqlite3.NULL
max uint
i64 int64
f64 float64
str string
)
for k, v := range m.ints {
if v > max || v == max && k < i64 {
typ = sqlite3.INTEGER
max = v
i64 = k
}
}
for k, v := range m.reals {
if v > max || v == max && k < f64 {
typ = sqlite3.FLOAT
max = v
f64 = k
}
}
for k, v := range m.texts {
if v > max || v == max && typ == sqlite3.TEXT && k < str {
typ = sqlite3.TEXT
max = v
str = k
}
}
for k, v := range m.blobs {
if v > max || v == max && typ == sqlite3.BLOB && k < str {
typ = sqlite3.BLOB
max = v
str = k
}
}
switch typ {
case sqlite3.INTEGER:
ctx.ResultInt64(i64)
case sqlite3.FLOAT:
ctx.ResultFloat(f64)
case sqlite3.TEXT:
ctx.ResultText(str)
case sqlite3.BLOB:
ctx.ResultBlob(unsafe.Slice(unsafe.StringData(str), len(str)))
}
}
func (m *mode) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
switch arg[0].Type() {
case sqlite3.INTEGER:
if m.reals == nil {
m.ints.add(arg[0].Int64())
break
}
fallthrough
case sqlite3.FLOAT:
m.reals.add(arg[0].Float())
for k, v := range m.ints {
m.reals[float64(k)] += v
}
m.ints = nil
case sqlite3.TEXT:
m.texts.add(arg[0].Text())
case sqlite3.BLOB:
m.blobs.add(string(arg[0].RawBlob()))
}
}
func (m *mode) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
switch arg[0].Type() {
case sqlite3.INTEGER:
if m.reals == nil {
m.ints.del(arg[0].Int64())
break
}
fallthrough
case sqlite3.FLOAT:
m.reals.del(arg[0].Float())
case sqlite3.TEXT:
m.texts.del(arg[0].Text())
case sqlite3.BLOB:
m.blobs.del(string(arg[0].RawBlob()))
}
}
type counter[T comparable] map[T]uint
func (c *counter[T]) add(k T) {
if (*c) == nil {
(*c) = make(counter[T])
}
(*c)[k]++
}
func (c counter[T]) del(k T) {
if n := c[k]; n == 1 {
delete(c, k)
} else {
c[k] = n - 1
}
}

View File

@@ -1,102 +0,0 @@
package stats_test
import (
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
)
func TestRegister_mode(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT mode(column1) FROM (VALUES (NULL), (1), (NULL), (2), (NULL), (3), (3))`)
if err != nil {
t.Fatal(err)
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnInt(0); got != 3 {
t.Errorf("got %v, want 3", got)
}
stmt.Close()
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (1), (1), (2), (2), (3))`)
if err != nil {
t.Fatal(err)
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnInt(0); got != 1 {
t.Errorf("got %v, want 1", got)
}
stmt.Close()
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (0.5), (1), (2.5), (2), (2.5))`)
if err != nil {
t.Fatal(err)
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnFloat(0); got != 2.5 {
t.Errorf("got %v, want 2.5", got)
}
stmt.Close()
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES ('red'), ('green'), ('blue'), ('red'))`)
if err != nil {
t.Fatal(err)
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnText(0); got != "red" {
t.Errorf("got %q, want red", got)
}
stmt.Close()
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (X'cafebabe'), ('green'), ('blue'), (X'cafebabe'))`)
if err != nil {
t.Fatal(err)
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnText(0); got != "\xca\xfe\xba\xbe" {
t.Errorf("got %q, want cafebabe", got)
}
stmt.Close()
stmt, _, err = db.Prepare(`
SELECT mode(column1) OVER (ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
FROM (VALUES (1), (1), (2.5), ('blue'), (X'cafebabe'), (1), (1))
`)
if err != nil {
t.Fatal(err)
}
for stmt.Step() {
}
stmt.Close()
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (?), (?), (?), (?), (?))`)
if err != nil {
t.Fatal(err)
}
stmt.BindInt(1, 1)
stmt.BindInt(2, 1)
stmt.BindInt(3, 2)
stmt.BindFloat(4, 2)
stmt.BindFloat(5, 2)
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnInt(0); got != 2 {
t.Errorf("got %v, want 2", got)
}
stmt.Close()
}

View File

@@ -1,101 +0,0 @@
package stats
import "math"
// FisherPearson skewness and kurtosis using
// Terriberry's algorithm with Kahan summation:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
type moments struct {
m1, m2, m3, m4 kahan
n int64
}
func (m moments) mean() float64 {
return m.m1.hi
}
func (m moments) var_pop() float64 {
return m.m2.hi / float64(m.n)
}
func (m moments) var_samp() float64 {
return m.m2.hi / float64(m.n-1) // Bessel's correction
}
func (m moments) stddev_pop() float64 {
return math.Sqrt(m.var_pop())
}
func (m moments) stddev_samp() float64 {
return math.Sqrt(m.var_samp())
}
func (m moments) skewness_pop() float64 {
m2 := m.m2.hi
if div := m2 * m2 * m2; div != 0 {
return m.m3.hi * math.Sqrt(float64(m.n)/div)
}
return math.NaN()
}
func (m moments) skewness_samp() float64 {
n := m.n
// https://mathworks.com/help/stats/skewness.html#f1132178
return m.skewness_pop() * math.Sqrt(float64(n*(n-1))) / float64(n-2)
}
func (m moments) kurtosis_pop() float64 {
return m.raw_kurtosis_pop() - 3
}
func (m moments) raw_kurtosis_pop() float64 {
m2 := m.m2.hi
if div := m2 * m2; div != 0 {
return m.m4.hi * float64(m.n) / div
}
return math.NaN()
}
func (m moments) kurtosis_samp() float64 {
n := m.n
k := math.FMA(m.raw_kurtosis_pop(), float64(n+1), float64(3-3*n))
return k * float64(n-1) / float64((n-2)*(n-3))
}
func (m moments) raw_kurtosis_samp() float64 {
n := m.n
// https://mathworks.com/help/stats/kurtosis.html#f4975293
k := math.FMA(m.raw_kurtosis_pop(), float64(n+1), float64(3-3*n))
return math.FMA(k, float64(n-1)/float64((n-2)*(n-3)), 3)
}
func (m *moments) enqueue(x float64) {
n := m.n + 1
m.n = n
d1 := x - m.m1.hi - m.m1.lo
dn := d1 / float64(n)
d2 := dn * dn
t1 := d1 * dn * float64(n-1)
m.m4.add(t1*d2*float64(n*n-3*n+3) + 6*d2*m.m2.hi - 4*dn*m.m3.hi)
m.m3.add(t1*dn*float64(n-2) - 3*dn*m.m2.hi)
m.m2.add(t1)
m.m1.add(dn)
}
func (m *moments) dequeue(x float64) {
n := m.n - 1
if n <= 0 {
*m = moments{}
return
}
m.n = n
d1 := x - m.m1.hi - m.m1.lo
dn := d1 / float64(n)
d2 := dn * dn
t1 := d1 * dn * float64(n+1)
m.m4.sub(t1*d2*float64(n*n+3*n+3) - 6*d2*m.m2.hi - 4*dn*m.m3.hi)
m.m3.sub(t1*dn*float64(n+2) - 3*dn*m.m2.hi)
m.m2.sub(t1)
m.m1.sub(dn)
}

View File

@@ -1,87 +0,0 @@
package stats
import (
"math"
"testing"
)
func Test_moments(t *testing.T) {
t.Parallel()
var s1 moments
s1.enqueue(1)
s1.dequeue(1)
if !math.IsNaN(s1.skewness_pop()) {
t.Errorf("want NaN")
}
if !math.IsNaN(s1.raw_kurtosis_pop()) {
t.Errorf("want NaN")
}
s1.enqueue(+0.5377)
s1.enqueue(+1.8339)
s1.enqueue(-2.2588)
s1.enqueue(+0.8622)
s1.enqueue(+0.3188)
s1.enqueue(-1.3077)
s1.enqueue(-0.4336)
s1.enqueue(+0.3426)
s1.enqueue(+3.5784)
s1.enqueue(+2.7694)
if got := s1.skewness_pop(); float32(got) != 0.106098293 {
t.Errorf("got %v, want 0.1061", got)
}
if got := s1.skewness_samp(); float32(got) != 0.1258171 {
t.Errorf("got %v, want 0.1258", got)
}
if got := s1.raw_kurtosis_pop(); float32(got) != 2.3121266 {
t.Errorf("got %v, want 2.3121", got)
}
if got := s1.raw_kurtosis_samp(); float32(got) != 2.7482237 {
t.Errorf("got %v, want 2.7483", got)
}
var s2 welford
s2.enqueue(+0.5377)
s2.enqueue(+1.8339)
s2.enqueue(-2.2588)
s2.enqueue(+0.8622)
s2.enqueue(+0.3188)
s2.enqueue(-1.3077)
s2.enqueue(-0.4336)
s2.enqueue(+0.3426)
s2.enqueue(+3.5784)
s2.enqueue(+2.7694)
if got, want := s1.mean(), s2.mean(); got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := s1.stddev_pop(), s2.stddev_pop(); got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := s1.stddev_samp(), s2.stddev_samp(); got != want {
t.Errorf("got %v, want %v", got, want)
}
s1.enqueue(math.Pi)
s1.enqueue(math.Sqrt2)
s1.enqueue(math.E)
s1.dequeue(math.Pi)
s1.dequeue(math.E)
s1.dequeue(math.Sqrt2)
if got := s1.skewness_pop(); float32(got) != 0.106098293 {
t.Errorf("got %v, want 0.1061", got)
}
if got := s1.skewness_samp(); float32(got) != 0.1258171 {
t.Errorf("got %v, want 0.1258", got)
}
if got := s1.raw_kurtosis_pop(); float32(got) != 2.3121266 {
t.Errorf("got %v, want 2.3121", got)
}
if got := s1.raw_kurtosis_samp(); float32(got) != 2.7482237 {
t.Errorf("got %v, want 2.7483", got)
}
}

View File

@@ -11,9 +11,6 @@ import (
"github.com/ncruces/sort/quick"
)
// Compatible with:
// https://sqlite.org/src/file/ext/misc/percentile.c
const (
median = iota
percentile_100
@@ -21,7 +18,7 @@ const (
percentile_disc
)
func newPercentile(kind int) sqlite3.AggregateConstructor {
func newPercentile(kind int) func() sqlite3.AggregateFunction {
return func() sqlite3.AggregateFunction { return &percentile{kind: kind} }
}

View File

@@ -38,9 +38,7 @@ func TestRegister_percentile(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else {
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 10 {
t.Errorf("got %v, want 10", got)
}
@@ -67,30 +65,30 @@ func TestRegister_percentile(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnFloat(0); got != 5.5 {
t.Errorf("got %v, want 5.5", got)
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 5.5 {
t.Errorf("got %v, want 5.5", got)
}
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnFloat(0); got != 7 {
t.Errorf("got %v, want 7", got)
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 7 {
t.Errorf("got %v, want 7", got)
}
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnFloat(0); got != 10 {
t.Errorf("got %v, want 10", got)
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 10 {
t.Errorf("got %v, want 10", got)
}
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnFloat(0); got != 14.5 {
t.Errorf("got %v, want 14.5", got)
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 14.5 {
t.Errorf("got %v, want 14.5", got)
}
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnFloat(0); got != 16 {
t.Errorf("got %v, want 16", got)
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 16 {
t.Errorf("got %v, want 16", got)
}
}
stmt.Close()
@@ -105,9 +103,7 @@ func TestRegister_percentile(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else {
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 4 {
t.Errorf("got %v, want 4", got)
}
@@ -138,9 +134,7 @@ func TestRegister_percentile(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else {
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Error("want NULL")
}

View File

@@ -1,17 +1,13 @@
// Package stats provides aggregate functions for statistics.
//
// Provided functions:
// - var_pop: population variance
// - var_samp: sample variance
// - stddev_pop: population standard deviation
// - stddev_samp: sample standard deviation
// - skewness_pop: Pearson population skewness
// - skewness_samp: Pearson sample skewness
// - kurtosis_pop: Fisher population excess kurtosis
// - kurtosis_samp: Fisher sample excess kurtosis
// - var_pop: population variance
// - var_samp: sample variance
// - covar_pop: population covariance
// - covar_samp: sample covariance
// - corr: Pearson correlation coefficient
// - corr: correlation coefficient
// - regr_r2: correlation coefficient squared
// - regr_avgx: average of the independent variable
// - regr_avgy: average of the dependent variable
@@ -21,12 +17,10 @@
// - regr_count: count non-null pairs of variables
// - regr_slope: slope of the least-squares-fit linear equation
// - regr_intercept: y-intercept of the least-squares-fit linear equation
// - regr_json: all regr stats as a JSON object
// - percentile_disc: discrete quantile
// - percentile_cont: continuous quantile
// - percentile: continuous percentile
// - median: middle value
// - mode: most frequent value
// - regr_json: all regr stats in a JSON object
// - percentile_disc: discrete percentile
// - percentile_cont: continuous percentile
// - median: median value
// - every: boolean and
// - some: boolean or
//
@@ -47,7 +41,7 @@
//
// [Built-in Aggregate Functions]: https://sqlite.org/lang_aggfunc.html
// [Built-in Window Functions]: https://sqlite.org/windowfunctions.html#builtins
// [ANSI SQL Aggregate Functions]: https://oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
package stats
import (
@@ -58,20 +52,13 @@ import (
// Register registers statistics functions.
func Register(db *sqlite3.Conn) error {
const (
flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
json = sqlite3.RESULT_SUBTYPE | flags
order = sqlite3.SELFORDER1 | flags
)
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
const order = sqlite3.SELFORDER1 | flags
return errors.Join(
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop)),
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp)),
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop)),
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp)),
db.CreateWindowFunction("skewness_pop", 1, flags, newMoments(skewness_pop)),
db.CreateWindowFunction("skewness_samp", 1, flags, newMoments(skewness_samp)),
db.CreateWindowFunction("kurtosis_pop", 1, flags, newMoments(kurtosis_pop)),
db.CreateWindowFunction("kurtosis_samp", 1, flags, newMoments(kurtosis_samp)),
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop)),
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp)),
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr)),
@@ -84,14 +71,13 @@ func Register(db *sqlite3.Conn) error {
db.CreateWindowFunction("regr_slope", 2, flags, newCovariance(regr_slope)),
db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept)),
db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count)),
db.CreateWindowFunction("regr_json", 2, json, newCovariance(regr_json)),
db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json)),
db.CreateWindowFunction("median", 1, order, newPercentile(median)),
db.CreateWindowFunction("percentile", 2, order, newPercentile(percentile_100)),
db.CreateWindowFunction("percentile_cont", 2, order, newPercentile(percentile_cont)),
db.CreateWindowFunction("percentile_disc", 2, order, newPercentile(percentile_disc)),
db.CreateWindowFunction("every", 1, flags, newBoolean(every)),
db.CreateWindowFunction("some", 1, flags, newBoolean(some)),
db.CreateWindowFunction("mode", 1, order, newMode))
db.CreateWindowFunction("some", 1, flags, newBoolean(some)))
}
const (
@@ -99,10 +85,6 @@ const (
var_samp
stddev_pop
stddev_samp
skewness_pop
skewness_samp
kurtosis_pop
kurtosis_samp
corr
regr_r2
regr_sxx
@@ -116,24 +98,7 @@ const (
regr_json
)
func special(kind int, n int64) (null, zero bool) {
switch kind {
case var_pop, stddev_pop, regr_sxx, regr_syy, regr_sxy:
return n <= 0, n == 1
case regr_avgx, regr_avgy:
return n <= 0, false
case kurtosis_samp:
return n <= 3, false
case skewness_samp:
return n <= 2, false
case skewness_pop:
return n <= 1, n == 2
default:
return n <= 1, false
}
}
func newVariance(kind int) sqlite3.AggregateConstructor {
func newVariance(kind int) func() sqlite3.AggregateFunction {
return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
}
@@ -143,14 +108,6 @@ type variance struct {
}
func (fn *variance) Value(ctx sqlite3.Context) {
switch null, zero := special(fn.kind, fn.n); {
case zero:
ctx.ResultFloat(0)
return
case null:
return
}
var r float64
switch fn.kind {
case var_pop:
@@ -181,7 +138,7 @@ func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
}
func newCovariance(kind int) sqlite3.AggregateConstructor {
func newCovariance(kind int) func() sqlite3.AggregateFunction {
return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
}
@@ -191,18 +148,6 @@ type covariance struct {
}
func (fn *covariance) Value(ctx sqlite3.Context) {
if fn.kind == regr_count {
ctx.ResultInt64(fn.regr_count())
return
}
switch null, zero := special(fn.kind, fn.n); {
case zero:
ctx.ResultFloat(0)
return
case null:
return
}
var r float64
switch fn.kind {
case var_pop:
@@ -227,10 +172,11 @@ func (fn *covariance) Value(ctx sqlite3.Context) {
r = fn.regr_slope()
case regr_intercept:
r = fn.regr_intercept()
case regr_count:
ctx.ResultInt64(fn.regr_count())
return
case regr_json:
var buf [128]byte
ctx.ResultRawText(fn.regr_json(buf[:0]))
ctx.ResultSubtype('J')
ctx.ResultText(fn.regr_json())
return
}
ctx.ResultFloat(r)
@@ -257,51 +203,3 @@ func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
fn.dequeue(fa, fb)
}
}
func newMoments(kind int) sqlite3.AggregateConstructor {
return func() sqlite3.AggregateFunction { return &momentfn{kind: kind} }
}
type momentfn struct {
kind int
moments
}
func (fn *momentfn) Value(ctx sqlite3.Context) {
switch null, zero := special(fn.kind, fn.n); {
case zero:
ctx.ResultFloat(0)
return
case null:
return
}
var r float64
switch fn.kind {
case skewness_pop:
r = fn.skewness_pop()
case skewness_samp:
r = fn.skewness_samp()
case kurtosis_pop:
r = fn.kurtosis_pop()
case kurtosis_samp:
r = fn.kurtosis_samp()
}
ctx.ResultFloat(r)
}
func (fn *momentfn) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
a := arg[0]
f := a.Float()
if f != 0.0 || a.NumericType() != sqlite3.NULL {
fn.enqueue(f)
}
}
func (fn *momentfn) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
a := arg[0]
f := a.Float()
if f != 0.0 || a.NumericType() != sqlite3.NULL {
fn.dequeue(f)
}
}

View File

@@ -2,7 +2,6 @@ package stats_test
import (
"math"
"os"
"testing"
"github.com/ncruces/go-sqlite3"
@@ -13,7 +12,7 @@ import (
func TestMain(m *testing.M) {
sqlite3.AutoExtension(stats.Register)
os.Exit(m.Run())
m.Run()
}
func TestRegister_variance(t *testing.T) {
@@ -30,36 +29,21 @@ func TestRegister_variance(t *testing.T) {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT stddev_pop(x) FROM data`)
if err != nil {
t.Fatal(err)
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
stmt.Close()
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err = db.Prepare(`
stmt, _, err := db.Prepare(`
SELECT
sum(x), avg(x),
var_samp(x), var_pop(x),
stddev_samp(x), stddev_pop(x),
skewness_samp(x), skewness_pop(x),
kurtosis_samp(x), kurtosis_pop(x)
stddev_samp(x), stddev_pop(x)
FROM data`)
if err != nil {
t.Fatal(err)
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else {
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 40 {
t.Errorf("got %v, want 40", got)
}
@@ -78,27 +62,10 @@ func TestRegister_variance(t *testing.T) {
if got := stmt.ColumnFloat(5); got != math.Sqrt(22.5) {
t.Errorf("got %v, want √22.5", got)
}
if got := stmt.ColumnFloat(6); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(7); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(8); float32(got) != -3.3 {
t.Errorf("got %v, want -3.3", got)
}
if got := stmt.ColumnFloat(9); got != -1.64 {
t.Errorf("got %v, want -1.64", got)
}
}
stmt.Close()
stmt, _, err = db.Prepare(`
SELECT
var_samp(x) OVER (ROWS 1 PRECEDING),
var_pop(x) OVER (ROWS 1 PRECEDING),
skewness_pop(x) OVER (ROWS 1 PRECEDING)
FROM data`)
stmt, _, err = db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
if err != nil {
t.Fatal(err)
}
@@ -129,28 +96,12 @@ func TestRegister_covariance(t *testing.T) {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT regr_count(y, x), regr_json(y, x) FROM data`)
if err != nil {
t.Fatal(err)
}
if !stmt.Step() {
t.Fatal(stmt.Err())
} else {
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want 0", got)
}
if got := stmt.ColumnType(1); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
}
stmt.Close()
err = db.Exec(`INSERT INTO data (y, x) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err = db.Prepare(`SELECT
stmt, _, err := db.Prepare(`SELECT
corr(y, x), covar_samp(y, x), covar_pop(y, x),
regr_avgy(y, x), regr_avgx(y, x),
regr_syy(y, x), regr_sxx(y, x), regr_sxy(y, x),
@@ -160,59 +111,53 @@ func TestRegister_covariance(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !stmt.Step() {
t.Fatal(stmt.Err())
}
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
t.Errorf("got %v, want 0.9881049293224639", got)
}
if got := stmt.ColumnFloat(1); got != 21.25 {
t.Errorf("got %v, want 21.25", got)
}
if got := stmt.ColumnFloat(2); got != 17 {
t.Errorf("got %v, want 17", got)
}
if got := stmt.ColumnFloat(3); got != 4.2 {
t.Errorf("got %v, want 4.2", got)
}
if got := stmt.ColumnFloat(4); got != 75 {
t.Errorf("got %v, want 75", got)
}
if got := stmt.ColumnFloat(5); got != 14.8 {
t.Errorf("got %v, want 14.8", got)
}
if got := stmt.ColumnFloat(6); got != 500 {
t.Errorf("got %v, want 500", got)
}
if got := stmt.ColumnFloat(7); got != 85 {
t.Errorf("got %v, want 85", got)
}
if got := stmt.ColumnFloat(8); got != 0.17 {
t.Errorf("got %v, want 0.17", got)
}
if got := stmt.ColumnFloat(9); got != -8.55 {
t.Errorf("got %v, want -8.55", got)
}
if got := stmt.ColumnFloat(10); got != 0.9763513513513513 {
t.Errorf("got %v, want 0.9763513513513513", got)
}
if got := stmt.ColumnInt(11); got != 5 {
t.Errorf("got %v, want 5", got)
}
var a map[string]float64
if err := stmt.ColumnJSON(12, &a); err != nil {
t.Error(err)
} else if got := a["count"]; got != 5 {
t.Errorf("got %v, want 5", got)
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
t.Errorf("got %v, want 0.9881049293224639", got)
}
if got := stmt.ColumnFloat(1); got != 21.25 {
t.Errorf("got %v, want 21.25", got)
}
if got := stmt.ColumnFloat(2); got != 17 {
t.Errorf("got %v, want 17", got)
}
if got := stmt.ColumnFloat(3); got != 4.2 {
t.Errorf("got %v, want 4.2", got)
}
if got := stmt.ColumnFloat(4); got != 75 {
t.Errorf("got %v, want 75", got)
}
if got := stmt.ColumnFloat(5); got != 14.8 {
t.Errorf("got %v, want 14.8", got)
}
if got := stmt.ColumnFloat(6); got != 500 {
t.Errorf("got %v, want 500", got)
}
if got := stmt.ColumnFloat(7); got != 85 {
t.Errorf("got %v, want 85", got)
}
if got := stmt.ColumnFloat(8); got != 0.17 {
t.Errorf("got %v, want 0.17", got)
}
if got := stmt.ColumnFloat(9); got != -8.55 {
t.Errorf("got %v, want -8.55", got)
}
if got := stmt.ColumnFloat(10); got != 0.9763513513513513 {
t.Errorf("got %v, want 0.9763513513513513", got)
}
if got := stmt.ColumnInt(11); got != 5 {
t.Errorf("got %v, want 5", got)
}
var a map[string]float64
if err := stmt.ColumnJSON(12, &a); err != nil {
t.Error(err)
} else if got := a["count"]; got != 5 {
t.Errorf("got %v, want 5", got)
}
}
stmt.Close()
stmt, _, err = db.Prepare(`
SELECT
covar_samp(y, x) OVER (ROWS 1 PRECEDING),
covar_pop(y, x) OVER (ROWS 1 PRECEDING),
regr_avgx(y, x) OVER (ROWS 1 PRECEDING)
FROM data`)
stmt, _, err = db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`)
if err != nil {
t.Fatal(err)
}
@@ -226,9 +171,6 @@ func TestRegister_covariance(t *testing.T) {
t.Errorf("got %v, want %v", got, want[i])
}
}
if stmt.Err() != nil {
t.Fatal(stmt.Err())
}
stmt.Close()
}
@@ -253,9 +195,7 @@ func Benchmark_average(b *testing.B) {
b.Fatal(err)
}
if !stmt.Step() {
b.Fatal(stmt.Err())
} else {
if stmt.Step() {
want := float64(b.N) / 2
if got := stmt.ColumnFloat(0); got != want {
b.Errorf("got %v, want %v", got, want)
@@ -289,9 +229,7 @@ func Benchmark_variance(b *testing.B) {
b.Fatal(err)
}
if !stmt.Step() {
b.Fatal(stmt.Err())
} else if b.N > 100 {
if stmt.Step() && b.N > 100 {
want := float64(b.N*b.N) / 12
if got := stmt.ColumnFloat(0); want > (got-want)*float64(b.N) {
b.Errorf("got %v, want %v", got, want)

View File

@@ -3,20 +3,22 @@ package stats
import (
"math"
"strconv"
"github.com/ncruces/go-sqlite3/internal/util"
"strings"
)
// Welford's algorithm with Kahan summation:
// The effect of truncation in statistical computation [van Reeken, AJ 1970]
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
// See also:
// https://duckdb.org/docs/sql/aggregates.html#statistical-aggregates
type welford struct {
m1, m2 kahan
n int64
}
func (w welford) mean() float64 {
func (w welford) average() float64 {
return w.m1.hi
}
@@ -37,23 +39,17 @@ func (w welford) stddev_samp() float64 {
}
func (w *welford) enqueue(x float64) {
n := w.n + 1
w.n = n
w.n++
d1 := x - w.m1.hi - w.m1.lo
w.m1.add(d1 / float64(n))
w.m1.add(d1 / float64(w.n))
d2 := x - w.m1.hi - w.m1.lo
w.m2.add(d1 * d2)
}
func (w *welford) dequeue(x float64) {
n := w.n - 1
if n <= 0 {
*w = welford{}
return
}
w.n = n
w.n--
d1 := x - w.m1.hi - w.m1.lo
w.m1.sub(d1 / float64(n))
w.m1.sub(d1 / float64(w.n))
d2 := x - w.m1.hi - w.m1.lo
w.m2.sub(d1 * d2)
}
@@ -116,35 +112,38 @@ func (w welford2) regr_r2() float64 {
return w.cov.hi * w.cov.hi / (w.m2y.hi * w.m2x.hi)
}
func (w welford2) regr_json(dst []byte) []byte {
dst = append(dst, `{"count":`...)
dst = strconv.AppendInt(dst, w.regr_count(), 10)
dst = append(dst, `,"avgy":`...)
dst = util.AppendNumber(dst, w.regr_avgy())
dst = append(dst, `,"avgx":`...)
dst = util.AppendNumber(dst, w.regr_avgx())
dst = append(dst, `,"syy":`...)
dst = util.AppendNumber(dst, w.regr_syy())
dst = append(dst, `,"sxx":`...)
dst = util.AppendNumber(dst, w.regr_sxx())
dst = append(dst, `,"sxy":`...)
dst = util.AppendNumber(dst, w.regr_sxy())
dst = append(dst, `,"slope":`...)
dst = util.AppendNumber(dst, w.regr_slope())
dst = append(dst, `,"intercept":`...)
dst = util.AppendNumber(dst, w.regr_intercept())
dst = append(dst, `,"r2":`...)
dst = util.AppendNumber(dst, w.regr_r2())
return append(dst, '}')
func (w welford2) regr_json() string {
var json strings.Builder
var num [32]byte
json.Grow(128)
json.WriteString(`{"count":`)
json.Write(strconv.AppendInt(num[:0], w.regr_count(), 10))
json.WriteString(`,"avgy":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_avgy(), 'g', -1, 64))
json.WriteString(`,"avgx":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_avgx(), 'g', -1, 64))
json.WriteString(`,"syy":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_syy(), 'g', -1, 64))
json.WriteString(`,"sxx":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_sxx(), 'g', -1, 64))
json.WriteString(`,"sxy":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_sxy(), 'g', -1, 64))
json.WriteString(`,"slope":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_slope(), 'g', -1, 64))
json.WriteString(`,"intercept":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_intercept(), 'g', -1, 64))
json.WriteString(`,"r2":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_r2(), 'g', -1, 64))
json.WriteByte('}')
return json.String()
}
func (w *welford2) enqueue(y, x float64) {
n := w.n + 1
w.n = n
w.n++
d1y := y - w.m1y.hi - w.m1y.lo
d1x := x - w.m1x.hi - w.m1x.lo
w.m1y.add(d1y / float64(n))
w.m1x.add(d1x / float64(n))
w.m1y.add(d1y / float64(w.n))
w.m1x.add(d1x / float64(w.n))
d2y := y - w.m1y.hi - w.m1y.lo
d2x := x - w.m1x.hi - w.m1x.lo
w.m2y.add(d1y * d2y)
@@ -153,19 +152,30 @@ func (w *welford2) enqueue(y, x float64) {
}
func (w *welford2) dequeue(y, x float64) {
n := w.n - 1
if n <= 0 {
*w = welford2{}
return
}
w.n = n
w.n--
d1y := y - w.m1y.hi - w.m1y.lo
d1x := x - w.m1x.hi - w.m1x.lo
w.m1y.sub(d1y / float64(n))
w.m1x.sub(d1x / float64(n))
w.m1y.sub(d1y / float64(w.n))
w.m1x.sub(d1x / float64(w.n))
d2y := y - w.m1y.hi - w.m1y.lo
d2x := x - w.m1x.hi - w.m1x.lo
w.m2y.sub(d1y * d2y)
w.m2x.sub(d1x * d2x)
w.cov.sub(d1y * d2x)
}
type kahan struct{ hi, lo float64 }
func (k *kahan) add(x float64) {
y := k.lo + x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}
func (k *kahan) sub(x float64) {
y := k.lo - x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}

View File

@@ -9,14 +9,12 @@ func Test_welford(t *testing.T) {
t.Parallel()
var s1, s2 welford
s1.enqueue(1)
s1.dequeue(1)
s1.enqueue(4)
s1.enqueue(7)
s1.enqueue(13)
s1.enqueue(16)
if got := s1.mean(); got != 10 {
if got := s1.average(); got != 10 {
t.Errorf("got %v, want 10", got)
}
if got := s1.var_samp(); got != 30 {
@@ -45,8 +43,6 @@ func Test_covar(t *testing.T) {
t.Parallel()
var c1, c2 welford2
c1.enqueue(1, 1)
c1.dequeue(1, 1)
c1.enqueue(3, 70)
c1.enqueue(5, 80)

View File

@@ -1,22 +1,19 @@
// Package unicode provides an alternative to the SQLite ICU extension.
//
// Like the [ICU extension], it provides Unicode aware:
// - upper() and lower() functions
// - LIKE and REGEXP operators
// - collation sequences
// - upper() and lower() functions,
// - LIKE and REGEXP operators,
// - collation sequences.
//
// Like PostgreSQL, it also provides:
// - initcap()
// - casefold()
// - normalize()
// - unaccent()
// It also provides, from PostgreSQL:
// - unaccent(),
// - initcap().
//
// The implementations are not 100% compatible:
// - upper(), lower(), initcap() casefold() use [strings.ToUpper], [strings.ToLower], [strings.Title] and [cases]
// - normalize(), unaccent() use [transform] and [unicode.Mn]
// - the LIKE operator follows [strings.EqualFold] rules
// - the REGEXP operator uses Go [regexp/syntax]
// - collation sequences use [collate]
// The implementation is not 100% compatible with the [ICU extension]:
// - upper() and lower() use [strings.ToUpper], [strings.ToLower] and [cases];
// - the LIKE operator follows [strings.EqualFold] rules;
// - the REGEXP operator uses Go [regexp/syntax];
// - collation sequences use [collate].
//
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
//
@@ -28,7 +25,6 @@ import (
"errors"
"regexp"
"strings"
"sync"
"unicode"
"unicode/utf8"
@@ -43,7 +39,7 @@ import (
"github.com/ncruces/go-sqlite3/internal/util"
)
// RegisterLike must be set to false to not register a Unicode aware LIKE operator.
// Set RegisterLike to false to not register a Unicode aware LIKE operator.
// Overriding the built-in LIKE operator disables the [LIKE optimization].
//
// [LIKE optimization]: https://sqlite.org/optoverview.html#the_like_optimization
@@ -52,13 +48,13 @@ var RegisterLike = true
// Register registers Unicode aware functions for a database connection.
func Register(db *sqlite3.Conn) error {
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
var lkfn sqlite3.ScalarFunction
var errs util.ErrorJoiner
if RegisterLike {
lkfn = like
errs.Join(
db.CreateFunction("like", 2, flags, like),
db.CreateFunction("like", 3, flags, like))
}
return errors.Join(
db.CreateFunction("like", 2, flags, lkfn),
db.CreateFunction("like", 3, flags, lkfn),
errs.Join(
db.CreateFunction("upper", 1, flags, upper),
db.CreateFunction("upper", 2, flags, upper),
db.CreateFunction("lower", 1, flags, lower),
@@ -66,10 +62,7 @@ func Register(db *sqlite3.Conn) error {
db.CreateFunction("regexp", 2, flags, regex),
db.CreateFunction("initcap", 1, flags, initcap),
db.CreateFunction("initcap", 2, flags, initcap),
db.CreateFunction("casefold", 1, flags, casefold),
db.CreateFunction("unaccent", 1, flags, unaccent),
db.CreateFunction("normalize", 1, flags, normalize),
db.CreateFunction("normalize", 2, flags, normalize),
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
name := arg[1].Text()
@@ -83,6 +76,7 @@ func Register(db *sqlite3.Conn) error {
return // notest
}
}))
return errors.Join(errs...)
}
// RegisterCollation registers a Unicode collation sequence for a database connection.
@@ -115,8 +109,9 @@ func upper(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultError(err)
return // notest
}
cs = cases.Upper(t)
ctx.SetAuxData(1, cs)
c := cases.Upper(t)
ctx.SetAuxData(1, c)
cs = c
}
ctx.ResultRawText(cs.Bytes(arg[0].RawText()))
}
@@ -133,8 +128,9 @@ func lower(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultError(err)
return // notest
}
cs = cases.Lower(t)
ctx.SetAuxData(1, cs)
c := cases.Lower(t)
ctx.SetAuxData(1, c)
cs = c
}
ctx.ResultRawText(cs.Bytes(arg[0].RawText()))
}
@@ -151,26 +147,15 @@ func initcap(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultError(err)
return // notest
}
cs = cases.Title(t)
ctx.SetAuxData(1, cs)
c := cases.Title(t)
ctx.SetAuxData(1, c)
cs = c
}
ctx.ResultRawText(cs.Bytes(arg[0].RawText()))
}
func casefold(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultRawText(cases.Fold().Bytes(arg[0].RawText()))
}
var unaccentPool = sync.Pool{
New: func() any {
return transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC)
},
}
func unaccent(ctx sqlite3.Context, arg ...sqlite3.Value) {
unaccent := unaccentPool.Get().(transform.Transformer)
defer unaccentPool.Put(unaccent)
unaccent := transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC)
res, _, err := transform.Bytes(unaccent, arg[0].RawText())
if err != nil {
ctx.ResultError(err) // notest
@@ -179,44 +164,16 @@ func unaccent(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
}
func normalize(ctx sqlite3.Context, arg ...sqlite3.Value) {
form := norm.NFC
if len(arg) > 1 {
switch strings.ToUpper(arg[1].Text()) {
case "NFC":
//
case "NFD":
form = norm.NFD
case "NFKC":
form = norm.NFKC
case "NFKD":
form = norm.NFKD
default:
ctx.ResultError(util.ErrorString("unicode: invalid form"))
return
}
}
res, _, err := transform.Bytes(form, arg[0].RawText())
if err != nil {
ctx.ResultError(err) // notest
} else {
ctx.ResultRawText(res)
}
}
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
if !ok {
re, ok = arg[0].Pointer().(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
re = r
r, err := regexp.Compile(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
ctx.SetAuxData(0, re)
re = r
ctx.SetAuxData(0, r)
}
ctx.ResultBool(re.Match(arg[1].RawText()))
}

View File

@@ -2,7 +2,7 @@ package unicode
import (
"errors"
"slices"
"reflect"
"testing"
"github.com/ncruces/go-sqlite3"
@@ -26,10 +26,11 @@ func TestRegister(t *testing.T) {
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal(stmt.Err())
if stmt.Step() {
return stmt.ColumnText(0)
}
return stmt.ColumnText(0)
t.Fatal(stmt.Err())
return ""
}
Register(db)
@@ -48,12 +49,6 @@ func TestRegister(t *testing.T) {
{`upper('Dünyanın İlk Borsası', 'tr-TR')`, "DÜNYANIN İLK BORSASI"},
{`initcap('Kad je hladno Marko nosi džemper')`, "Kad Je Hladno Marko Nosi Džemper"},
{`initcap('Kad je hladno Marko nosi džemper', 'hr-HR')`, "Kad Je Hladno Marko Nosi Džemper"},
{`normalize(X'61cc88')`, "ä"},
{`normalize(X'61cc88', 'NFC' )`, "ä"},
{`normalize(X'61cc88', 'NFKC')`, "ä"},
{`normalize('ä', 'NFD' )`, "\x61\xcc\x88"},
{`normalize('ä', 'NFKD')`, "\x61\xcc\x88"},
{`casefold('Maße')`, "masse"},
{`unaccent('Hôtel')`, "Hotel"},
{`'Hello' REGEXP 'ell'`, "1"},
{`'Hello' REGEXP 'el.'`, "1"},
@@ -120,7 +115,7 @@ func TestRegister_collation(t *testing.T) {
t.Fatal(err)
}
if !slices.Equal(got, want) {
if !reflect.DeepEqual(got, want) {
t.Error("not equal")
}
@@ -171,7 +166,7 @@ func TestRegisterCollationsNeeded(t *testing.T) {
t.Fatal(err)
}
if !slices.Equal(got, want) {
if !reflect.DeepEqual(got, want) {
t.Error("not equal")
}
@@ -213,14 +208,6 @@ func TestRegister_error(t *testing.T) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT normalize('', 'NF')`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT 'hello' REGEXP '\'`)
if err == nil {
t.Error("want error")

View File

@@ -7,7 +7,6 @@ import (
"bytes"
"errors"
"fmt"
"time"
"github.com/google/uuid"
@@ -17,18 +16,17 @@ import (
// Register registers the SQL functions:
//
// - uuid([ version [, domain/namespace, [ id/data ]]]):
// to generate a UUID as a string
// - uuid_str(u):
// to convert a UUID into a well-formed UUID string
// - uuid_blob(u):
// to convert a UUID into a 16-byte blob
// - uuid_extract_version(u):
// to extract the version of a RFC 4122 UUID
// - uuid_extract_timestamp(u):
// to extract the timestamp of a version 1/2/6/7 UUID
// - gen_random_uuid(u):
// to generate a version 4 (random) UUID
// uuid([version], [domain/namespace], [id/data])
//
// Generates a UUID as a string.
//
// uuid_str(u)
//
// Converts a UUID into a well-formed UUID string.
//
// uuid_blob(u)
//
// Converts a UUID into a 16-byte blob.
func Register(db *sqlite3.Conn) error {
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
return errors.Join(
@@ -37,10 +35,7 @@ func Register(db *sqlite3.Conn) error {
db.CreateFunction("uuid", 2, sqlite3.INNOCUOUS, generate),
db.CreateFunction("uuid", 3, sqlite3.INNOCUOUS, generate),
db.CreateFunction("uuid_str", 1, flags, toString),
db.CreateFunction("uuid_blob", 1, flags, toBlob),
db.CreateFunction("uuid_extract_version", 1, flags, version),
db.CreateFunction("uuid_extract_timestamp", 1, flags, timestamp),
db.CreateFunction("gen_random_uuid", 0, sqlite3.INNOCUOUS, generate))
db.CreateFunction("uuid_blob", 1, flags, toBlob))
}
func generate(ctx sqlite3.Context, arg ...sqlite3.Value) {
@@ -172,30 +167,3 @@ func toString(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultText(u.String())
}
}
func version(ctx sqlite3.Context, arg ...sqlite3.Value) {
u, err := fromValue(arg[0])
if err != nil {
ctx.ResultError(err)
return // notest
}
if u.Variant() == uuid.RFC4122 {
ctx.ResultInt64(int64(u.Version()))
}
}
func timestamp(ctx sqlite3.Context, arg ...sqlite3.Value) {
u, err := fromValue(arg[0])
if err != nil {
ctx.ResultError(err)
return // notest
}
if u.Variant() == uuid.RFC4122 {
switch u.Version() {
case 1, 2, 6, 7:
ctx.ResultTime(
time.Unix(u.Time().UnixTime()).UTC(),
sqlite3.TimeFormatDefault)
}
}
}

View File

@@ -2,7 +2,6 @@ package uuid
import (
"testing"
"time"
"github.com/google/uuid"
@@ -14,9 +13,9 @@ import (
func Test_generate(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := driver.Open(dsn, Register)
db, err := driver.Open(tmp, Register)
if err != nil {
t.Fatal(err)
}
@@ -107,26 +106,7 @@ func Test_generate(t *testing.T) {
t.Error("want error")
}
var tstamp time.Time
var version uuid.Version
err = db.QueryRow(`
SELECT
column1,
uuid_extract_version(column1),
uuid_extract_timestamp(column1)
FROM (VALUES (uuid(7)))
`).Scan(&u, &version, &tstamp)
if err != nil {
t.Fatal(err)
}
if got := u.Version(); got != version {
t.Errorf("got %d, want %d", got, version)
}
if got := time.Unix(u.Time().UnixTime()); !got.Equal(tstamp) {
t.Errorf("got %v, want %v", got, tstamp)
}
tests := []struct {
hash := []struct {
ver uuid.Version
ns any
data string
@@ -140,7 +120,7 @@ func Test_generate(t *testing.T) {
{3, "url", "https://www.php.net", uuid.MustParse("3f703955-aaba-3e70-a3cb-baff6aa3b28f")},
{5, "url", "https://www.php.net", uuid.MustParse("a8f6ae40-d8a7-58f0-be05-a22f94eca9ec")},
}
for _, tt := range tests {
for _, tt := range hash {
err = db.QueryRow(`SELECT uuid(?, ?, ?)`, tt.ver, tt.ns, tt.data).Scan(&u)
if err != nil {
t.Fatal(err)
@@ -153,23 +133,23 @@ func Test_generate(t *testing.T) {
func Test_convert(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := driver.Open(dsn, Register)
db, err := driver.Open(tmp, Register)
if err != nil {
t.Fatal(err)
}
defer db.Close()
var u uuid.UUID
tests := []string{
lits := []string{
"'6ba7b8119dad11d180b400c04fd430c8'",
"'6ba7b811-9dad-11d1-80b4-00c04fd430c8'",
"'{6ba7b811-9dad-11d1-80b4-00c04fd430c8}'",
"X'6ba7b8119dad11d180b400c04fd430c8'",
}
for _, tt := range tests {
for _, tt := range lits {
err = db.QueryRow(`SELECT uuid_str(` + tt + `)`).Scan(&u)
if err != nil {
t.Fatal(err)
@@ -179,7 +159,7 @@ func Test_convert(t *testing.T) {
}
}
for _, tt := range tests {
for _, tt := range lits {
err = db.QueryRow(`SELECT uuid_blob(` + tt + `)`).Scan(&u)
if err != nil {
t.Fatal(err)
@@ -198,14 +178,4 @@ func Test_convert(t *testing.T) {
if err == nil {
t.Fatal("want error")
}
err = db.QueryRow(`SELECT uuid_extract_version(X'cafe')`).Scan(&u)
if err == nil {
t.Fatal("want error")
}
err = db.QueryRow(`SELECT uuid_extract_timestamp(X'cafe')`).Scan(&u)
if err == nil {
t.Fatal("want error")
}
}

View File

@@ -19,9 +19,9 @@ func Register(db *sqlite3.Conn) error {
}
func zorder(ctx sqlite3.Context, arg ...sqlite3.Value) {
var x [24]int64
if n := len(arg); n < 2 || n > 24 {
ctx.ResultError(util.ErrorString("zorder: needs between 2 and 24 dimensions"))
var x [63]int64
if len(arg) > len(x) {
ctx.ResultError(util.ErrorString("zorder: too many parameters"))
return
}
for i := range arg {
@@ -29,15 +29,17 @@ func zorder(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
var z int64
for i := range 63 {
j := i % len(arg)
z |= (x[j] & 1) << i
x[j] >>= 1
if len(arg) > 0 {
for i := range x {
j := i % len(arg)
z |= (x[j] & 1) << i
x[j] >>= 1
}
}
for i := range arg {
if x[i] != 0 {
ctx.ResultError(util.ErrorString("zorder: argument out of range"))
ctx.ResultError(util.ErrorString("zorder: parameter too large"))
return
}
}
@@ -49,19 +51,6 @@ func unzorder(ctx sqlite3.Context, arg ...sqlite3.Value) {
n := arg[1].Int64()
z := arg[0].Int64()
if n < 2 || n > 24 {
ctx.ResultError(util.ErrorString("unzorder: needs between 2 and 24 dimensions"))
return
}
if i < 0 || i >= n {
ctx.ResultError(util.ErrorString("unzorder: index out of range"))
return
}
if z < 0 {
ctx.ResultError(util.ErrorString("unzorder: argument out of range"))
return
}
var k int
var x int64
for j := i; j < 63; j += n {

View File

@@ -12,11 +12,11 @@ import (
"github.com/ncruces/go-sqlite3/vfs/memdb"
)
func Test_zorder(t *testing.T) {
func TestRegister_zorder(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := driver.Open(dsn, zorder.Register)
db, err := driver.Open(tmp, zorder.Register)
if err != nil {
t.Fatal(err)
}
@@ -57,11 +57,11 @@ func Test_zorder(t *testing.T) {
}
}
func Test_unzorder(t *testing.T) {
func TestRegister_unzorder(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := driver.Open(dsn, zorder.Register)
db, err := driver.Open(tmp, zorder.Register)
if err != nil {
t.Fatal(err)
}
@@ -85,11 +85,11 @@ func Test_unzorder(t *testing.T) {
}
}
func Test_zorder_error(t *testing.T) {
func TestRegister_error(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
tmp := memdb.TestDB(t)
db, err := driver.Open(dsn, zorder.Register)
db, err := driver.Open(tmp, zorder.Register)
if err != nil {
t.Fatal(err)
}
@@ -103,7 +103,7 @@ func Test_zorder_error(t *testing.T) {
var buf strings.Builder
buf.WriteString("SELECT zorder(0")
for i := 1; i < 25; i++ {
for i := 1; i < 80; i++ {
buf.WriteByte(',')
buf.WriteString(strconv.Itoa(0))
}
@@ -113,30 +113,3 @@ func Test_zorder_error(t *testing.T) {
t.Error("want error")
}
}
func Test_unzorder_error(t *testing.T) {
t.Parallel()
dsn := memdb.TestDB(t)
db, err := driver.Open(dsn, zorder.Register)
if err != nil {
t.Fatal(err)
}
defer db.Close()
var got int64
err = db.QueryRow(`SELECT unzorder(-1, 2, 0)`).Scan(&got)
if err == nil {
t.Error("want error")
}
err = db.QueryRow(`SELECT unzorder(0, 2, 2)`).Scan(&got)
if err == nil {
t.Error("want error")
}
err = db.QueryRow(`SELECT unzorder(0, 25, 2)`).Scan(&got)
if err == nil {
t.Error("want error")
}
}

231
func.go
View File

@@ -2,10 +2,7 @@ package sqlite3
import (
"context"
"io"
"iter"
"sync"
"sync/atomic"
"github.com/tetratelabs/wazero/api"
@@ -17,12 +14,12 @@ import (
//
// https://sqlite.org/c3ref/collation_needed.html
func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error {
var enable int32
var enable uint64
if cb != nil {
enable = 1
}
rc := res_t(c.call("sqlite3_collation_needed_go", stk_t(c.handle), stk_t(enable)))
if err := c.error(rc); err != nil {
r := c.call("sqlite3_collation_needed_go", uint64(c.handle), enable)
if err := c.error(r); err != nil {
return err
}
c.collation = cb
@@ -36,8 +33,8 @@ func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error {
// This can be used to load schemas that contain
// one or more unknown collating sequences.
func (c Conn) AnyCollationNeeded() error {
rc := res_t(c.call("sqlite3_anycollseq_init", stk_t(c.handle), 0, 0))
if err := c.error(rc); err != nil {
r := c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0)
if err := c.error(r); err != nil {
return err
}
c.collation = nil
@@ -47,103 +44,60 @@ func (c Conn) AnyCollationNeeded() error {
// CreateCollation defines a new collating sequence.
//
// https://sqlite.org/c3ref/create_collation.html
func (c *Conn) CreateCollation(name string, fn CollatingFunction) error {
var funcPtr ptr_t
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
var funcPtr uint32
defer c.arena.mark()()
namePtr := c.arena.string(name)
if fn != nil {
funcPtr = util.AddHandle(c.ctx, fn)
}
rc := res_t(c.call("sqlite3_create_collation_go",
stk_t(c.handle), stk_t(namePtr), stk_t(funcPtr)))
return c.error(rc)
r := c.call("sqlite3_create_collation_go",
uint64(c.handle), uint64(namePtr), uint64(funcPtr))
return c.error(r)
}
// CollatingFunction is the type of a collation callback.
// Implementations must not retain a or b.
type CollatingFunction func(a, b []byte) int
// CreateFunction defines a new scalar SQL function.
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error {
var funcPtr ptr_t
var funcPtr uint32
defer c.arena.mark()()
namePtr := c.arena.string(name)
if fn != nil {
funcPtr = util.AddHandle(c.ctx, fn)
}
rc := res_t(c.call("sqlite3_create_function_go",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
stk_t(flag), stk_t(funcPtr)))
return c.error(rc)
r := c.call("sqlite3_create_function_go",
uint64(c.handle), uint64(namePtr), uint64(nArg),
uint64(flag), uint64(funcPtr))
return c.error(r)
}
// ScalarFunction is the type of a scalar SQL function.
// Implementations must not retain arg.
type ScalarFunction func(ctx Context, arg ...Value)
// CreateAggregateFunction defines a new aggregate SQL function.
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateAggregateFunction(name string, nArg int, flag FunctionFlag, fn AggregateSeqFunction) error {
var funcPtr ptr_t
defer c.arena.mark()()
namePtr := c.arena.string(name)
if fn != nil {
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
var a aggregateFunc
coro := func(yieldCoro func(struct{}) bool) {
seq := func(yieldSeq func([]Value) bool) {
for yieldSeq(a.arg) {
if !yieldCoro(struct{}{}) {
break
}
}
}
fn(&a.ctx, seq)
}
a.next, a.stop = iter.Pull(coro)
return &a
}))
}
rc := res_t(c.call("sqlite3_create_aggregate_function_go",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
stk_t(flag), stk_t(funcPtr)))
return c.error(rc)
}
// AggregateSeqFunction is the type of an aggregate SQL function.
// Implementations must not retain the slices yielded by seq.
type AggregateSeqFunction func(ctx *Context, seq iter.Seq[[]Value])
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
// If fn returns a [WindowFunction], an aggregate window function is created.
// If fn returns a [WindowFunction], then an aggregate window function is created.
// If fn returns an [io.Closer], it will be called to free resources.
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn AggregateConstructor) error {
var funcPtr ptr_t
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
var funcPtr uint32
defer c.arena.mark()()
namePtr := c.arena.string(name)
if fn != nil {
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
agg := fn()
if win, ok := agg.(WindowFunction); ok {
return win
}
return agg
}))
funcPtr = util.AddHandle(c.ctx, fn)
}
rc := res_t(c.call("sqlite3_create_window_function_go",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
stk_t(flag), stk_t(funcPtr)))
return c.error(rc)
call := "sqlite3_create_aggregate_function_go"
if _, ok := fn().(WindowFunction); ok {
call = "sqlite3_create_window_function_go"
}
r := c.call(call,
uint64(c.handle), uint64(namePtr), uint64(nArg),
uint64(flag), uint64(funcPtr))
return c.error(r)
}
// AggregateConstructor is a an [AggregateFunction] constructor.
type AggregateConstructor func() AggregateFunction
// AggregateFunction is the interface an aggregate function should implement.
//
// https://sqlite.org/appfunc.html
@@ -175,135 +129,102 @@ type WindowFunction interface {
func (c *Conn) OverloadFunction(name string, nArg int) error {
defer c.arena.mark()()
namePtr := c.arena.string(name)
rc := res_t(c.call("sqlite3_overload_function",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg)))
return c.error(rc)
r := c.call("sqlite3_overload_function",
uint64(c.handle), uint64(namePtr), uint64(nArg))
return c.error(r)
}
func destroyCallback(ctx context.Context, mod api.Module, pApp ptr_t) {
func destroyCallback(ctx context.Context, mod api.Module, pApp uint32) {
util.DelHandle(ctx, pApp)
}
func collationCallback(ctx context.Context, mod api.Module, pArg, pDB ptr_t, eTextRep uint32, zName ptr_t) {
func collationCallback(ctx context.Context, mod api.Module, pArg, pDB, eTextRep, zName uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.collation != nil {
name := util.ReadString(mod, zName, _MAX_NAME)
c.collation(c, name)
}
}
func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int32, pKey1 ptr_t, nKey2 int32, pKey2 ptr_t) uint32 {
fn := util.GetHandle(ctx, pApp).(CollatingFunction)
return uint32(fn(util.View(mod, pKey1, int64(nKey1)), util.View(mod, pKey2, int64(nKey2))))
func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))
}
func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp ptr_t, nArg int32, pArg ptr_t) {
func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp, nArg, pArg uint32) {
args := getFuncArgs()
defer putFuncArgs(args)
db := ctx.Value(connKey{}).(*Conn)
args := callbackArgs(db, nArg, pArg)
defer returnArgs(args)
fn := util.GetHandle(db.ctx, pApp).(ScalarFunction)
fn(Context{db, pCtx}, *args...)
callbackArgs(db, args[:nArg], pArg)
fn(Context{db, pCtx}, args[:nArg]...)
}
func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, nArg int32, pArg ptr_t) {
func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp, nArg, pArg uint32) {
args := getFuncArgs()
defer putFuncArgs(args)
db := ctx.Value(connKey{}).(*Conn)
args := callbackArgs(db, nArg, pArg)
defer returnArgs(args)
callbackArgs(db, args[:nArg], pArg)
fn, _ := callbackAggregate(db, pAgg, pApp)
fn.Step(Context{db, pCtx}, *args...)
fn.Step(Context{db, pCtx}, args[:nArg]...)
}
func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, final int32) {
func finalCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp uint32) {
db := ctx.Value(connKey{}).(*Conn)
fn, handle := callbackAggregate(db, pAgg, pApp)
fn.Value(Context{db, pCtx})
// Cleanup.
if final != 0 {
var err error
if handle != 0 {
err = util.DelHandle(ctx, handle)
} else if c, ok := fn.(io.Closer); ok {
err = c.Close()
}
if err != nil {
Context{db, pCtx}.ResultError(err)
return // notest
}
}
util.DelHandle(ctx, handle)
}
func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t, nArg int32, pArg ptr_t) {
func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg uint32) {
db := ctx.Value(connKey{}).(*Conn)
args := callbackArgs(db, nArg, pArg)
defer returnArgs(args)
fn := util.GetHandle(db.ctx, pAgg).(WindowFunction)
fn.Inverse(Context{db, pCtx}, *args...)
fn := util.GetHandle(db.ctx, pAgg).(AggregateFunction)
fn.Value(Context{db, pCtx})
}
func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) {
func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg, nArg, pArg uint32) {
args := getFuncArgs()
defer putFuncArgs(args)
db := ctx.Value(connKey{}).(*Conn)
callbackArgs(db, args[:nArg], pArg)
fn := util.GetHandle(db.ctx, pAgg).(WindowFunction)
fn.Inverse(Context{db, pCtx}, args[:nArg]...)
}
func callbackAggregate(db *Conn, pAgg, pApp uint32) (AggregateFunction, uint32) {
if pApp == 0 {
handle := util.Read32[ptr_t](db.mod, pAgg)
handle := util.ReadUint32(db.mod, pAgg)
return util.GetHandle(db.ctx, handle).(AggregateFunction), handle
}
// We need to create the aggregate.
fn := util.GetHandle(db.ctx, pApp).(AggregateConstructor)()
fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)()
if pAgg != 0 {
handle := util.AddHandle(db.ctx, fn)
util.Write32(db.mod, pAgg, handle)
util.WriteUint32(db.mod, pAgg, handle)
return fn, handle
}
return fn, 0
}
var (
valueArgsPool sync.Pool
valueArgsLen atomic.Int32
)
func callbackArgs(db *Conn, nArg int32, pArg ptr_t) *[]Value {
arg, ok := valueArgsPool.Get().(*[]Value)
if !ok || cap(*arg) < int(nArg) {
max := valueArgsLen.Or(nArg) | nArg
lst := make([]Value, max)
arg = &lst
}
lst := (*arg)[:nArg]
for i := range lst {
lst[i] = Value{
func callbackArgs(db *Conn, arg []Value, pArg uint32) {
for i := range arg {
arg[i] = Value{
c: db,
handle: util.Read32[ptr_t](db.mod, pArg+ptr_t(i)*ptrlen),
handle: util.ReadUint32(db.mod, pArg+ptrlen*uint32(i)),
}
}
*arg = lst
return arg
}
func returnArgs(p *[]Value) {
valueArgsPool.Put(p)
var funcArgsPool sync.Pool
func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) {
funcArgsPool.Put(p)
}
type aggregateFunc struct {
next func() (struct{}, bool)
stop func()
ctx Context
arg []Value
}
func (a *aggregateFunc) Step(ctx Context, arg ...Value) {
a.ctx = ctx
a.arg = append(a.arg[:0], arg...)
if _, more := a.next(); !more {
a.stop()
func getFuncArgs() *[_MAX_FUNCTION_ARG]Value {
if p := funcArgsPool.Get(); p == nil {
return new([_MAX_FUNCTION_ARG]Value)
} else {
return p.(*[_MAX_FUNCTION_ARG]Value)
}
}
func (a *aggregateFunc) Value(ctx Context) {
a.ctx = ctx
a.stop()
}
func (a *aggregateFunc) Close() error {
a.stop()
return nil
}

View File

@@ -1,57 +0,0 @@
package sqlite3_test
import (
"fmt"
"iter"
"log"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func ExampleConn_CreateAggregateFunction() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE test (col)`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (1), (2), (3)`)
if err != nil {
log.Fatal(err)
}
err = db.CreateAggregateFunction("seq_avg", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS,
func(ctx *sqlite3.Context, seq iter.Seq[[]sqlite3.Value]) {
count := 0
total := 0.0
for arg := range seq {
total += arg[0].Float()
count++
}
ctx.ResultFloat(total / float64(count))
})
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT seq_avg(col) FROM test`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnFloat(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Output:
// 2
}

22
go.mod
View File

@@ -1,26 +1,24 @@
module github.com/ncruces/go-sqlite3
go 1.24.0
go 1.21
toolchain go1.23.0
require (
github.com/ncruces/julianday v1.0.0
github.com/ncruces/sort v0.1.6
github.com/ncruces/wbt v1.0.0
github.com/tetratelabs/wazero v1.11.0
golang.org/x/sys v0.40.0
github.com/ncruces/sort v0.1.2
github.com/tetratelabs/wazero v1.8.2
golang.org/x/crypto v0.32.0
golang.org/x/sys v0.29.0
)
require (
github.com/dchest/siphash v1.2.3 // ext/bloom
github.com/google/uuid v1.6.0 // ext/uuid
github.com/psanford/httpreadat v0.1.0 // example
golang.org/x/crypto v0.47.0 // vfs/adiantum vfs/xts
golang.org/x/sync v0.19.0 // test
golang.org/x/text v0.33.0 // ext/unicode
golang.org/x/sync v0.10.0 // test
golang.org/x/text v0.21.0 // ext/unicode
lukechampine.com/adiantum v1.1.1 // vfs/adiantum
)
retract (
v0.23.2 // tagged from the wrong branch
v0.4.0 // tagged from the wrong branch
)
retract v0.4.0 // tagged from the wrong branch

26
go.sum
View File

@@ -4,21 +4,19 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/ncruces/sort v0.1.6 h1:TrsJfGRH1AoWoaeB4/+gCohot9+cA6u/INaH5agIhNk=
github.com/ncruces/sort v0.1.6/go.mod h1:obJToO4rYr6VWP0Uw5FYymgYGt3Br4RXcs/JdKaXAPk=
github.com/ncruces/wbt v1.0.0 h1:8iBE7UPjTLUpzu3/FCRjAmuQjWzgxo10RGBgt3ooLSc=
github.com/ncruces/wbt v1.0.0/go.mod h1:DtF92amvMxH69EmBFUSFWRDAlo6hOEfoNQnClxj9C/c=
github.com/ncruces/sort v0.1.2 h1:zKQ9CA4fpHPF6xsUhRTfi5EEryspuBpe/QA4VWQOV1U=
github.com/ncruces/sort v0.1.2/go.mod h1:vEJUTBJtebIuCMmXD18GKo5GJGhsay+xZFOoBEIXFmE=
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
github.com/tetratelabs/wazero v1.11.0 h1:+gKemEuKCTevU4d7ZTzlsvgd1uaToIDtlQlmNbwqYhA=
github.com/tetratelabs/wazero v1.11.0/go.mod h1:eV28rsN8Q+xwjogd7f4/Pp4xFxO7uOGbLcD/LzB1wiU=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
github.com/tetratelabs/wazero v1.8.2 h1:yIgLR/b2bN31bjxwXHD8a3d+BogigR952csSDdLYEv4=
github.com/tetratelabs/wazero v1.8.2/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
lukechampine.com/adiantum v1.1.1 h1:4fp6gTxWCqpEbLy40ExiYDDED3oUNWx5cTqBCtPdZqA=
lukechampine.com/adiantum v1.1.1/go.mod h1:LrAYVnTYLnUtE/yMp5bQr0HstAf060YUF8nM0B6+rUw=

7
go.work Normal file
View File

@@ -0,0 +1,7 @@
go 1.21
use (
.
./gormlite
./embed/bcw2
)

17
go.work.sum Normal file
View File

@@ -0,0 +1,17 @@
github.com/ncruces/go-sqlite3 v0.21.0/go.mod h1:zxMOaSG5kFYVFK4xQa0pdwIszqxqJ0W0BxBgwdrNjuA=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8=
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/term v0.26.0/go.mod h1:Si5m1o57C5nBNQo5z1iq+XDijt21BDBDp2bK0QI8e3E=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=

View File

@@ -209,12 +209,8 @@ func (d *ddl) renameTable(dst, src string) error {
return nil
}
func compileConstraintRegexp(name string) *regexp.Regexp {
return regexp.MustCompile("^(?i:CONSTRAINT)\\s+[\"`]?" + regexp.QuoteMeta(name) + "[\"`\\s]")
}
func (d *ddl) addConstraint(name string, sql string) {
reg := compileConstraintRegexp(name)
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
@@ -227,7 +223,7 @@ func (d *ddl) addConstraint(name string, sql string) {
}
func (d *ddl) removeConstraint(name string) bool {
reg := compileConstraintRegexp(name)
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
@@ -240,7 +236,7 @@ func (d *ddl) removeConstraint(name string) bool {
//lint:ignore U1000 ignore unused code.
func (d *ddl) hasConstraint(name string) bool {
reg := compileConstraintRegexp(name)
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
for _, f := range d.fields {
if reg.MatchString(f) {

View File

@@ -95,7 +95,7 @@ func parseAllColumns(in string) ([]string, error) {
}
return nil, fmt.Errorf("unexpected token: %s", string(s[i]))
case parseAllColumnsState_State_End:
continue // avoid SA4011
break
}
}
if state != parseAllColumnsState_State_End {

Some files were not shown because too many files have changed in this diff Show More