mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 22:19:14 +00:00
Compare commits
109 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
591480cd39 | ||
|
|
828788912e | ||
|
|
6f8645cd2e | ||
|
|
c00927e8bb | ||
|
|
6b28be6d0e | ||
|
|
310b4ff29d | ||
|
|
e82cf16b11 | ||
|
|
24c9b57c56 | ||
|
|
24b965ac7e | ||
|
|
446168c572 | ||
|
|
a9e2cbbfc5 | ||
|
|
a7c00eb150 | ||
|
|
0bcdb712ba | ||
|
|
2157d0f325 | ||
|
|
6353160619 | ||
|
|
501d157279 | ||
|
|
4db18a7b9a | ||
|
|
a9dddaa86c | ||
|
|
b25936dbec | ||
|
|
bf23041e46 | ||
|
|
d60fceac92 | ||
|
|
61da30f44a | ||
|
|
d4ff605983 | ||
|
|
8d0c654178 | ||
|
|
728e59951b | ||
|
|
f7b16bad5c | ||
|
|
db3e6da31a | ||
|
|
3f443b2ecc | ||
|
|
eec45ea684 | ||
|
|
f6d77f3cf4 | ||
|
|
d5d7cd1f2d | ||
|
|
a33a187d48 | ||
|
|
70c6ee15c6 | ||
|
|
994d9b1812 | ||
|
|
b19bd28ed3 | ||
|
|
e66bd51845 | ||
|
|
f5614bc2ed | ||
|
|
d9fcf60b7d | ||
|
|
ac6dd1aa5f | ||
|
|
b1495bd6cb | ||
|
|
2d91760295 | ||
|
|
38d4254bc4 | ||
|
|
c0aa734786 | ||
|
|
fa845dbd3d | ||
|
|
fed315ab79 | ||
|
|
726d7316f7 | ||
|
|
ddb387b021 | ||
|
|
d0f19507f5 | ||
|
|
9d997552ad | ||
|
|
9d75c39dcc | ||
|
|
746a84965e | ||
|
|
312d3b58f2 | ||
|
|
b71cd295c2 | ||
|
|
5b3b61a304 | ||
|
|
d661d15723 | ||
|
|
1e38165ad0 | ||
|
|
58a32d7c9d | ||
|
|
6765e883c1 | ||
|
|
18fc608433 | ||
|
|
77f37893b9 | ||
|
|
f1e36e2581 | ||
|
|
772b9153c7 | ||
|
|
4b280a3a7e | ||
|
|
19b6098bf6 | ||
|
|
2aa685320f | ||
|
|
9941be05c2 | ||
|
|
a0a9ab7737 | ||
|
|
a77727a1ce | ||
|
|
47fe032078 | ||
|
|
bdfe279444 | ||
|
|
a86937a54e | ||
|
|
6ef422fbde | ||
|
|
ff0cb6fb88 | ||
|
|
72db90efdf | ||
|
|
5a3fdef3c5 | ||
|
|
ff34b0cae1 | ||
|
|
f064492bb1 | ||
|
|
1427d30541 | ||
|
|
d3730341f0 | ||
|
|
78ac2386f6 | ||
|
|
632ea933b3 | ||
|
|
0f7fa6ebc9 | ||
|
|
6f7f776488 | ||
|
|
f6d7c5e9c5 | ||
|
|
1cc7ecfe8d | ||
|
|
3844e81404 | ||
|
|
fec1f8d32a | ||
|
|
31572e6095 | ||
|
|
4aee38b957 | ||
|
|
232a7705b5 | ||
|
|
a6c2fccd74 | ||
|
|
6a982559cd | ||
|
|
c7904d30de | ||
|
|
ce4386604d | ||
|
|
26b62c520d | ||
|
|
738714bf32 | ||
|
|
41b020bafc | ||
|
|
d0e720272b | ||
|
|
76171da12b | ||
|
|
dcc845d684 | ||
|
|
f1b42c26d5 | ||
|
|
1e94407ae7 | ||
|
|
eb8d9b95fd | ||
|
|
04037a75ed | ||
|
|
2472ceb0a0 | ||
|
|
bfe9bfde2e | ||
|
|
f07e82e361 | ||
|
|
fbbbe5a631 | ||
|
|
5ea603ed78 |
29
.github/workflows/bsd.yml
vendored
Normal file
29
.github/workflows/bsd.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
name: BSD
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: macos-12
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: 'true'
|
||||
|
||||
- name: Set up
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: stable
|
||||
|
||||
- name: Build
|
||||
run: GOOS=freebsd go test -c ./...
|
||||
|
||||
- name: Test
|
||||
uses: cross-platform-actions/action@v0.21.1
|
||||
with:
|
||||
operating_system: freebsd
|
||||
version: '13.2'
|
||||
sync_files: runner-to-vm
|
||||
run: find . -name '*.test' -maxdepth 1 -exec {} -test.v \;
|
||||
76
.github/workflows/codeql.yml
vendored
76
.github/workflows/codeql.yml
vendored
@@ -1,76 +0,0 @@
|
||||
# For most projects, this workflow file will not need changing; you simply need
|
||||
# to commit it to your repository.
|
||||
#
|
||||
# You may wish to alter this file to override the set of languages analyzed,
|
||||
# or to provide custom queries or build logic.
|
||||
#
|
||||
# ******** NOTE ********
|
||||
# We have attempted to detect the languages in your repository. Please check
|
||||
# the `language` matrix defined below to confirm you have the correct set of
|
||||
# supported CodeQL languages.
|
||||
#
|
||||
name: "CodeQL"
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
pull_request:
|
||||
# The branches below must be a subset of the branches above
|
||||
branches: [ "main" ]
|
||||
schedule:
|
||||
- cron: '15 18 * * 6'
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: Analyze
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
language: [ 'go' ]
|
||||
# CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
|
||||
# Use only 'java' to analyze code written in Java, Kotlin or both
|
||||
# Use only 'javascript' to analyze code written in JavaScript, TypeScript or both
|
||||
# Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v2
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||
# By default, queries listed here will override any specified in a config file.
|
||||
# Prefix the list here with "+" to use these queries and those in the config file.
|
||||
|
||||
# Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
|
||||
# queries: security-extended,security-and-quality
|
||||
|
||||
|
||||
# Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java).
|
||||
# If this step fails, then you should remove it and run the build manually (see below)
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v2
|
||||
|
||||
# ℹ️ Command-line programs to run using the OS shell.
|
||||
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
|
||||
|
||||
# If the Autobuild fails above, remove it and uncomment the following three lines.
|
||||
# modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance.
|
||||
|
||||
# - run: |
|
||||
# echo "Run, Build Application using script"
|
||||
# ./location_of_script_within_repo/buildscript.sh
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v2
|
||||
with:
|
||||
category: "/language:${{matrix.language}}"
|
||||
22
.github/workflows/cross.sh
vendored
Executable file
22
.github/workflows/cross.sh
vendored
Executable file
@@ -0,0 +1,22 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
echo android ; GOOS=android GOARCH=amd64 go build .
|
||||
echo darwin ; GOOS=darwin GOARCH=amd64 go build .
|
||||
echo dragonfly ; GOOS=dragonfly GOARCH=amd64 go build .
|
||||
echo freebsd ; GOOS=freebsd GOARCH=amd64 go build .
|
||||
echo illumos ; GOOS=illumos GOARCH=amd64 go build .
|
||||
echo ios ; GOOS=ios GOARCH=amd64 go build .
|
||||
echo linux ; GOOS=linux GOARCH=amd64 go build .
|
||||
echo netbsd ; GOOS=netbsd GOARCH=amd64 go build .
|
||||
echo openbsd ; GOOS=openbsd GOARCH=amd64 go build .
|
||||
echo plan9 ; GOOS=plan9 GOARCH=amd64 go build .
|
||||
echo solaris ; GOOS=solaris GOARCH=amd64 go build .
|
||||
echo windows ; GOOS=windows GOARCH=amd64 go build .
|
||||
# echo aix ; GOOS=aix GOARCH=ppc64 go build .
|
||||
echo js ; GOOS=js GOARCH=wasm go build .
|
||||
echo wasip1 ; GOOS=wasip1 GOARCH=wasm go build .
|
||||
echo darwin-flock ; GOOS=darwin GOARCH=amd64 go build -tags sqlite3_flock .
|
||||
echo darwin-nosys ; GOOS=darwin GOARCH=amd64 go build -tags sqlite3_nosys .
|
||||
echo linux-nosys ; GOOS=linux GOARCH=amd64 go build -tags sqlite3_nosys .
|
||||
echo windows-nosys ; GOOS=windows GOARCH=amd64 go build -tags sqlite3_nosys .
|
||||
echo freebsd-nosys ; GOOS=freebsd GOARCH=amd64 go build -tags sqlite3_nosys .
|
||||
21
.github/workflows/cross.yml
vendored
Normal file
21
.github/workflows/cross.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Cross compile
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: 'true'
|
||||
|
||||
- name: Set up
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: stable
|
||||
|
||||
- name: Build
|
||||
run: .github/workflows/cross.sh
|
||||
18
.github/workflows/go.yml
vendored
18
.github/workflows/go.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: 'true'
|
||||
|
||||
@@ -39,7 +39,6 @@ jobs:
|
||||
|
||||
- name: Vet
|
||||
run: go vet ./...
|
||||
continue-on-error: true
|
||||
|
||||
- name: Build
|
||||
run: go build -v ./...
|
||||
@@ -47,16 +46,19 @@ jobs:
|
||||
- name: Test
|
||||
run: go test -v ./...
|
||||
|
||||
- name: Test no locks
|
||||
run: go test -v -tags sqlite3_nosys ./tests -run TestDB_nolock
|
||||
|
||||
- name: Test BSD locks
|
||||
run: go test -v -tags sqlite3_bsd ./...
|
||||
run: go test -v -tags sqlite3_flock ./...
|
||||
if: matrix.os == 'macos-latest'
|
||||
|
||||
- name: Coverage report
|
||||
uses: ncruces/go-coverage-report@v0
|
||||
with:
|
||||
chart: 'true'
|
||||
amend: 'true'
|
||||
chart: true
|
||||
amend: true
|
||||
reuse-go: true
|
||||
if: |
|
||||
matrix.os == 'ubuntu-latest' &&
|
||||
github.event_name == 'push'
|
||||
continue-on-error: true
|
||||
github.event_name == 'push' &&
|
||||
matrix.os == 'ubuntu-latest'
|
||||
|
||||
74
README.md
74
README.md
@@ -7,16 +7,28 @@
|
||||
Go module `github.com/ncruces/go-sqlite3` wraps a [WASM](https://webassembly.org/) build of [SQLite](https://sqlite.org/),
|
||||
and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings.
|
||||
|
||||
- Package [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
|
||||
wraps the [C SQLite API](https://www.sqlite.org/cintro.html)
|
||||
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
|
||||
- Package [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver)
|
||||
provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
|
||||
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
|
||||
- Package [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
|
||||
embeds a build of SQLite into your application.
|
||||
- Package [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
|
||||
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
|
||||
- [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
|
||||
wraps the [C SQLite API](https://www.sqlite.org/cintro.html)
|
||||
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
|
||||
- [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver)
|
||||
provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
|
||||
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
|
||||
- [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
|
||||
embeds a build of SQLite into your application.
|
||||
- [`github.com/ncruces/go-sqlite3/ext/blob`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/blob)
|
||||
simplifies incremental BLOB I/O.
|
||||
- [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
|
||||
registers [statistics functions](https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html).
|
||||
- [`github.com/ncruces/go-sqlite3/ext/unicode`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/unicode)
|
||||
registers Unicode aware functions.
|
||||
- [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
|
||||
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
|
||||
- [`github.com/ncruces/go-sqlite3/vfs/memdb`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/memdb)
|
||||
implements an in-memory VFS.
|
||||
- [`github.com/ncruces/go-sqlite3/vfs/readervfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/readervfs)
|
||||
implements a VFS for immutable databases.
|
||||
- [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
|
||||
provides a [GORM](https://gorm.io) driver.
|
||||
|
||||
### Caveats
|
||||
|
||||
@@ -29,57 +41,65 @@ This has benefits, but also comes with some drawbacks.
|
||||
Because WASM does not support shared memory,
|
||||
[WAL](https://www.sqlite.org/wal.html) support is [limited](https://www.sqlite.org/wal.html#noshm).
|
||||
|
||||
To work around this limitation, SQLite is compiled with
|
||||
[`SQLITE_DEFAULT_LOCKING_MODE=1`](https://www.sqlite.org/compile.html#default_locking_mode),
|
||||
making `EXCLUSIVE` the default locking mode.
|
||||
For non-WAL databases, `NORMAL` locking mode can be activated with
|
||||
[`PRAGMA locking_mode=NORMAL`](https://www.sqlite.org/pragma.html#pragma_locking_mode).
|
||||
To work around this limitation, SQLite is [patched](sqlite3/locking_mode.patch)
|
||||
to always use `EXCLUSIVE` locking mode for WAL databases.
|
||||
|
||||
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
|
||||
the `database/sql` driver defaults to `NORMAL` locking mode.
|
||||
To open WAL databases, or use `EXCLUSIVE` locking mode,
|
||||
disable connection pooling by calling
|
||||
to use the [`database/sql`](https://pkg.go.dev/database/sql) driver
|
||||
with WAL mode databases you should disable connection pooling by calling
|
||||
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
|
||||
|
||||
#### POSIX Advisory Locks
|
||||
#### File Locking
|
||||
|
||||
POSIX advisory locks, which SQLite uses, are
|
||||
[broken by design](https://www.sqlite.org/src/artifact/90c4fa?ln=1073-1161).
|
||||
POSIX advisory locks, which SQLite uses on Unix, are
|
||||
[broken by design](https://www.sqlite.org/src/artifact/2e8b12?ln=1073-1161).
|
||||
|
||||
On Linux, macOS and illumos, this module uses
|
||||
[OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
|
||||
to synchronize access to database files.
|
||||
OFD locks are fully compatible with process-associated POSIX advisory locks.
|
||||
OFD locks are fully compatible with POSIX advisory locks.
|
||||
|
||||
On BSD Unixes, this module uses
|
||||
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2).
|
||||
BSD locks may _not_ be compatible with process-associated POSIX advisory locks.
|
||||
On BSD Unixes, BSD locks are fully compatible with POSIX advisory locks.
|
||||
|
||||
On Windows, this module uses `LockFile`, `LockFileEx`, and `UnlockFile`,
|
||||
like SQLite.
|
||||
|
||||
On all other platforms, file locking is not supported, and you must use
|
||||
[`nolock=1`](https://www.sqlite.org/uri.html#urinolock)
|
||||
to open database files.
|
||||
To use the [`database/sql`](https://pkg.go.dev/database/sql) driver
|
||||
with `nolock=1` you must disable connection pooling by calling
|
||||
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
|
||||
|
||||
#### Testing
|
||||
|
||||
The pure Go VFS is tested by running an unmodified build of SQLite's
|
||||
The pure Go VFS is tested by running SQLite's
|
||||
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
|
||||
on Linux, macOS and Windows.
|
||||
on Linux, macOS, Windows and FreeBSD.
|
||||
Performance is tested by running
|
||||
[speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c).
|
||||
|
||||
### Roadmap
|
||||
|
||||
- [ ] advanced SQLite features
|
||||
- [x] custom functions
|
||||
- [x] nested transactions
|
||||
- [x] incremental BLOB I/O
|
||||
- [x] online backup
|
||||
- [x] JSON support
|
||||
- [ ] virtual tables
|
||||
- [ ] session extension
|
||||
- [ ] custom VFSes
|
||||
- [x] custom VFS API
|
||||
- [x] in-memory VFS
|
||||
- [x] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
|
||||
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
|
||||
- [ ] custom SQL functions
|
||||
|
||||
### Alternatives
|
||||
|
||||
- [`modernc.org/sqlite`](https://pkg.go.dev/modernc.org/sqlite)
|
||||
- [`crawshaw.io/sqlite`](https://pkg.go.dev/crawshaw.io/sqlite)
|
||||
- [`github.com/mattn/go-sqlite3`](https://pkg.go.dev/github.com/mattn/go-sqlite3)
|
||||
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)
|
||||
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)
|
||||
|
||||
@@ -77,7 +77,7 @@ func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string
|
||||
if r == 0 {
|
||||
defer c.closeDB(other)
|
||||
r = c.call(c.api.errcode, uint64(dst))
|
||||
return nil, c.module.error(r, dst)
|
||||
return nil, c.sqlite.error(r, dst)
|
||||
}
|
||||
|
||||
return &Backup{
|
||||
|
||||
7
blob.go
7
blob.go
@@ -118,8 +118,8 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
want := int64(1024 * 1024)
|
||||
avail := b.bytes - b.offset
|
||||
want := int64(65536)
|
||||
if want > avail {
|
||||
want = avail
|
||||
}
|
||||
@@ -175,8 +175,11 @@ func (b *Blob) Write(p []byte) (n int, err error) {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/blob_write.html
|
||||
func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
want := int64(1024 * 1024)
|
||||
avail := b.bytes - b.offset
|
||||
want := int64(65536)
|
||||
if l, ok := r.(*io.LimitedReader); ok && want > l.N {
|
||||
want = l.N
|
||||
}
|
||||
if want > avail {
|
||||
want = avail
|
||||
}
|
||||
|
||||
114
conn.go
114
conn.go
@@ -2,16 +2,14 @@ package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
// Conn is a database connection handle.
|
||||
@@ -19,10 +17,9 @@ import (
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/sqlite3.html
|
||||
type Conn struct {
|
||||
*module
|
||||
*sqlite
|
||||
|
||||
interrupt context.Context
|
||||
waiter chan struct{}
|
||||
pending *Stmt
|
||||
arena arena
|
||||
|
||||
@@ -39,7 +36,7 @@ func Open(filename string) (*Conn, error) {
|
||||
// If none of the required flags is used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used.
|
||||
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
|
||||
//
|
||||
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
|
||||
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)")
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/open.html
|
||||
func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
|
||||
@@ -49,21 +46,24 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
|
||||
return newConn(filename, flags)
|
||||
}
|
||||
|
||||
type connKey struct{}
|
||||
|
||||
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
|
||||
mod, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if conn == nil {
|
||||
mod.close()
|
||||
sqlite.close()
|
||||
} else {
|
||||
runtime.SetFinalizer(conn, util.Finalizer[Conn](3))
|
||||
}
|
||||
}()
|
||||
|
||||
c := &Conn{module: mod}
|
||||
c := &Conn{sqlite: sqlite}
|
||||
c.arena = c.newArena(1024)
|
||||
c.ctx = context.WithValue(c.ctx, connKey{}, c)
|
||||
c.handle, err = c.openDB(filename, flags)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -80,7 +80,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
|
||||
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
|
||||
|
||||
handle := util.ReadUint32(c.mod, connPtr)
|
||||
if err := c.module.error(r, handle); err != nil {
|
||||
if err := c.sqlite.error(r, handle); err != nil {
|
||||
c.closeDB(handle)
|
||||
return 0, err
|
||||
}
|
||||
@@ -99,7 +99,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
|
||||
c.arena.reset()
|
||||
pragmaPtr := c.arena.string(pragmas.String())
|
||||
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
|
||||
if err := c.module.error(r, handle, pragmas.String()); err != nil {
|
||||
if err := c.sqlite.error(r, handle, pragmas.String()); err != nil {
|
||||
if errors.Is(err, ERROR) {
|
||||
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
|
||||
}
|
||||
@@ -113,7 +113,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
|
||||
|
||||
func (c *Conn) closeDB(handle uint32) {
|
||||
r := c.call(c.api.closeZombie, uint64(handle))
|
||||
if err := c.module.error(r, handle); err != nil {
|
||||
if err := c.sqlite.error(r, handle); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -132,7 +132,6 @@ func (c *Conn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.SetInterrupt(context.Background())
|
||||
c.pending.Close()
|
||||
c.pending = nil
|
||||
|
||||
@@ -143,7 +142,7 @@ func (c *Conn) Close() error {
|
||||
|
||||
c.handle = 0
|
||||
runtime.SetFinalizer(c, nil)
|
||||
return c.module.close()
|
||||
return c.close()
|
||||
}
|
||||
|
||||
// Exec is a convenience function that allows an application to run
|
||||
@@ -240,65 +239,45 @@ func (c *Conn) Changes() int64 {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/interrupt.html
|
||||
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
|
||||
// Is a waiter running?
|
||||
if c.waiter != nil {
|
||||
c.waiter <- struct{}{} // Cancel the waiter.
|
||||
<-c.waiter // Wait for it to finish.
|
||||
c.waiter = nil
|
||||
// Is it the same context?
|
||||
if ctx == c.interrupt {
|
||||
return ctx
|
||||
}
|
||||
// Reset the pending statement.
|
||||
if c.pending != nil {
|
||||
|
||||
// An uncompleted SQL statement prevents SQLite from ignoring
|
||||
// an interrupt that comes before any other statements are started.
|
||||
if c.pending == nil {
|
||||
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
|
||||
} else {
|
||||
c.pending.Reset()
|
||||
}
|
||||
|
||||
old = c.interrupt
|
||||
c.interrupt = ctx
|
||||
// Remove the handler if the context can't be canceled.
|
||||
if ctx == nil || ctx.Done() == nil {
|
||||
c.call(c.api.progressHandler, uint64(c.handle), 0)
|
||||
return old
|
||||
}
|
||||
|
||||
// Creating an uncompleted SQL statement prevents SQLite from ignoring
|
||||
// an interrupt that comes before any other statements are started.
|
||||
if c.pending == nil {
|
||||
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
|
||||
}
|
||||
c.pending.Step()
|
||||
|
||||
// Don't create the goroutine if we're already interrupted.
|
||||
// This happens frequently while restoring to a previously interrupted state.
|
||||
if c.checkInterrupt() {
|
||||
return old
|
||||
}
|
||||
|
||||
waiter := make(chan struct{})
|
||||
c.waiter = waiter
|
||||
go func() {
|
||||
select {
|
||||
case <-waiter: // Waiter was cancelled.
|
||||
break
|
||||
|
||||
case <-ctx.Done(): // Done was closed.
|
||||
const isInterruptedOffset = 280
|
||||
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
|
||||
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
|
||||
// Wait for the next call to SetInterrupt.
|
||||
<-waiter
|
||||
}
|
||||
|
||||
// Signal that the waiter has finished.
|
||||
waiter <- struct{}{}
|
||||
}()
|
||||
c.call(c.api.progressHandler, uint64(c.handle), 100)
|
||||
return old
|
||||
}
|
||||
|
||||
func (c *Conn) checkInterrupt() bool {
|
||||
if c.interrupt == nil || c.interrupt.Err() == nil {
|
||||
return false
|
||||
func callbackProgress(ctx context.Context, mod api.Module, _ uint32) uint32 {
|
||||
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
|
||||
if c.interrupt != nil && c.interrupt.Err() != nil {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *Conn) checkInterrupt() {
|
||||
if c.interrupt != nil && c.interrupt.Err() != nil {
|
||||
c.call(c.api.interrupt, uint64(c.handle))
|
||||
}
|
||||
const isInterruptedOffset = 280
|
||||
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
|
||||
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
|
||||
return true
|
||||
}
|
||||
|
||||
// Pragma executes a PRAGMA statement and returns any results.
|
||||
@@ -319,27 +298,14 @@ func (c *Conn) Pragma(str string) ([]string, error) {
|
||||
}
|
||||
|
||||
func (c *Conn) error(rc uint64, sql ...string) error {
|
||||
return c.module.error(rc, c.handle, sql...)
|
||||
return c.sqlite.error(rc, c.handle, sql...)
|
||||
}
|
||||
|
||||
// DriverConn is implemented by the SQLite [database/sql] driver connection.
|
||||
//
|
||||
// It can be used to access advanced SQLite features like
|
||||
// [savepoints], [online backup] and [incremental BLOB I/O].
|
||||
// It can be used to access SQLite features like [online backup].
|
||||
//
|
||||
// [savepoints]: https://www.sqlite.org/lang_savepoint.html
|
||||
// [online backup]: https://www.sqlite.org/backup.html
|
||||
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
|
||||
type DriverConn interface {
|
||||
driver.Conn
|
||||
driver.ConnBeginTx
|
||||
driver.ExecerContext
|
||||
driver.ConnPrepareContext
|
||||
|
||||
SetInterrupt(ctx context.Context) (old context.Context)
|
||||
|
||||
Savepoint() Savepoint
|
||||
Backup(srcDB, dstURI string) error
|
||||
Restore(dstDB, srcURI string) error
|
||||
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
|
||||
Raw() *Conn
|
||||
}
|
||||
|
||||
21
const.go
21
const.go
@@ -97,6 +97,7 @@ const (
|
||||
IOERR_ROLLBACK_ATOMIC ExtendedErrorCode = xErrorCode(IOERR) | (31 << 8)
|
||||
IOERR_DATA ExtendedErrorCode = xErrorCode(IOERR) | (32 << 8)
|
||||
IOERR_CORRUPTFS ExtendedErrorCode = xErrorCode(IOERR) | (33 << 8)
|
||||
IOERR_IN_PAGE ExtendedErrorCode = xErrorCode(IOERR) | (34 << 8)
|
||||
LOCKED_SHAREDCACHE ExtendedErrorCode = xErrorCode(LOCKED) | (1 << 8)
|
||||
LOCKED_VTAB ExtendedErrorCode = xErrorCode(LOCKED) | (2 << 8)
|
||||
BUSY_RECOVERY ExtendedErrorCode = xErrorCode(BUSY) | (1 << 8)
|
||||
@@ -167,6 +168,18 @@ const (
|
||||
PREPARE_NO_VTAB PrepareFlag = 0x04
|
||||
)
|
||||
|
||||
// FunctionFlag is a flag that can be passed to [Conn.PrepareFlags].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_deterministic.html
|
||||
type FunctionFlag uint32
|
||||
|
||||
const (
|
||||
DETERMINISTIC FunctionFlag = 0x000000800
|
||||
DIRECTONLY FunctionFlag = 0x000080000
|
||||
SUBTYPE FunctionFlag = 0x000100000
|
||||
INNOCUOUS FunctionFlag = 0x000200000
|
||||
)
|
||||
|
||||
// Datatype is a fundamental datatype of SQLite.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_blob.html
|
||||
@@ -182,18 +195,18 @@ const (
|
||||
|
||||
// String implements the [fmt.Stringer] interface.
|
||||
func (t Datatype) String() string {
|
||||
const name = "INTEGERFLOATTEXTBLOBNULL"
|
||||
const name = "INTEGERFLOATEXTBLOBNULL"
|
||||
switch t {
|
||||
case INTEGER:
|
||||
return name[0:7]
|
||||
case FLOAT:
|
||||
return name[7:12]
|
||||
case TEXT:
|
||||
return name[12:16]
|
||||
return name[11:15]
|
||||
case BLOB:
|
||||
return name[16:20]
|
||||
return name[15:19]
|
||||
case NULL:
|
||||
return name[20:24]
|
||||
return name[19:23]
|
||||
}
|
||||
return strconv.FormatUint(uint64(t), 10)
|
||||
}
|
||||
|
||||
219
context.go
Normal file
219
context.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Context is the context in which an SQL function executes.
|
||||
// An SQLite [Context] is in no way related to a Go [context.Context].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/context.html
|
||||
type Context struct {
|
||||
c *Conn
|
||||
handle uint32
|
||||
}
|
||||
|
||||
// Conn returns the database connection of the
|
||||
// [Conn.CreateFunction] or [Conn.CreateWindowFunction]
|
||||
// routines that originally registered the application defined function.
|
||||
//
|
||||
// https://sqlite.org/c3ref/context_db_handle.html
|
||||
func (ctx Context) Conn() *Conn {
|
||||
return ctx.c
|
||||
}
|
||||
|
||||
// SetAuxData saves metadata for argument n of the function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/get_auxdata.html
|
||||
func (ctx Context) SetAuxData(n int, data any) {
|
||||
ptr := util.AddHandle(ctx.c.ctx, data)
|
||||
ctx.c.call(ctx.c.api.setAuxData, uint64(ctx.handle), uint64(n), uint64(ptr))
|
||||
}
|
||||
|
||||
// GetAuxData returns metadata for argument n of the function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/get_auxdata.html
|
||||
func (ctx Context) GetAuxData(n int) any {
|
||||
ptr := uint32(ctx.c.call(ctx.c.api.getAuxData, uint64(ctx.handle), uint64(n)))
|
||||
return util.GetHandle(ctx.c.ctx, ptr)
|
||||
}
|
||||
|
||||
// ResultBool sets the result of the function to a bool.
|
||||
// SQLite does not have a separate boolean storage class.
|
||||
// Instead, boolean values are stored as integers 0 (false) and 1 (true).
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultBool(value bool) {
|
||||
var i int64
|
||||
if value {
|
||||
i = 1
|
||||
}
|
||||
ctx.ResultInt64(i)
|
||||
}
|
||||
|
||||
// ResultInt sets the result of the function to an int.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultInt(value int) {
|
||||
ctx.ResultInt64(int64(value))
|
||||
}
|
||||
|
||||
// ResultInt64 sets the result of the function to an int64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultInt64(value int64) {
|
||||
ctx.c.call(ctx.c.api.resultInteger,
|
||||
uint64(ctx.handle), uint64(value))
|
||||
}
|
||||
|
||||
// ResultFloat sets the result of the function to a float64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultFloat(value float64) {
|
||||
ctx.c.call(ctx.c.api.resultFloat,
|
||||
uint64(ctx.handle), math.Float64bits(value))
|
||||
}
|
||||
|
||||
// ResultText sets the result of the function to a string.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultText(value string) {
|
||||
ptr := ctx.c.newString(value)
|
||||
ctx.c.call(ctx.c.api.resultText,
|
||||
uint64(ctx.handle), uint64(ptr), uint64(len(value)),
|
||||
uint64(ctx.c.api.destructor), _UTF8)
|
||||
}
|
||||
|
||||
// ResultBlob sets the result of the function to a []byte.
|
||||
// Returning a nil slice is the same as calling [Context.ResultNull].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultBlob(value []byte) {
|
||||
ptr := ctx.c.newBytes(value)
|
||||
ctx.c.call(ctx.c.api.resultBlob,
|
||||
uint64(ctx.handle), uint64(ptr), uint64(len(value)),
|
||||
uint64(ctx.c.api.destructor))
|
||||
}
|
||||
|
||||
// BindZeroBlob sets the result of the function to a zero-filled, length n BLOB.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultZeroBlob(n int64) {
|
||||
ctx.c.call(ctx.c.api.resultZeroBlob,
|
||||
uint64(ctx.handle), uint64(n))
|
||||
}
|
||||
|
||||
// ResultNull sets the result of the function to NULL.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultNull() {
|
||||
ctx.c.call(ctx.c.api.resultNull,
|
||||
uint64(ctx.handle))
|
||||
}
|
||||
|
||||
// ResultTime sets the result of the function to a [time.Time].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultTime(value time.Time, format TimeFormat) {
|
||||
if format == TimeFormatDefault {
|
||||
ctx.resultRFC3339Nano(value)
|
||||
return
|
||||
}
|
||||
switch v := format.Encode(value).(type) {
|
||||
case string:
|
||||
ctx.ResultText(v)
|
||||
case int64:
|
||||
ctx.ResultInt64(v)
|
||||
case float64:
|
||||
ctx.ResultFloat(v)
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx Context) resultRFC3339Nano(value time.Time) {
|
||||
const maxlen = uint64(len(time.RFC3339Nano)) + 5
|
||||
|
||||
ptr := ctx.c.new(maxlen)
|
||||
buf := util.View(ctx.c.mod, ptr, maxlen)
|
||||
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
|
||||
|
||||
ctx.c.call(ctx.c.api.resultText,
|
||||
uint64(ctx.handle), uint64(ptr), uint64(len(buf)),
|
||||
uint64(ctx.c.api.destructor), _UTF8)
|
||||
}
|
||||
|
||||
// ResultPointer sets the result of the function to NULL, just like [Context.ResultNull],
|
||||
// except that it also associates ptr with that NULL value such that it can be retrieved
|
||||
// within an application-defined SQL function using [Value.Pointer].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultPointer(ptr any) {
|
||||
valPtr := util.AddHandle(ctx.c.ctx, ptr)
|
||||
ctx.c.call(ctx.c.api.resultPointer, uint64(valPtr))
|
||||
}
|
||||
|
||||
// ResultJSON sets the result of the function to the JSON encoding of value.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultJSON(value any) {
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
}
|
||||
ptr := ctx.c.newBytes(data)
|
||||
ctx.c.call(ctx.c.api.resultText,
|
||||
uint64(ctx.handle), uint64(ptr), uint64(len(data)),
|
||||
uint64(ctx.c.api.destructor))
|
||||
}
|
||||
|
||||
// ResultValue sets the result of the function a copy of [Value].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultValue(value Value) {
|
||||
if value.sqlite != ctx.c.sqlite {
|
||||
ctx.ResultError(MISUSE)
|
||||
}
|
||||
ctx.c.call(ctx.c.api.resultValue,
|
||||
uint64(ctx.handle), uint64(value.handle))
|
||||
}
|
||||
|
||||
// ResultError sets the result of the function an error.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultError(err error) {
|
||||
if errors.Is(err, NOMEM) {
|
||||
ctx.c.call(ctx.c.api.resultErrorMem, uint64(ctx.handle))
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, TOOBIG) {
|
||||
ctx.c.call(ctx.c.api.resultErrorBig, uint64(ctx.handle))
|
||||
return
|
||||
}
|
||||
|
||||
str := err.Error()
|
||||
ptr := ctx.c.newString(str)
|
||||
ctx.c.call(ctx.c.api.resultError,
|
||||
uint64(ctx.handle), uint64(ptr), uint64(len(str)))
|
||||
ctx.c.free(ptr)
|
||||
|
||||
var code uint64
|
||||
var ecode ErrorCode
|
||||
var xcode xErrorCode
|
||||
switch {
|
||||
case errors.As(err, &xcode):
|
||||
code = uint64(xcode)
|
||||
case errors.As(err, &ecode):
|
||||
code = uint64(ecode)
|
||||
}
|
||||
if code != 0 {
|
||||
ctx.c.call(ctx.c.api.resultErrorCode,
|
||||
uint64(ctx.handle), code)
|
||||
}
|
||||
}
|
||||
203
driver/driver.go
203
driver/driver.go
@@ -14,10 +14,9 @@
|
||||
//
|
||||
// [PRAGMA] statements can be specified using "_pragma":
|
||||
//
|
||||
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
|
||||
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)")
|
||||
//
|
||||
// If no PRAGMAs are specified, a busy timeout of 1 minute
|
||||
// and normal locking mode are used.
|
||||
// If no PRAGMAs are specified, a busy timeout of 1 minute is set.
|
||||
//
|
||||
// Order matters:
|
||||
// busy timeout and locking mode should be the first PRAGMAs set, in that order.
|
||||
@@ -31,6 +30,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
@@ -41,64 +41,117 @@ import (
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// This variable can be replaced with -ldflags:
|
||||
//
|
||||
// go build -ldflags="-X github.com/ncruces/go-sqlite3.driverName=sqlite"
|
||||
var driverName = "sqlite3"
|
||||
|
||||
func init() {
|
||||
sql.Register("sqlite3", sqlite{})
|
||||
if driverName != "" {
|
||||
sql.Register(driverName, sqlite{})
|
||||
}
|
||||
}
|
||||
|
||||
// Open opens the SQLite database specified by dataSourceName as a [database/sql.DB].
|
||||
//
|
||||
// The init function is called by the driver on new connections.
|
||||
// The conn can be used to execute queries, register functions, etc.
|
||||
// Any error return closes the conn and passes the error to [database/sql].
|
||||
func Open(dataSourceName string, init func(*sqlite3.Conn) error) (*sql.DB, error) {
|
||||
c, err := newConnector(dataSourceName, init)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sql.OpenDB(c), nil
|
||||
}
|
||||
|
||||
type sqlite struct{}
|
||||
|
||||
func (sqlite) Open(name string) (_ driver.Conn, err error) {
|
||||
var c conn
|
||||
c.Conn, err = sqlite3.Open(name)
|
||||
func (sqlite) Open(name string) (driver.Conn, error) {
|
||||
c, err := newConnector(name, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.Connect(context.Background())
|
||||
}
|
||||
|
||||
var pragmas []string
|
||||
func (sqlite) OpenConnector(name string) (driver.Connector, error) {
|
||||
return newConnector(name, nil)
|
||||
}
|
||||
|
||||
func newConnector(name string, init func(*sqlite3.Conn) error) (*connector, error) {
|
||||
c := connector{name: name, init: init}
|
||||
if strings.HasPrefix(name, "file:") {
|
||||
if _, after, ok := strings.Cut(name, "?"); ok {
|
||||
query, _ := url.ParseQuery(after)
|
||||
|
||||
switch s := query.Get("_txlock"); s {
|
||||
case "":
|
||||
c.txBegin = "BEGIN"
|
||||
case "deferred", "immediate", "exclusive":
|
||||
c.txBegin = "BEGIN " + s
|
||||
default:
|
||||
c.Close()
|
||||
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
|
||||
query, err := url.ParseQuery(after)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pragmas = query["_pragma"]
|
||||
c.txlock = query.Get("_txlock")
|
||||
c.pragmas = len(query["_pragma"]) > 0
|
||||
}
|
||||
}
|
||||
if len(pragmas) == 0 {
|
||||
err := c.Conn.Exec(`
|
||||
PRAGMA busy_timeout=60000;
|
||||
PRAGMA locking_mode=normal;
|
||||
`)
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
type connector struct {
|
||||
init func(*sqlite3.Conn) error
|
||||
name string
|
||||
txlock string
|
||||
pragmas bool
|
||||
}
|
||||
|
||||
func (n *connector) Driver() driver.Driver {
|
||||
return sqlite{}
|
||||
}
|
||||
|
||||
func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
|
||||
var c conn
|
||||
c.Conn, err = sqlite3.Open(n.name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
c.reusable = true
|
||||
} else {
|
||||
s, _, err := c.Conn.Prepare(`
|
||||
SELECT * FROM
|
||||
PRAGMA_locking_mode,
|
||||
PRAGMA_query_only;
|
||||
`)
|
||||
}()
|
||||
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
|
||||
switch n.txlock {
|
||||
case "":
|
||||
c.txBegin = "BEGIN"
|
||||
case "deferred", "immediate", "exclusive":
|
||||
c.txBegin = "BEGIN " + n.txlock
|
||||
default:
|
||||
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", n.txlock)
|
||||
}
|
||||
if !n.pragmas {
|
||||
err = c.Conn.Exec(`PRAGMA busy_timeout=60000`)
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
if s.Step() {
|
||||
c.reusable = s.ColumnText(0) == "normal"
|
||||
c.readOnly = s.ColumnRawText(1)[0] // 0 or 1
|
||||
}
|
||||
if n.init != nil {
|
||||
err = n.init(c.Conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if n.pragmas || n.init != nil {
|
||||
s, _, err := c.Conn.Prepare(`PRAGMA query_only`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.Step() && s.ColumnBool(0) {
|
||||
c.readOnly = '1'
|
||||
} else {
|
||||
c.readOnly = '0'
|
||||
}
|
||||
err = s.Close()
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -110,20 +163,19 @@ type conn struct {
|
||||
txBegin string
|
||||
txCommit string
|
||||
txRollback string
|
||||
reusable bool
|
||||
readOnly byte
|
||||
}
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ driver.ExecerContext = &conn{}
|
||||
_ driver.ConnBeginTx = &conn{}
|
||||
_ driver.Validator = &conn{}
|
||||
_ sqlite3.DriverConn = &conn{}
|
||||
_ driver.ConnPrepareContext = &conn{}
|
||||
_ driver.ExecerContext = &conn{}
|
||||
_ driver.ConnBeginTx = &conn{}
|
||||
_ sqlite3.DriverConn = &conn{}
|
||||
)
|
||||
|
||||
func (c *conn) IsValid() bool {
|
||||
return c.reusable
|
||||
func (c *conn) Raw() *sqlite3.Conn {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *conn) Begin() (driver.Tx, error) {
|
||||
@@ -139,10 +191,10 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
|
||||
txBegin = `
|
||||
BEGIN deferred;
|
||||
PRAGMA query_only=on`
|
||||
c.txCommit = `
|
||||
c.txRollback = `
|
||||
ROLLBACK;
|
||||
PRAGMA query_only=` + string(c.readOnly)
|
||||
c.txRollback = c.txCommit
|
||||
c.txCommit = c.txRollback
|
||||
}
|
||||
|
||||
switch opts.Isolation {
|
||||
@@ -166,14 +218,20 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
|
||||
|
||||
func (c *conn) Commit() error {
|
||||
err := c.Conn.Exec(c.txCommit)
|
||||
if err != nil && !c.GetAutocommit() {
|
||||
if err != nil && !c.Conn.GetAutocommit() {
|
||||
c.Rollback()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) Rollback() error {
|
||||
return c.Conn.Exec(c.txRollback)
|
||||
err := c.Conn.Exec(c.txRollback)
|
||||
if errors.Is(err, sqlite3.INTERRUPT) {
|
||||
old := c.Conn.SetInterrupt(context.Background())
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
err = c.Conn.Exec(c.txRollback)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) Prepare(query string) (driver.Stmt, error) {
|
||||
@@ -210,6 +268,12 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
|
||||
return nil, driver.ErrSkip
|
||||
}
|
||||
|
||||
if savept, ok := ctx.(*saveptCtx); ok {
|
||||
// Called from driver.Savepoint.
|
||||
savept.Savepoint = c.Savepoint()
|
||||
return resultRowsAffected(0), nil
|
||||
}
|
||||
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
|
||||
@@ -221,6 +285,10 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
|
||||
return newResult(c.Conn), nil
|
||||
}
|
||||
|
||||
func (*conn) CheckNamedValue(arg *driver.NamedValue) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type stmt struct {
|
||||
Stmt *sqlite3.Stmt
|
||||
Conn *sqlite3.Conn
|
||||
@@ -258,13 +326,14 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
}
|
||||
|
||||
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
// Use QueryContext to setup bindings.
|
||||
// No need to close rows: that simply resets the statement, exec does the same.
|
||||
_, err := s.QueryContext(ctx, args)
|
||||
err := s.setupBindings(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
old := s.Conn.SetInterrupt(ctx)
|
||||
defer s.Conn.SetInterrupt(old)
|
||||
|
||||
err = s.Stmt.Exec()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -274,10 +343,18 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
}
|
||||
|
||||
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
err := s.Stmt.ClearBindings()
|
||||
err := s.setupBindings(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rows{ctx, s.Stmt, s.Conn}, nil
|
||||
}
|
||||
|
||||
func (s *stmt) setupBindings(args []driver.NamedValue) error {
|
||||
err := s.Stmt.ClearBindings()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ids [3]int
|
||||
for _, arg := range args {
|
||||
@@ -310,6 +387,10 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
|
||||
err = s.Stmt.BindZeroBlob(id, int64(a))
|
||||
case time.Time:
|
||||
err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
|
||||
case interface{ Pointer() any }:
|
||||
err = s.Stmt.BindPointer(id, a.Pointer())
|
||||
case interface{ JSON() any }:
|
||||
err = s.Stmt.BindJSON(id, a.JSON())
|
||||
case nil:
|
||||
err = s.Stmt.BindNull(id)
|
||||
default:
|
||||
@@ -317,17 +398,19 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return &rows{ctx, s.Stmt, s.Conn}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
|
||||
switch arg.Value.(type) {
|
||||
case bool, int, int64, float64, string, []byte,
|
||||
sqlite3.ZeroBlob, time.Time, nil:
|
||||
sqlite3.ZeroBlob, time.Time,
|
||||
interface{ Pointer() any },
|
||||
interface{ JSON() any },
|
||||
nil:
|
||||
return nil
|
||||
default:
|
||||
return driver.ErrSkip
|
||||
@@ -406,11 +489,7 @@ func (r *rows) Next(dest []driver.Value) error {
|
||||
case sqlite3.TEXT:
|
||||
dest[i] = stringOrTime(r.Stmt.ColumnRawText(i))
|
||||
case sqlite3.NULL:
|
||||
if buf, ok := dest[i].([]byte); ok {
|
||||
dest[i] = buf[0:0]
|
||||
} else {
|
||||
dest[i] = nil
|
||||
}
|
||||
dest[i] = nil
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
|
||||
65
driver/json_test.go
Normal file
65
driver/json_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package driver_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
func Example_json() {
|
||||
db, err := driver.Open("file:/test.db?vfs=memdb", nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE orders (
|
||||
cart_id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
cart TEXT
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
type CartItem struct {
|
||||
ItemID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Quantity int `json:"quantity,omitempty"`
|
||||
Price int `json:"price,omitempty"`
|
||||
}
|
||||
|
||||
type Cart struct {
|
||||
Items []CartItem `json:"items"`
|
||||
}
|
||||
|
||||
_, err = db.Exec(`INSERT INTO orders (user_id, cart) VALUES (?, ?)`, 123, sqlite3.JSON(Cart{
|
||||
[]CartItem{
|
||||
{ItemID: "111", Name: "T-shirt", Quantity: 1, Price: 250},
|
||||
{ItemID: "222", Name: "Trousers", Quantity: 1, Price: 600},
|
||||
},
|
||||
}))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var total string
|
||||
err = db.QueryRow(`
|
||||
SELECT total(json_each.value -> 'price')
|
||||
FROM orders, json_each(cart -> 'items')
|
||||
WHERE cart_id = last_insert_rowid()
|
||||
`).Scan(&total)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println("total:", total)
|
||||
// Output:
|
||||
// total: 850
|
||||
}
|
||||
27
driver/savepoint.go
Normal file
27
driver/savepoint.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package driver
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
// Savepoint establishes a new transaction savepoint.
|
||||
//
|
||||
// https://www.sqlite.org/lang_savepoint.html
|
||||
func Savepoint(tx *sql.Tx) sqlite3.Savepoint {
|
||||
var ctx saveptCtx
|
||||
tx.ExecContext(&ctx, "")
|
||||
return ctx.Savepoint
|
||||
}
|
||||
|
||||
type saveptCtx struct{ sqlite3.Savepoint }
|
||||
|
||||
func (*saveptCtx) Deadline() (deadline time.Time, ok bool) { return }
|
||||
|
||||
func (*saveptCtx) Done() <-chan struct{} { return nil }
|
||||
|
||||
func (*saveptCtx) Err() error { return nil }
|
||||
|
||||
func (*saveptCtx) Value(key any) any { return nil }
|
||||
87
driver/savepoint_test.go
Normal file
87
driver/savepoint_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package driver_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
func ExampleSavepoint() {
|
||||
db, err := driver.Open("file:/test.db?vfs=memdb", nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = func() error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt, err := tx.Prepare(`INSERT INTO users (id, name) VALUES (?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec(0, "go")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = stmt.Exec(1, "zig")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
savept := driver.Savepoint(tx)
|
||||
|
||||
_, err = stmt.Exec(2, "whatever")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = savept.Rollback()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = stmt.Exec(3, "rust")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := db.Query(`SELECT id, name FROM users`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var id, name string
|
||||
err = rows.Scan(&id, &name)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("%s %s\n", id, name)
|
||||
}
|
||||
// Output:
|
||||
// 0 go
|
||||
// 1 zig
|
||||
// 3 rust
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
var db *sql.DB
|
||||
|
||||
func ExampleDriverConn() {
|
||||
var err error
|
||||
db, err = sql.Open("sqlite3", "demo.db")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer os.Remove("demo.db")
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
conn, err := db.Conn(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
res, err := conn.ExecContext(ctx, `INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(11))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = conn.Raw(func(driverConn any) error {
|
||||
conn := driverConn.(sqlite3.DriverConn)
|
||||
savept := conn.Savepoint()
|
||||
defer savept.Release(&err)
|
||||
|
||||
blob, err := conn.OpenBlob("main", "test", "col", id, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer blob.Close()
|
||||
|
||||
_, err = fmt.Fprint(blob, "Hello BLOB!")
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var msg string
|
||||
err = conn.QueryRowContext(ctx, `SELECT col FROM test`).Scan(&msg)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println(msg)
|
||||
// Output:
|
||||
// Hello BLOB!
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
# Embeddable WASM build of SQLite
|
||||
|
||||
This folder includes an embeddable WASM build of SQLite 3.42.0 for use with
|
||||
This folder includes an embeddable WASM build of SQLite 3.44.0 for use with
|
||||
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
|
||||
|
||||
The following optional features are compiled in:
|
||||
@@ -9,6 +9,7 @@ The following optional features are compiled in:
|
||||
- [JSON](https://www.sqlite.org/json1.html)
|
||||
- [R*Tree](https://www.sqlite.org/rtree.html)
|
||||
- [GeoPoly](https://www.sqlite.org/geopoly.html)
|
||||
- [soundex](https://www.sqlite.org/lang_corefunc.html#soundex)
|
||||
- [base64](https://github.com/sqlite/sqlite/blob/master/ext/misc/base64.c)
|
||||
- [decimal](https://github.com/sqlite/sqlite/blob/master/ext/misc/decimal.c)
|
||||
- [regexp](https://github.com/sqlite/sqlite/blob/master/ext/misc/regexp.c)
|
||||
|
||||
@@ -4,24 +4,27 @@ set -euo pipefail
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
ROOT=../
|
||||
BINARYEN="$ROOT/tools/binaryen-version_113/bin"
|
||||
BINARYEN="$ROOT/tools/binaryen-version_116/bin"
|
||||
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
|
||||
|
||||
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
|
||||
-o sqlite3.wasm "$ROOT/sqlite3/main.c" \
|
||||
-I"$ROOT/sqlite3" \
|
||||
-mexec-model=reactor \
|
||||
-mmutable-globals \
|
||||
-msimd128 -mmutable-globals \
|
||||
-mbulk-memory -mreference-types \
|
||||
-mnontrapping-fptoint -msign-ext \
|
||||
-fno-stack-protector -fno-stack-clash-protection \
|
||||
-Wl,--initial-memory=327680 \
|
||||
-Wl,--stack-first \
|
||||
-Wl,--import-undefined \
|
||||
-D_HAVE_SQLITE_CONFIG_H \
|
||||
$(awk '{print "-Wl,--export="$0}' exports.txt)
|
||||
|
||||
trap 'rm -f sqlite3.tmp' EXIT
|
||||
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
|
||||
"$BINARYEN/wasm-opt" -g -O2 sqlite3.tmp -o sqlite3.wasm \
|
||||
--enable-multivalue --enable-mutable-globals \
|
||||
"$BINARYEN/wasm-opt" -g --strip -c -O3 \
|
||||
sqlite3.tmp -o sqlite3.wasm \
|
||||
--enable-simd --enable-mutable-globals --enable-multivalue \
|
||||
--enable-bulk-memory --enable-reference-types \
|
||||
--enable-nontrapping-float-to-int --enable-sign-ext
|
||||
@@ -13,6 +13,8 @@ sqlite3_finalize
|
||||
sqlite3_reset
|
||||
sqlite3_step
|
||||
sqlite3_exec
|
||||
sqlite3_interrupt
|
||||
sqlite3_progress_handler_go
|
||||
sqlite3_clear_bindings
|
||||
sqlite3_bind_parameter_count
|
||||
sqlite3_bind_parameter_index
|
||||
@@ -23,6 +25,7 @@ sqlite3_bind_double
|
||||
sqlite3_bind_text64
|
||||
sqlite3_bind_blob64
|
||||
sqlite3_bind_zeroblob64
|
||||
sqlite3_bind_pointer_go
|
||||
sqlite3_column_count
|
||||
sqlite3_column_name
|
||||
sqlite3_column_type
|
||||
@@ -33,10 +36,10 @@ sqlite3_column_blob
|
||||
sqlite3_column_bytes
|
||||
sqlite3_blob_open
|
||||
sqlite3_blob_close
|
||||
sqlite3_blob_reopen
|
||||
sqlite3_blob_bytes
|
||||
sqlite3_blob_read
|
||||
sqlite3_blob_write
|
||||
sqlite3_blob_reopen
|
||||
sqlite3_backup_init
|
||||
sqlite3_backup_step
|
||||
sqlite3_backup_finish
|
||||
@@ -46,4 +49,32 @@ sqlite3_uri_parameter
|
||||
sqlite3_uri_key
|
||||
sqlite3_changes64
|
||||
sqlite3_last_insert_rowid
|
||||
sqlite3_get_autocommit
|
||||
sqlite3_get_autocommit
|
||||
sqlite3_anycollseq_init
|
||||
sqlite3_create_collation_go
|
||||
sqlite3_create_function_go
|
||||
sqlite3_create_aggregate_function_go
|
||||
sqlite3_create_window_function_go
|
||||
sqlite3_aggregate_context
|
||||
sqlite3_user_data
|
||||
sqlite3_set_auxdata_go
|
||||
sqlite3_get_auxdata
|
||||
sqlite3_value_type
|
||||
sqlite3_value_int64
|
||||
sqlite3_value_double
|
||||
sqlite3_value_text
|
||||
sqlite3_value_blob
|
||||
sqlite3_value_bytes
|
||||
sqlite3_value_pointer_go
|
||||
sqlite3_result_null
|
||||
sqlite3_result_int64
|
||||
sqlite3_result_double
|
||||
sqlite3_result_text64
|
||||
sqlite3_result_blob64
|
||||
sqlite3_result_zeroblob64
|
||||
sqlite3_result_pointer_go
|
||||
sqlite3_result_value
|
||||
sqlite3_result_error
|
||||
sqlite3_result_error_code
|
||||
sqlite3_result_error_nomem
|
||||
sqlite3_result_error_toobig
|
||||
Binary file not shown.
22
error.go
22
error.go
@@ -68,6 +68,19 @@ func (e *Error) Is(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// As converts this error to an [ErrorCode] or [ExtendedErrorCode].
|
||||
func (e *Error) As(err any) bool {
|
||||
switch c := err.(type) {
|
||||
case *ErrorCode:
|
||||
*c = e.Code()
|
||||
return true
|
||||
case *ExtendedErrorCode:
|
||||
*c = e.ExtendedCode()
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Temporary returns true for [BUSY] errors.
|
||||
func (e *Error) Temporary() bool {
|
||||
return e.Code() == BUSY
|
||||
@@ -104,6 +117,15 @@ func (e ExtendedErrorCode) Is(err error) bool {
|
||||
return ok && c == ErrorCode(e)
|
||||
}
|
||||
|
||||
// As converts this error to an [ErrorCode].
|
||||
func (e ExtendedErrorCode) As(err any) bool {
|
||||
c, ok := err.(*ErrorCode)
|
||||
if ok {
|
||||
*c = ErrorCode(e)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// Temporary returns true for [BUSY] errors.
|
||||
func (e ExtendedErrorCode) Temporary() bool {
|
||||
return ErrorCode(e) == BUSY
|
||||
|
||||
@@ -18,22 +18,36 @@ func Test_assertErr(t *testing.T) {
|
||||
func TestError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := Error{code: 0x8080}
|
||||
if rc := err.Code(); rc != 0x80 {
|
||||
t.Errorf("got %#x, want 0x80", rc)
|
||||
var ecode ErrorCode
|
||||
var xcode xErrorCode
|
||||
err := &Error{code: 0x8080}
|
||||
if !errors.As(err, &err) {
|
||||
t.Fatal("want true")
|
||||
}
|
||||
if !errors.Is(&err, ErrorCode(0x80)) {
|
||||
if ecode := err.Code(); ecode != 0x80 {
|
||||
t.Errorf("got %#x, want 0x80", uint8(ecode))
|
||||
}
|
||||
if ok := errors.As(err, &ecode); !ok || ecode != ErrorCode(0x80) {
|
||||
t.Errorf("got %#x, want 0x80", uint8(ecode))
|
||||
}
|
||||
if !errors.Is(err, ErrorCode(0x80)) {
|
||||
t.Errorf("want true")
|
||||
}
|
||||
if rc := err.ExtendedCode(); rc != 0x8080 {
|
||||
t.Errorf("got %#x, want 0x8080", rc)
|
||||
if xcode := err.ExtendedCode(); xcode != 0x8080 {
|
||||
t.Errorf("got %#x, want 0x8080", uint16(xcode))
|
||||
}
|
||||
if !errors.Is(&err, ExtendedErrorCode(0x8080)) {
|
||||
if ok := errors.As(err, &xcode); !ok || xcode != xErrorCode(0x8080) {
|
||||
t.Errorf("got %#x, want 0x8080", uint16(xcode))
|
||||
}
|
||||
if !errors.Is(err, xErrorCode(0x8080)) {
|
||||
t.Errorf("want true")
|
||||
}
|
||||
if s := err.Error(); s != "sqlite3: 32896" {
|
||||
t.Errorf("got %q", s)
|
||||
}
|
||||
if ok := errors.As(err.ExtendedCode(), &ecode); !ok || ecode != ErrorCode(0x80) {
|
||||
t.Errorf("got %#x, want 0x80", uint8(ecode))
|
||||
}
|
||||
if !errors.Is(err.ExtendedCode(), ErrorCode(0x80)) {
|
||||
t.Errorf("want true")
|
||||
}
|
||||
|
||||
59
ext/blob/blob.go
Normal file
59
ext/blob/blob.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Package blob provides an alternative interface to incremental BLOB I/O.
|
||||
package blob
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
// Register registers the blob_open SQL function.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
db.CreateFunction("blob_open", -1,
|
||||
sqlite3.DETERMINISTIC|sqlite3.DIRECTONLY, openBlob)
|
||||
}
|
||||
|
||||
func openBlob(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if len(arg) < 6 {
|
||||
ctx.ResultError(errors.New("wrong number of arguments to function blob_open()"))
|
||||
return
|
||||
}
|
||||
|
||||
row := arg[3].Int64()
|
||||
|
||||
var err error
|
||||
blob, ok := ctx.GetAuxData(0).(*sqlite3.Blob)
|
||||
if ok {
|
||||
err = blob.Reopen(row)
|
||||
if errors.Is(err, sqlite3.MISUSE) {
|
||||
// Blob was closed (db, table or column changed).
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
|
||||
if !ok {
|
||||
db := arg[0].Text()
|
||||
table := arg[1].Text()
|
||||
column := arg[2].Text()
|
||||
write := arg[4].Bool()
|
||||
blob, err = ctx.Conn().OpenBlob(db, table, column, row, write)
|
||||
}
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
|
||||
fn := arg[5].Pointer().(OpenCallback)
|
||||
err = fn(blob, arg[6:]...)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
|
||||
// This ensures the blob is closed if db, table or column change.
|
||||
ctx.SetAuxData(0, blob)
|
||||
ctx.SetAuxData(1, blob)
|
||||
ctx.SetAuxData(2, blob)
|
||||
}
|
||||
|
||||
type OpenCallback func(*sqlite3.Blob, ...sqlite3.Value) error
|
||||
61
ext/blob/blob_test.go
Normal file
61
ext/blob/blob_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package blob_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/ext/blob"
|
||||
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
func Example() {
|
||||
// Open the database, registering the extension.
|
||||
db, err := driver.Open("file:/test.db?vfs=memdb", func(conn *sqlite3.Conn) error {
|
||||
blob.Register(conn)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
const message = "Hello BLOB!"
|
||||
|
||||
// Create the BLOB.
|
||||
_, err = db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(len(message)))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Write the BLOB.
|
||||
_, err = db.Exec(`SELECT blob_open('main', 'test', 'col', last_insert_rowid(), true, ?)`,
|
||||
sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
|
||||
_, err = io.WriteString(blob, message)
|
||||
return err
|
||||
}))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the BLOB.
|
||||
_, err = db.Exec(`SELECT blob_open('main', 'test', 'col', rowid, false, ?) FROM test`,
|
||||
sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
|
||||
_, err = io.Copy(os.Stdout, blob)
|
||||
return err
|
||||
}))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Output:
|
||||
// Hello BLOB!
|
||||
}
|
||||
109
ext/stats/stats.go
Normal file
109
ext/stats/stats.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Package stats provides aggregate functions for statistics.
|
||||
//
|
||||
// Functions:
|
||||
// - stddev_pop: population standard deviation
|
||||
// - stddev_samp: sample standard deviation
|
||||
// - var_pop: population variance
|
||||
// - var_samp: sample variance
|
||||
// - covar_pop: population covariance
|
||||
// - covar_samp: sample covariance
|
||||
// - corr: correlation coefficient
|
||||
//
|
||||
// See: [ANSI SQL Aggregate Functions]
|
||||
//
|
||||
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
|
||||
package stats
|
||||
|
||||
import "github.com/ncruces/go-sqlite3"
|
||||
|
||||
// Register registers statistics functions.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop))
|
||||
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp))
|
||||
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop))
|
||||
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp))
|
||||
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
|
||||
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
|
||||
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr))
|
||||
}
|
||||
|
||||
const (
|
||||
var_pop = iota
|
||||
var_samp
|
||||
stddev_pop
|
||||
stddev_samp
|
||||
corr
|
||||
)
|
||||
|
||||
func newVariance(kind int) func() sqlite3.AggregateFunction {
|
||||
return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
|
||||
}
|
||||
|
||||
type variance struct {
|
||||
kind int
|
||||
welford
|
||||
}
|
||||
|
||||
func (fn *variance) Value(ctx sqlite3.Context) {
|
||||
var r float64
|
||||
switch fn.kind {
|
||||
case var_pop:
|
||||
r = fn.var_pop()
|
||||
case var_samp:
|
||||
r = fn.var_samp()
|
||||
case stddev_pop:
|
||||
r = fn.stddev_pop()
|
||||
case stddev_samp:
|
||||
r = fn.stddev_samp()
|
||||
}
|
||||
ctx.ResultFloat(r)
|
||||
}
|
||||
|
||||
func (fn *variance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if a := arg[0]; a.Type() != sqlite3.NULL {
|
||||
fn.enqueue(a.Float())
|
||||
}
|
||||
}
|
||||
|
||||
func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if a := arg[0]; a.Type() != sqlite3.NULL {
|
||||
fn.dequeue(a.Float())
|
||||
}
|
||||
}
|
||||
|
||||
func newCovariance(kind int) func() sqlite3.AggregateFunction {
|
||||
return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
|
||||
}
|
||||
|
||||
type covariance struct {
|
||||
kind int
|
||||
welford2
|
||||
}
|
||||
|
||||
func (fn *covariance) Value(ctx sqlite3.Context) {
|
||||
var r float64
|
||||
switch fn.kind {
|
||||
case var_pop:
|
||||
r = fn.covar_pop()
|
||||
case var_samp:
|
||||
r = fn.covar_samp()
|
||||
case corr:
|
||||
r = fn.correlation()
|
||||
}
|
||||
ctx.ResultFloat(r)
|
||||
}
|
||||
|
||||
func (fn *covariance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
a, b := arg[0], arg[1]
|
||||
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
|
||||
fn.enqueue(a.Float(), b.Float())
|
||||
}
|
||||
}
|
||||
|
||||
func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
a, b := arg[0], arg[1]
|
||||
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
|
||||
fn.dequeue(a.Float(), b.Float())
|
||||
}
|
||||
}
|
||||
140
ext/stats/stats_test.go
Normal file
140
ext/stats/stats_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestRegister_variance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`
|
||||
SELECT
|
||||
sum(x), avg(x),
|
||||
var_samp(x), var_pop(x),
|
||||
stddev_samp(x), stddev_pop(x)
|
||||
FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 40 {
|
||||
t.Errorf("got %v, want 40", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(1); got != 10 {
|
||||
t.Errorf("got %v, want 10", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(2); got != 30 {
|
||||
t.Errorf("got %v, want 30", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(3); got != 22.5 {
|
||||
t.Errorf("got %v, want 22.5", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(4); got != math.Sqrt(30) {
|
||||
t.Errorf("got %v, want √30", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(5); got != math.Sqrt(22.5) {
|
||||
t.Errorf("got %v, want √22.5", got)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
stmt, _, err := db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
want := [...]float64{0, 4.5, 18, 0, 0}
|
||||
for i := 0; stmt.Step(); i++ {
|
||||
if got := stmt.ColumnFloat(0); got != want[i] {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_covariance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x, y)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO data (x, y) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT
|
||||
corr(x, y), covar_samp(x, y), covar_pop(x, y) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
|
||||
t.Errorf("got %v, want 0.9881049293224639", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(1); got != 21.25 {
|
||||
t.Errorf("got %v, want 21.25", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(2); got != 17 {
|
||||
t.Errorf("got %v, want 17", got)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
stmt, _, err := db.Prepare(`SELECT covar_samp(x, y) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
want := [...]float64{0, 10, 30, 75, 22.5}
|
||||
for i := 0; stmt.Step(); i++ {
|
||||
if got := stmt.ColumnFloat(0); got != want[i] {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
109
ext/stats/welford.go
Normal file
109
ext/stats/welford.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package stats
|
||||
|
||||
import "math"
|
||||
|
||||
// Welford's algorithm with Kahan summation:
|
||||
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
|
||||
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
|
||||
|
||||
type welford struct {
|
||||
m1, m2 kahan
|
||||
n uint64
|
||||
}
|
||||
|
||||
func (w welford) average() float64 {
|
||||
return w.m1.hi
|
||||
}
|
||||
|
||||
func (w welford) var_pop() float64 {
|
||||
return w.m2.hi / float64(w.n)
|
||||
}
|
||||
|
||||
func (w welford) var_samp() float64 {
|
||||
return w.m2.hi / float64(w.n-1) // Bessel's correction
|
||||
}
|
||||
|
||||
func (w welford) stddev_pop() float64 {
|
||||
return math.Sqrt(w.var_pop())
|
||||
}
|
||||
|
||||
func (w welford) stddev_samp() float64 {
|
||||
return math.Sqrt(w.var_samp())
|
||||
}
|
||||
|
||||
func (w *welford) enqueue(x float64) {
|
||||
w.n++
|
||||
d1 := x - w.m1.hi - w.m1.lo
|
||||
w.m1.add(d1 / float64(w.n))
|
||||
d2 := x - w.m1.hi - w.m1.lo
|
||||
w.m2.add(d1 * d2)
|
||||
}
|
||||
|
||||
func (w *welford) dequeue(x float64) {
|
||||
w.n--
|
||||
d1 := x - w.m1.hi - w.m1.lo
|
||||
w.m1.sub(d1 / float64(w.n))
|
||||
d2 := x - w.m1.hi - w.m1.lo
|
||||
w.m2.sub(d1 * d2)
|
||||
}
|
||||
|
||||
type welford2 struct {
|
||||
m1x, m2x kahan
|
||||
m1y, m2y kahan
|
||||
cov kahan
|
||||
n uint64
|
||||
}
|
||||
|
||||
func (w welford2) covar_pop() float64 {
|
||||
return w.cov.hi / float64(w.n)
|
||||
}
|
||||
|
||||
func (w welford2) covar_samp() float64 {
|
||||
return w.cov.hi / float64(w.n-1) // Bessel's correction
|
||||
}
|
||||
|
||||
func (w welford2) correlation() float64 {
|
||||
return w.cov.hi / math.Sqrt(w.m2x.hi*w.m2y.hi)
|
||||
}
|
||||
|
||||
func (w *welford2) enqueue(x, y float64) {
|
||||
w.n++
|
||||
d1x := x - w.m1x.hi - w.m1x.lo
|
||||
d1y := y - w.m1y.hi - w.m1y.lo
|
||||
w.m1x.add(d1x / float64(w.n))
|
||||
w.m1y.add(d1y / float64(w.n))
|
||||
d2x := x - w.m1x.hi - w.m1x.lo
|
||||
d2y := y - w.m1y.hi - w.m1y.lo
|
||||
w.m2x.add(d1x * d2x)
|
||||
w.m2y.add(d1y * d2y)
|
||||
w.cov.add(d1x * d2y)
|
||||
}
|
||||
|
||||
func (w *welford2) dequeue(x, y float64) {
|
||||
w.n--
|
||||
d1x := x - w.m1x.hi - w.m1x.lo
|
||||
d1y := y - w.m1y.hi - w.m1y.lo
|
||||
w.m1x.sub(d1x / float64(w.n))
|
||||
w.m1y.sub(d1y / float64(w.n))
|
||||
d2x := x - w.m1x.hi - w.m1x.lo
|
||||
d2y := y - w.m1y.hi - w.m1y.lo
|
||||
w.m2x.sub(d1x * d2x)
|
||||
w.m2y.sub(d1y * d2y)
|
||||
w.cov.sub(d1x * d2y)
|
||||
}
|
||||
|
||||
type kahan struct{ hi, lo float64 }
|
||||
|
||||
func (k *kahan) add(x float64) {
|
||||
y := k.lo + x
|
||||
t := k.hi + y
|
||||
k.lo = y - (t - k.hi)
|
||||
k.hi = t
|
||||
}
|
||||
|
||||
func (k *kahan) sub(x float64) {
|
||||
y := k.lo - x
|
||||
t := k.hi + y
|
||||
k.lo = y - (t - k.hi)
|
||||
k.hi = t
|
||||
}
|
||||
75
ext/stats/welford_test.go
Normal file
75
ext/stats/welford_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_welford(t *testing.T) {
|
||||
var s1, s2 welford
|
||||
|
||||
s1.enqueue(4)
|
||||
s1.enqueue(7)
|
||||
s1.enqueue(13)
|
||||
s1.enqueue(16)
|
||||
if got := s1.average(); got != 10 {
|
||||
t.Errorf("got %v, want 10", got)
|
||||
}
|
||||
if got := s1.var_samp(); got != 30 {
|
||||
t.Errorf("got %v, want 30", got)
|
||||
}
|
||||
if got := s1.var_pop(); got != 22.5 {
|
||||
t.Errorf("got %v, want 22.5", got)
|
||||
}
|
||||
if got := s1.stddev_samp(); got != math.Sqrt(30) {
|
||||
t.Errorf("got %v, want √30", got)
|
||||
}
|
||||
if got := s1.stddev_pop(); got != math.Sqrt(22.5) {
|
||||
t.Errorf("got %v, want √22.5", got)
|
||||
}
|
||||
|
||||
s1.dequeue(4)
|
||||
s2.enqueue(7)
|
||||
s2.enqueue(13)
|
||||
s2.enqueue(16)
|
||||
if s1.var_pop() != s2.var_pop() {
|
||||
t.Errorf("got %v, want %v", s1, s2)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_covar(t *testing.T) {
|
||||
var c1, c2 welford2
|
||||
|
||||
c1.enqueue(3, 70)
|
||||
c1.enqueue(5, 80)
|
||||
c1.enqueue(2, 60)
|
||||
c1.enqueue(7, 90)
|
||||
c1.enqueue(4, 75)
|
||||
|
||||
if got := c1.covar_samp(); got != 21.25 {
|
||||
t.Errorf("got %v, want 21.25", got)
|
||||
}
|
||||
if got := c1.covar_pop(); got != 17 {
|
||||
t.Errorf("got %v, want 17", got)
|
||||
}
|
||||
|
||||
c1.dequeue(3, 70)
|
||||
c2.enqueue(5, 80)
|
||||
c2.enqueue(2, 60)
|
||||
c2.enqueue(7, 90)
|
||||
c2.enqueue(4, 75)
|
||||
if c1.covar_pop() != c2.covar_pop() {
|
||||
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_correlation(t *testing.T) {
|
||||
var c welford2
|
||||
c.enqueue(1, 3)
|
||||
c.enqueue(2, 2)
|
||||
c.enqueue(3, 1)
|
||||
|
||||
if got := c.correlation(); got != -1 {
|
||||
t.Errorf("got %v, want -1", got)
|
||||
}
|
||||
}
|
||||
181
ext/unicode/unicode.go
Normal file
181
ext/unicode/unicode.go
Normal file
@@ -0,0 +1,181 @@
|
||||
// Package unicode provides an alternative to the SQLite ICU extension.
|
||||
//
|
||||
// Like the [ICU extension], it provides Unicode aware:
|
||||
// - upper() and lower() functions,
|
||||
// - LIKE and REGEXP operators,
|
||||
// - collation sequences.
|
||||
//
|
||||
// The implementation is not 100% compatible with the [ICU extension]:
|
||||
// - upper() and lower() use [strings.ToUpper], [strings.ToLower] and [cases];
|
||||
// - the LIKE operator follows [strings.EqualFold] rules;
|
||||
// - the REGEXP operator uses Go [regex/syntax];
|
||||
// - collation sequences use [collate].
|
||||
//
|
||||
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
|
||||
//
|
||||
// [ICU extension]: https://sqlite.org/src/dir/ext/icu
|
||||
package unicode
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/collate"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// Register registers Unicode aware functions for a database connection.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
|
||||
db.CreateFunction("like", 2, flags, like)
|
||||
db.CreateFunction("like", 3, flags, like)
|
||||
db.CreateFunction("upper", 1, flags, upper)
|
||||
db.CreateFunction("upper", 2, flags, upper)
|
||||
db.CreateFunction("lower", 1, flags, lower)
|
||||
db.CreateFunction("lower", 2, flags, lower)
|
||||
db.CreateFunction("regexp", 2, flags, regex)
|
||||
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
|
||||
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
name := arg[1].Text()
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err := RegisterCollation(db, arg[0].Text(), name)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterCollation registers a Unicode collation sequence for a database connection.
|
||||
func RegisterCollation(db *sqlite3.Conn, locale, name string) error {
|
||||
tag, err := language.Parse(locale)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.CreateCollation(name, collate.New(tag).Compare)
|
||||
}
|
||||
|
||||
func upper(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if len(arg) == 1 {
|
||||
ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob()))
|
||||
return
|
||||
}
|
||||
cs, ok := ctx.GetAuxData(1).(cases.Caser)
|
||||
if !ok {
|
||||
t, err := language.Parse(arg[1].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
c := cases.Upper(t)
|
||||
ctx.SetAuxData(1, c)
|
||||
cs = c
|
||||
}
|
||||
ctx.ResultBlob(cs.Bytes(arg[0].RawBlob()))
|
||||
}
|
||||
|
||||
func lower(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if len(arg) == 1 {
|
||||
ctx.ResultBlob(bytes.ToLower(arg[0].RawBlob()))
|
||||
return
|
||||
}
|
||||
cs, ok := ctx.GetAuxData(1).(cases.Caser)
|
||||
if !ok {
|
||||
t, err := language.Parse(arg[1].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
c := cases.Lower(t)
|
||||
ctx.SetAuxData(1, c)
|
||||
cs = c
|
||||
}
|
||||
ctx.ResultBlob(cs.Bytes(arg[0].RawBlob()))
|
||||
}
|
||||
|
||||
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
|
||||
if !ok {
|
||||
r, err := regexp.Compile(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
re = r
|
||||
ctx.SetAuxData(0, re)
|
||||
}
|
||||
ctx.ResultBool(re.Match(arg[1].RawBlob()))
|
||||
}
|
||||
|
||||
func like(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
escape := rune(-1)
|
||||
if len(arg) == 3 {
|
||||
var size int
|
||||
b := arg[2].RawBlob()
|
||||
escape, size = utf8.DecodeRune(b)
|
||||
if size != len(b) {
|
||||
ctx.ResultError(util.ErrorString("ESCAPE expression must be a single character"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type likeData struct {
|
||||
*regexp.Regexp
|
||||
escape rune
|
||||
}
|
||||
|
||||
re, ok := ctx.GetAuxData(0).(likeData)
|
||||
if !ok || re.escape != escape {
|
||||
re = likeData{
|
||||
regexp.MustCompile(like2regex(arg[0].Text(), escape)),
|
||||
escape,
|
||||
}
|
||||
ctx.SetAuxData(0, re)
|
||||
}
|
||||
ctx.ResultBool(re.Match(arg[1].RawBlob()))
|
||||
}
|
||||
|
||||
func like2regex(pattern string, escape rune) string {
|
||||
var re strings.Builder
|
||||
start := 0
|
||||
literal := false
|
||||
re.Grow(len(pattern) + 10)
|
||||
re.WriteString(`(?is)\A`) // case insensitive, . matches any character
|
||||
for i, r := range pattern {
|
||||
if start < 0 {
|
||||
start = i
|
||||
}
|
||||
if literal {
|
||||
literal = false
|
||||
continue
|
||||
}
|
||||
var symbol string
|
||||
switch r {
|
||||
case '_':
|
||||
symbol = `.`
|
||||
case '%':
|
||||
symbol = `.*`
|
||||
case escape:
|
||||
literal = true
|
||||
default:
|
||||
continue
|
||||
}
|
||||
re.WriteString(regexp.QuoteMeta(pattern[start:i]))
|
||||
re.WriteString(symbol)
|
||||
start = -1
|
||||
}
|
||||
if start >= 0 {
|
||||
re.WriteString(regexp.QuoteMeta(pattern[start:]))
|
||||
}
|
||||
re.WriteString(`\z`)
|
||||
return re.String()
|
||||
}
|
||||
215
ext/unicode/unicode_test.go
Normal file
215
ext/unicode/unicode_test.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package unicode
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
exec := func(fn string) string {
|
||||
stmt, _, err := db.Prepare(`SELECT ` + fn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
return stmt.ColumnText(0)
|
||||
}
|
||||
t.Fatal(stmt.Err())
|
||||
return ""
|
||||
}
|
||||
|
||||
Register(db)
|
||||
|
||||
tests := []struct {
|
||||
test string
|
||||
want string
|
||||
}{
|
||||
{`upper('hello')`, "HELLO"},
|
||||
{`lower('HELLO')`, "hello"},
|
||||
{`upper('привет')`, "ПРИВЕТ"},
|
||||
{`lower('ПРИВЕТ')`, "привет"},
|
||||
{`upper('istanbul')`, "ISTANBUL"},
|
||||
{`upper('istanbul', 'tr-TR')`, "İSTANBUL"},
|
||||
{`lower('Dünyanın İlk Borsası', 'tr-TR')`, "dünyanın ilk borsası"},
|
||||
{`upper('Dünyanın İlk Borsası', 'tr-TR')`, "DÜNYANIN İLK BORSASI"},
|
||||
{`'Hello' REGEXP 'ell'`, "1"},
|
||||
{`'Hello' REGEXP 'el.'`, "1"},
|
||||
{`'Hello' LIKE 'hel_'`, "0"},
|
||||
{`'Hello' LIKE 'hel%'`, "1"},
|
||||
{`'Hello' LIKE 'h_llo'`, "1"},
|
||||
{`'Hello' LIKE 'hello'`, "1"},
|
||||
{`'Привет' LIKE 'ПРИВЕТ'`, "1"},
|
||||
{`'100%' LIKE '100|%' ESCAPE '|'`, "1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.test, func(t *testing.T) {
|
||||
if got := exec(tt.test); got != tt.want {
|
||||
t.Errorf("exec(%q) = %q, want %q", tt.test, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_collation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT icu_load_collation('fr_FR', 'french')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT word FROM words ORDER BY word COLLATE french`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
got, want := []string{}, []string{"cote", "coté", "côte", "côté", "cotée", "coter"}
|
||||
|
||||
for stmt.Step() {
|
||||
got = append(got, stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Error("not equal")
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
Register(db)
|
||||
|
||||
err = db.Exec(`SELECT upper('hello', 'enUS')`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT lower('hello', 'enUS')`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT 'hello' REGEXP '\'`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT 'hello' LIKE 'HELLO' ESCAPE '\\'`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT icu_load_collation('enUS', 'error')`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT icu_load_collation('enUS', '')`)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_like2regex(t *testing.T) {
|
||||
const prefix = `(?is)\A`
|
||||
const sufix = `\z`
|
||||
tests := []struct {
|
||||
pattern string
|
||||
escape rune
|
||||
want string
|
||||
}{
|
||||
{`a`, -1, `a`},
|
||||
{`a.`, -1, `a\.`},
|
||||
{`a%`, -1, `a.*`},
|
||||
{`a\`, -1, `a\\`},
|
||||
{`a_b`, -1, `a.b`},
|
||||
{`a|b`, '|', `ab`},
|
||||
{`a|_`, '|', `a_`},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.pattern, func(t *testing.T) {
|
||||
want := prefix + tt.want + sufix
|
||||
if got := like2regex(tt.pattern, tt.escape); got != want {
|
||||
t.Errorf("like2regex() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
175
func.go
Normal file
175
func.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
// AnyCollationNeeded registers a fake collating function
|
||||
// for any unknown collating sequence.
|
||||
// The fake collating function works like BINARY.
|
||||
//
|
||||
// This can be used to load schemas that contain
|
||||
// one or more unknown collating sequences.
|
||||
func (c *Conn) AnyCollationNeeded() {
|
||||
c.call(c.api.anyCollation, uint64(c.handle), 0, 0)
|
||||
}
|
||||
|
||||
// CreateCollation defines a new collating sequence.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/create_collation.html
|
||||
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
|
||||
namePtr := c.arena.string(name)
|
||||
funcPtr := util.AddHandle(c.ctx, fn)
|
||||
r := c.call(c.api.createCollation,
|
||||
uint64(c.handle), uint64(namePtr), uint64(funcPtr))
|
||||
if err := c.error(r); err != nil {
|
||||
util.DelHandle(c.ctx, funcPtr)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateFunction defines a new scalar SQL function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/create_function.html
|
||||
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error {
|
||||
namePtr := c.arena.string(name)
|
||||
funcPtr := util.AddHandle(c.ctx, fn)
|
||||
r := c.call(c.api.createFunction,
|
||||
uint64(c.handle), uint64(namePtr), uint64(nArg),
|
||||
uint64(flag), uint64(funcPtr))
|
||||
return c.error(r)
|
||||
}
|
||||
|
||||
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
|
||||
// If fn returns a [WindowFunction], then an aggregate window function is created.
|
||||
// If fn returns an [io.Closer], it will be called to free resources.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/create_function.html
|
||||
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
|
||||
call := c.api.createAggregate
|
||||
namePtr := c.arena.string(name)
|
||||
funcPtr := util.AddHandle(c.ctx, fn)
|
||||
if _, ok := fn().(WindowFunction); ok {
|
||||
call = c.api.createWindow
|
||||
}
|
||||
r := c.call(call,
|
||||
uint64(c.handle), uint64(namePtr), uint64(nArg),
|
||||
uint64(flag), uint64(funcPtr))
|
||||
return c.error(r)
|
||||
}
|
||||
|
||||
// AggregateFunction is the interface an aggregate function should implement.
|
||||
//
|
||||
// https://www.sqlite.org/appfunc.html
|
||||
type AggregateFunction interface {
|
||||
// Step is invoked to add a row to the current window.
|
||||
// The function arguments, if any, corresponding to the row being added are passed to Step.
|
||||
Step(ctx Context, arg ...Value)
|
||||
|
||||
// Value is invoked to return the current (or final) value of the aggregate.
|
||||
Value(ctx Context)
|
||||
}
|
||||
|
||||
// WindowFunction is the interface an aggregate window function should implement.
|
||||
//
|
||||
// https://www.sqlite.org/windowfunctions.html
|
||||
type WindowFunction interface {
|
||||
AggregateFunction
|
||||
|
||||
// Inverse is invoked to remove the oldest presently aggregated result of Step from the current window.
|
||||
// The function arguments, if any, are those passed to Step for the row being removed.
|
||||
Inverse(ctx Context, arg ...Value)
|
||||
}
|
||||
|
||||
func callbackDestroy(ctx context.Context, mod api.Module, pApp uint32) {
|
||||
util.DelHandle(ctx, pApp)
|
||||
}
|
||||
|
||||
func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
|
||||
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
|
||||
return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))
|
||||
}
|
||||
|
||||
func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
db := ctx.Value(connKey{}).(*Conn)
|
||||
fn := callbackHandle(db, pCtx).(func(ctx Context, arg ...Value))
|
||||
fn(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
db := ctx.Value(connKey{}).(*Conn)
|
||||
fn := callbackAggregate(db, pCtx, nil).(AggregateFunction)
|
||||
fn.Step(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) {
|
||||
var handle uint32
|
||||
db := ctx.Value(connKey{}).(*Conn)
|
||||
fn := callbackAggregate(db, pCtx, &handle).(AggregateFunction)
|
||||
fn.Value(Context{db, pCtx})
|
||||
if err := util.DelHandle(ctx, handle); err != nil {
|
||||
Context{db, pCtx}.ResultError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) {
|
||||
db := ctx.Value(connKey{}).(*Conn)
|
||||
fn := callbackAggregate(db, pCtx, nil).(AggregateFunction)
|
||||
fn.Value(Context{db, pCtx})
|
||||
}
|
||||
|
||||
func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
db := ctx.Value(connKey{}).(*Conn)
|
||||
fn := callbackAggregate(db, pCtx, nil).(WindowFunction)
|
||||
fn.Inverse(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackHandle(db *Conn, pCtx uint32) any {
|
||||
pApp := uint32(db.call(db.api.userData, uint64(pCtx)))
|
||||
return util.GetHandle(db.ctx, pApp)
|
||||
}
|
||||
|
||||
func callbackAggregate(db *Conn, pCtx uint32, close *uint32) any {
|
||||
// On close, we're getting rid of the handle.
|
||||
// Don't allocate space to store it.
|
||||
var size uint64
|
||||
if close == nil {
|
||||
size = ptrlen
|
||||
}
|
||||
ptr := uint32(db.call(db.api.aggregateCtx, uint64(pCtx), size))
|
||||
|
||||
// Try loading the handle, if we already have one, or want a new one.
|
||||
if ptr != 0 || size != 0 {
|
||||
if handle := util.ReadUint32(db.mod, ptr); handle != 0 {
|
||||
fn := util.GetHandle(db.ctx, handle)
|
||||
if close != nil {
|
||||
*close = handle
|
||||
}
|
||||
if fn != nil {
|
||||
return fn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new aggregate and store the handle.
|
||||
fn := callbackHandle(db, pCtx).(func() AggregateFunction)()
|
||||
if ptr != 0 {
|
||||
util.WriteUint32(db.mod, ptr, util.AddHandle(db.ctx, fn))
|
||||
}
|
||||
return fn
|
||||
}
|
||||
|
||||
func callbackArgs(db *Conn, nArg, pArg uint32) []Value {
|
||||
args := make([]Value, nArg)
|
||||
for i := range args {
|
||||
args[i] = Value{
|
||||
sqlite: db.sqlite,
|
||||
handle: util.ReadUint32(db.mod, pArg+ptrlen*uint32(i)),
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
154
func_test.go
Normal file
154
func_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"regexp"
|
||||
|
||||
"golang.org/x/text/collate"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func ExampleConn_CreateCollation() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateCollation("french", collate.New(language.French).Compare)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT word FROM words ORDER BY word COLLATE french`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Output:
|
||||
// cote
|
||||
// coté
|
||||
// côte
|
||||
// côté
|
||||
// cotée
|
||||
// coter
|
||||
}
|
||||
|
||||
func ExampleConn_CreateFunction() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateFunction("upper", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob()))
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT upper(word) FROM words`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Unordered output:
|
||||
// COTE
|
||||
// COTÉ
|
||||
// CÔTE
|
||||
// CÔTÉ
|
||||
// COTÉE
|
||||
// COTER
|
||||
}
|
||||
|
||||
func ExampleContext_SetAuxData() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateFunction("regexp", 2, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
|
||||
if !ok {
|
||||
r, err := regexp.Compile(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
ctx.SetAuxData(0, r)
|
||||
re = r
|
||||
}
|
||||
ctx.ResultBool(re.Match(arg[1].RawBlob()))
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT word FROM words WHERE word REGEXP '^\p{L}+e$'`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Unordered output:
|
||||
// cote
|
||||
// côte
|
||||
// cotée
|
||||
}
|
||||
87
func_win_test.go
Normal file
87
func_win_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"unicode"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func ExampleConn_CreateWindowFunction() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateWindowFunction("count_ascii", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, newASCIICounter)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT count_ascii(word) OVER (ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM words`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnInt(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Output:
|
||||
// 1
|
||||
// 2
|
||||
// 2
|
||||
// 1
|
||||
// 0
|
||||
// 0
|
||||
}
|
||||
|
||||
type countASCII struct{ result int }
|
||||
|
||||
func newASCIICounter() sqlite3.AggregateFunction {
|
||||
return &countASCII{}
|
||||
}
|
||||
|
||||
func (f *countASCII) Value(ctx sqlite3.Context) {
|
||||
ctx.ResultInt(f.result)
|
||||
}
|
||||
|
||||
func (f *countASCII) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if f.isASCII(arg[0]) {
|
||||
f.result++
|
||||
}
|
||||
}
|
||||
|
||||
func (f *countASCII) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if f.isASCII(arg[0]) {
|
||||
f.result--
|
||||
}
|
||||
}
|
||||
|
||||
func (f *countASCII) isASCII(arg sqlite3.Value) bool {
|
||||
if arg.Type() != sqlite3.TEXT {
|
||||
return false
|
||||
}
|
||||
for _, c := range arg.RawBlob() {
|
||||
if c > unicode.MaxASCII {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
11
go.mod
11
go.mod
@@ -1,13 +1,14 @@
|
||||
module github.com/ncruces/go-sqlite3
|
||||
|
||||
go 1.19
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/ncruces/julianday v0.1.5
|
||||
github.com/ncruces/julianday v1.0.0
|
||||
github.com/psanford/httpreadat v0.1.0
|
||||
github.com/tetratelabs/wazero v1.2.0
|
||||
golang.org/x/sync v0.2.0
|
||||
golang.org/x/sys v0.8.0
|
||||
github.com/tetratelabs/wazero v1.5.0
|
||||
golang.org/x/sync v0.5.0
|
||||
golang.org/x/sys v0.14.0
|
||||
golang.org/x/text v0.14.0
|
||||
)
|
||||
|
||||
retract v0.4.0 // tagged from the wrong branch
|
||||
|
||||
18
go.sum
18
go.sum
@@ -1,10 +1,12 @@
|
||||
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
|
||||
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
|
||||
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
|
||||
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
|
||||
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
|
||||
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
|
||||
github.com/tetratelabs/wazero v1.2.0 h1:I/8LMf4YkCZ3r2XaL9whhA0VMyAvF6QE+O7rco0DCeQ=
|
||||
github.com/tetratelabs/wazero v1.2.0/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
|
||||
golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
|
||||
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
|
||||
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
|
||||
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
|
||||
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
|
||||
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
|
||||
5
go.work.sum
Normal file
5
go.work.sum
Normal file
@@ -0,0 +1,5 @@
|
||||
github.com/ncruces/go-sqlite3 v0.9.1/go.mod h1:jFoUbaCDNUS1KN5ZgFxN7bgcWoWfO0EOKeik9QAHZ08=
|
||||
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
22
gormlite/LICENSE
Normal file
22
gormlite/LICENSE
Normal file
@@ -0,0 +1,22 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Nuno Cruces
|
||||
Copyright (c) 2023 Jinzhu <wosmvp@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
26
gormlite/README.md
Normal file
26
gormlite/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# GORM SQLite Driver
|
||||
|
||||
[](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import (
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/gormlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
db, err := gorm.Open(gormlite.Open("gorm.db"), &gorm.Config{})
|
||||
```
|
||||
|
||||
Checkout [https://gorm.io](https://gorm.io) for details.
|
||||
|
||||
### Foreign-key constraint activation
|
||||
|
||||
Foreign-key constraint is disabled by default in SQLite. To activate it, use connection URL parameter:
|
||||
```go
|
||||
db, err := gorm.Open(gormlite.Open(
|
||||
"file:gorm.db?_pragma=busy_timeout(10000)&_pragma=foreign_keys(1)"),
|
||||
&gorm.Config{})
|
||||
```
|
||||
297
gormlite/ddlmod.go
Normal file
297
gormlite/ddlmod.go
Normal file
@@ -0,0 +1,297 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/migrator"
|
||||
)
|
||||
|
||||
var (
|
||||
sqliteSeparator = "`|\"|'|\t"
|
||||
indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]? ON (.*)$`, sqliteSeparator, sqliteSeparator))
|
||||
tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator))
|
||||
separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator))
|
||||
columnsRegexp = regexp.MustCompile(fmt.Sprintf(`[(,][%v]?(\w+)[%v]?`, sqliteSeparator, sqliteSeparator))
|
||||
columnRegexp = regexp.MustCompile(fmt.Sprintf(`^[%v]?([\w\d]+)[%v]?\s+([\w\(\)\d]+)(.*)$`, sqliteSeparator, sqliteSeparator))
|
||||
defaultValueRegexp = regexp.MustCompile(`(?i) DEFAULT \(?(.+)?\)?( |COLLATE|GENERATED|$)`)
|
||||
regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
|
||||
)
|
||||
|
||||
func getAllColumns(s string) []string {
|
||||
allMatches := columnsRegexp.FindAllStringSubmatch(s, -1)
|
||||
columns := make([]string, 0, len(allMatches))
|
||||
for _, matches := range allMatches {
|
||||
if len(matches) > 1 {
|
||||
columns = append(columns, matches[1])
|
||||
}
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
type ddl struct {
|
||||
head string
|
||||
fields []string
|
||||
columns []migrator.ColumnType
|
||||
}
|
||||
|
||||
func parseDDL(strs ...string) (*ddl, error) {
|
||||
var result ddl
|
||||
for _, str := range strs {
|
||||
if sections := tableRegexp.FindStringSubmatch(str); len(sections) > 0 {
|
||||
var (
|
||||
ddlBody = sections[2]
|
||||
ddlBodyRunes = []rune(ddlBody)
|
||||
bracketLevel int
|
||||
quote rune
|
||||
buf string
|
||||
)
|
||||
ddlBodyRunesLen := len(ddlBodyRunes)
|
||||
|
||||
result.head = sections[1]
|
||||
|
||||
for idx := 0; idx < ddlBodyRunesLen; idx++ {
|
||||
var (
|
||||
next rune = 0
|
||||
c = ddlBodyRunes[idx]
|
||||
)
|
||||
if idx+1 < ddlBodyRunesLen {
|
||||
next = ddlBodyRunes[idx+1]
|
||||
}
|
||||
|
||||
if sc := string(c); separatorRegexp.MatchString(sc) {
|
||||
if c == next {
|
||||
buf += sc // Skip escaped quote
|
||||
idx++
|
||||
} else if quote > 0 {
|
||||
quote = 0
|
||||
} else {
|
||||
quote = c
|
||||
}
|
||||
} else if quote == 0 {
|
||||
if c == '(' {
|
||||
bracketLevel++
|
||||
} else if c == ')' {
|
||||
bracketLevel--
|
||||
} else if bracketLevel == 0 {
|
||||
if c == ',' {
|
||||
result.fields = append(result.fields, strings.TrimSpace(buf))
|
||||
buf = ""
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bracketLevel < 0 {
|
||||
return nil, errors.New("invalid DDL, unbalanced brackets")
|
||||
}
|
||||
|
||||
buf += string(c)
|
||||
}
|
||||
|
||||
if bracketLevel != 0 {
|
||||
return nil, errors.New("invalid DDL, unbalanced brackets")
|
||||
}
|
||||
|
||||
if buf != "" {
|
||||
result.fields = append(result.fields, strings.TrimSpace(buf))
|
||||
}
|
||||
|
||||
for _, f := range result.fields {
|
||||
fUpper := strings.ToUpper(f)
|
||||
if strings.HasPrefix(fUpper, "CHECK") ||
|
||||
strings.HasPrefix(fUpper, "CONSTRAINT") {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(fUpper, "PRIMARY KEY") {
|
||||
for _, name := range getAllColumns(f) {
|
||||
for idx, column := range result.columns {
|
||||
if column.NameValue.String == name {
|
||||
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
|
||||
result.columns[idx] = column
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 0 {
|
||||
columnType := migrator.ColumnType{
|
||||
NameValue: sql.NullString{String: matches[1], Valid: true},
|
||||
DataTypeValue: sql.NullString{String: matches[2], Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: matches[2], Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Valid: true},
|
||||
UniqueValue: sql.NullBool{Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: true, Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
}
|
||||
|
||||
matchUpper := strings.ToUpper(matches[3])
|
||||
if strings.Contains(matchUpper, " NOT NULL") {
|
||||
columnType.NullableValue = sql.NullBool{Bool: false, Valid: true}
|
||||
} else if strings.Contains(matchUpper, " NULL") {
|
||||
columnType.NullableValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
if strings.Contains(matchUpper, " UNIQUE") {
|
||||
columnType.UniqueValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
if strings.Contains(matchUpper, " PRIMARY") {
|
||||
columnType.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
if defaultMatches := defaultValueRegexp.FindStringSubmatch(matches[3]); len(defaultMatches) > 1 {
|
||||
if strings.ToLower(defaultMatches[1]) != "null" {
|
||||
columnType.DefaultValueValue = sql.NullString{String: strings.Trim(defaultMatches[1], `"`), Valid: true}
|
||||
}
|
||||
}
|
||||
|
||||
// data type length
|
||||
matches := regRealDataType.FindAllStringSubmatch(columnType.DataTypeValue.String, -1)
|
||||
if len(matches) == 1 && len(matches[0]) == 2 {
|
||||
size, _ := strconv.Atoi(matches[0][1])
|
||||
columnType.LengthValue = sql.NullInt64{Valid: true, Int64: int64(size)}
|
||||
columnType.DataTypeValue.String = strings.TrimSuffix(columnType.DataTypeValue.String, matches[0][0])
|
||||
}
|
||||
|
||||
result.columns = append(result.columns, columnType)
|
||||
}
|
||||
}
|
||||
} else if matches := indexRegexp.FindStringSubmatch(str); len(matches) > 0 {
|
||||
for _, column := range getAllColumns(matches[1]) {
|
||||
for idx, c := range result.columns {
|
||||
if c.NameValue.String == column {
|
||||
c.UniqueValue = sql.NullBool{Bool: strings.ToUpper(strings.Fields(str)[1]) == "UNIQUE", Valid: true}
|
||||
result.columns[idx] = c
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("invalid DDL")
|
||||
}
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (d *ddl) clone() *ddl {
|
||||
copied := new(ddl)
|
||||
*copied = *d
|
||||
|
||||
copied.fields = make([]string, len(d.fields))
|
||||
copy(copied.fields, d.fields)
|
||||
copied.columns = make([]migrator.ColumnType, len(d.columns))
|
||||
copy(copied.columns, d.columns)
|
||||
|
||||
return copied
|
||||
}
|
||||
|
||||
func (d *ddl) compile() string {
|
||||
if len(d.fields) == 0 {
|
||||
return d.head
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
|
||||
}
|
||||
|
||||
func (d *ddl) renameTable(dst, src string) error {
|
||||
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + regexp.QuoteMeta(src) + "\\b('|`|\")?\\s*")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
replaced := tableReg.ReplaceAllString(d.head, fmt.Sprintf(" `%s` ", dst))
|
||||
if replaced == d.head {
|
||||
return fmt.Errorf("failed to look up tablename `%s` from DDL head '%s'", src, d.head)
|
||||
}
|
||||
|
||||
d.head = replaced
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *ddl) addConstraint(name string, sql string) {
|
||||
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields[i] = sql
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
d.fields = append(d.fields, sql)
|
||||
}
|
||||
|
||||
func (d *ddl) removeConstraint(name string) bool {
|
||||
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields = append(d.fields[:i], d.fields[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
//lint:ignore U1000 ignore unused code.
|
||||
func (d *ddl) hasConstraint(name string) bool {
|
||||
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
|
||||
|
||||
for _, f := range d.fields {
|
||||
if reg.MatchString(f) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *ddl) getColumns() []string {
|
||||
res := []string{}
|
||||
|
||||
for _, f := range d.fields {
|
||||
fUpper := strings.ToUpper(f)
|
||||
if strings.HasPrefix(fUpper, "PRIMARY KEY") ||
|
||||
strings.HasPrefix(fUpper, "CHECK") ||
|
||||
strings.HasPrefix(fUpper, "CONSTRAINT") ||
|
||||
strings.Contains(fUpper, "GENERATED ALWAYS AS") {
|
||||
continue
|
||||
}
|
||||
|
||||
reg := regexp.MustCompile("^[\"`']?([\\w\\d]+)[\"`']?")
|
||||
match := reg.FindStringSubmatch(f)
|
||||
|
||||
if match != nil {
|
||||
res = append(res, "`"+match[1]+"`")
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (d *ddl) alterColumn(name, sql string) bool {
|
||||
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields[i] = sql
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
d.fields = append(d.fields, sql)
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *ddl) removeColumn(name string) bool {
|
||||
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields = append(d.fields[:i], d.fields[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
352
gormlite/ddlmod_test.go
Normal file
352
gormlite/ddlmod_test.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestParseDDL(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
sql []string
|
||||
nFields int
|
||||
columns []migrator.ColumnType
|
||||
}{
|
||||
{"with_fk", []string{
|
||||
"CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500) DEFAULT \"hello\",`age` integer DEFAULT 18,`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
|
||||
"CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)",
|
||||
}, 6, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}},
|
||||
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
},
|
||||
},
|
||||
{"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
}},
|
||||
{"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
},
|
||||
},
|
||||
{"no brackets", []string{"create table test"}, 0, nil},
|
||||
{"with_special_characters", []string{
|
||||
"CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")",
|
||||
}, 1, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"table_name_with_dash",
|
||||
[]string{
|
||||
"CREATE TABLE `test-a` (`id` int NOT NULL)",
|
||||
"CREATE UNIQUE INDEX `idx_test-a_id` ON `test-a`(`id`)",
|
||||
},
|
||||
1,
|
||||
[]migrator.ColumnType{
|
||||
{
|
||||
NameValue: sql.NullString{String: "id", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "int", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "int", Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
UniqueValue: sql.NullBool{Bool: true, Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Valid: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"unique index",
|
||||
[]string{
|
||||
"CREATE TABLE `test-b` (`field` integer NOT NULL)",
|
||||
"CREATE UNIQUE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
|
||||
},
|
||||
1,
|
||||
[]migrator.ColumnType{
|
||||
{
|
||||
NameValue: sql.NullString{String: "field", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
|
||||
UniqueValue: sql.NullBool{Bool: true, Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"non-unique index",
|
||||
[]string{
|
||||
"CREATE TABLE `test-c` (`field` integer NOT NULL)",
|
||||
"CREATE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
|
||||
},
|
||||
1,
|
||||
[]migrator.ColumnType{
|
||||
{
|
||||
NameValue: sql.NullString{String: "field", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
|
||||
UniqueValue: sql.NullBool{Bool: false, Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
ddl, err := parseDDL(p.sql...)
|
||||
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
tests.AssertEqual(t, p.sql[0], ddl.compile())
|
||||
if len(ddl.fields) != p.nFields {
|
||||
t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields))
|
||||
}
|
||||
tests.AssertEqual(t, ddl.columns, p.columns)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDDL_Whitespaces(t *testing.T) {
|
||||
testColumns := []migrator.ColumnType{
|
||||
{
|
||||
NameValue: sql.NullString{String: "id", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: true, Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
UniqueValue: sql.NullBool{Bool: true, Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true},
|
||||
},
|
||||
{
|
||||
NameValue: sql.NullString{String: "dark_mode", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "numeric", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "numeric", Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: true, Valid: true},
|
||||
DefaultValueValue: sql.NullString{String: "true", Valid: true},
|
||||
UniqueValue: sql.NullBool{Bool: false, Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
|
||||
},
|
||||
}
|
||||
|
||||
params := []struct {
|
||||
name string
|
||||
sql []string
|
||||
nFields int
|
||||
columns []migrator.ColumnType
|
||||
}{
|
||||
{
|
||||
"with_newline",
|
||||
[]string{"CREATE TABLE `users`\n(\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
|
||||
2,
|
||||
testColumns,
|
||||
},
|
||||
{
|
||||
"with_newline_2",
|
||||
[]string{"CREATE TABLE `users` (\n\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
|
||||
2,
|
||||
testColumns,
|
||||
},
|
||||
{
|
||||
"with_missing_space",
|
||||
[]string{"CREATE TABLE `users`(id integer primary key unique, dark_mode numeric DEFAULT true)"},
|
||||
2,
|
||||
testColumns,
|
||||
},
|
||||
{
|
||||
"with_many_spaces",
|
||||
[]string{"CREATE TABLE `users` (id integer primary key unique, dark_mode numeric DEFAULT true)"},
|
||||
2,
|
||||
testColumns,
|
||||
},
|
||||
}
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
ddl, err := parseDDL(p.sql...)
|
||||
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
if len(ddl.fields) != p.nFields {
|
||||
t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields))
|
||||
}
|
||||
tests.AssertEqual(t, ddl.columns, p.columns)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDDL_error(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
sql string
|
||||
}{
|
||||
{"invalid_cmd", "CREATE TABLE"},
|
||||
{"unbalanced_brackets", "CREATE TABLE test (ID int NOT NULL,Name varchar(255)"},
|
||||
{"unbalanced_brackets2", "CREATE TABLE test (ID int NOT NULL,Name varchar(255)))"},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
_, err := parseDDL(p.sql)
|
||||
if err == nil {
|
||||
t.Fail()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddConstraint(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
fields []string
|
||||
cName string
|
||||
sql string
|
||||
expect []string
|
||||
}{
|
||||
{
|
||||
name: "add_new",
|
||||
fields: []string{"`id` integer NOT NULL"},
|
||||
cName: "fk_users_notes",
|
||||
sql: "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
|
||||
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
|
||||
},
|
||||
{
|
||||
name: "update",
|
||||
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
|
||||
cName: "fk_users_notes",
|
||||
sql: "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)) ON UPDATE CASCADE ON DELETE CASCADE",
|
||||
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)) ON UPDATE CASCADE ON DELETE CASCADE"},
|
||||
},
|
||||
{
|
||||
name: "add_check",
|
||||
fields: []string{"`id` integer NOT NULL"},
|
||||
cName: "name_checker",
|
||||
sql: "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')",
|
||||
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')"},
|
||||
},
|
||||
{
|
||||
name: "update_check",
|
||||
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')"},
|
||||
cName: "name_checker",
|
||||
sql: "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')",
|
||||
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
testDDL := ddl{fields: p.fields}
|
||||
|
||||
testDDL.addConstraint(p.cName, p.sql)
|
||||
tests.AssertEqual(t, p.expect, testDDL.fields)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveConstraint(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
fields []string
|
||||
cName string
|
||||
success bool
|
||||
expect []string
|
||||
}{
|
||||
{
|
||||
name: "fk",
|
||||
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
|
||||
cName: "fk_users_notes",
|
||||
success: true,
|
||||
expect: []string{"`id` integer NOT NULL"},
|
||||
},
|
||||
{
|
||||
name: "check",
|
||||
fields: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
|
||||
cName: "name_checker",
|
||||
success: true,
|
||||
expect: []string{"`id` integer NOT NULL"},
|
||||
},
|
||||
{
|
||||
name: "none",
|
||||
fields: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
|
||||
cName: "nothing",
|
||||
success: false,
|
||||
expect: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
testDDL := ddl{fields: p.fields}
|
||||
|
||||
success := testDDL.removeConstraint(p.cName)
|
||||
|
||||
tests.AssertEqual(t, p.success, success)
|
||||
tests.AssertEqual(t, p.expect, testDDL.fields)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetColumns(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
ddl string
|
||||
columns []string
|
||||
}{
|
||||
{
|
||||
name: "with_fk",
|
||||
ddl: "CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500),`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
|
||||
columns: []string{"`id`", "`text`", "`user_id`"},
|
||||
},
|
||||
{
|
||||
name: "with_check",
|
||||
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName!='John'))",
|
||||
columns: []string{"`ID`", "`LastName`", "`FirstName`", "`Age`"},
|
||||
},
|
||||
{
|
||||
name: "with_escaped_quote",
|
||||
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL DEFAULT \"\",FirstName varchar(255))",
|
||||
columns: []string{"`ID`", "`LastName`", "`FirstName`"},
|
||||
},
|
||||
{
|
||||
name: "with_generated_column",
|
||||
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),FullName varchar(255) GENERATED ALWAYS AS (FirstName || ' ' || LastName))",
|
||||
columns: []string{"`ID`", "`LastName`", "`FirstName`"},
|
||||
},
|
||||
{
|
||||
name: "with_new_line",
|
||||
ddl: `CREATE TABLE "tb_sys_role_menu__temp" (
|
||||
"id" integer PRIMARY KEY AUTOINCREMENT,
|
||||
"created_at" datetime NOT NULL,
|
||||
"updated_at" datetime NOT NULL,
|
||||
"created_by" integer NOT NULL DEFAULT 0,
|
||||
"updated_by" integer NOT NULL DEFAULT 0,
|
||||
"role_id" integer NOT NULL,
|
||||
"menu_id" bigint NOT NULL
|
||||
)`,
|
||||
columns: []string{"`id`", "`created_at`", "`updated_at`", "`created_by`", "`updated_by`", "`role_id`", "`menu_id`"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
testDDL, err := parseDDL(p.ddl)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
cols := testDDL.getColumns()
|
||||
|
||||
tests.AssertEqual(t, p.columns, cols)
|
||||
})
|
||||
}
|
||||
}
|
||||
11
gormlite/download.sh
Executable file
11
gormlite/download.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/ddlmod.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/ddlmod_test.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/error_translator.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/migrator.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/sqlite.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/sqlite_test.go"
|
||||
21
gormlite/error_translator.go
Normal file
21
gormlite/error_translator.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (_Dialector) Translate(err error) error {
|
||||
switch {
|
||||
case
|
||||
errors.Is(err, sqlite3.CONSTRAINT_UNIQUE),
|
||||
errors.Is(err, sqlite3.CONSTRAINT_PRIMARYKEY):
|
||||
return gorm.ErrDuplicatedKey
|
||||
case
|
||||
errors.Is(err, sqlite3.CONSTRAINT_FOREIGNKEY):
|
||||
return gorm.ErrForeignKeyViolated
|
||||
}
|
||||
return err
|
||||
}
|
||||
16
gormlite/go.mod
Normal file
16
gormlite/go.mod
Normal file
@@ -0,0 +1,16 @@
|
||||
module github.com/ncruces/go-sqlite3/gormlite
|
||||
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/ncruces/go-sqlite3 v0.9.1
|
||||
gorm.io/gorm v1.25.5
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/ncruces/julianday v0.1.5 // indirect
|
||||
github.com/tetratelabs/wazero v1.5.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
)
|
||||
16
gormlite/go.sum
Normal file
16
gormlite/go.sum
Normal file
@@ -0,0 +1,16 @@
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/ncruces/go-sqlite3 v0.9.1 h1:kV7Zy+ZNyHMfMyZeWc1Yyq+wtgYZDZdp2qAA/wfeMWo=
|
||||
github.com/ncruces/go-sqlite3 v0.9.1/go.mod h1:jFoUbaCDNUS1KN5ZgFxN7bgcWoWfO0EOKeik9QAHZ08=
|
||||
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
|
||||
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
|
||||
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
|
||||
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
|
||||
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
406
gormlite/migrator.go
Normal file
406
gormlite/migrator.go
Normal file
@@ -0,0 +1,406 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type _Migrator struct {
|
||||
migrator.Migrator
|
||||
}
|
||||
|
||||
func (m *_Migrator) RunWithoutForeignKey(fc func() error) error {
|
||||
var enabled int
|
||||
m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
|
||||
if enabled == 1 {
|
||||
m.DB.Exec("PRAGMA foreign_keys = OFF")
|
||||
defer m.DB.Exec("PRAGMA foreign_keys = ON")
|
||||
}
|
||||
|
||||
return fc()
|
||||
}
|
||||
|
||||
func (m _Migrator) HasTable(value interface{}) bool {
|
||||
var count int
|
||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m _Migrator) DropTable(values ...interface{}) error {
|
||||
return m.RunWithoutForeignKey(func() error {
|
||||
values = m.ReorderModels(values, false)
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
|
||||
for i := len(values) - 1; i >= 0; i-- {
|
||||
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
|
||||
return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m _Migrator) GetTables() (tableList []string, err error) {
|
||||
return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error
|
||||
}
|
||||
|
||||
func (m _Migrator) HasColumn(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
|
||||
"table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", "%["+name+"]%", "%\t"+name+"\t%",
|
||||
).Row().Scan(&count)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m _Migrator) AlterColumn(value interface{}, name string) error {
|
||||
return m.RunWithoutForeignKey(func() error {
|
||||
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
if ddl.alterColumn(field.DBName, fmt.Sprintf("`%s` ?", field.DBName)) {
|
||||
return nil, nil, fmt.Errorf("field `%s` not found in origin ddl, ddl= '%s'", name, ddl.compile())
|
||||
}
|
||||
|
||||
return ddl, []interface{}{m.FullDataTypeOf(field)}, nil
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("failed to alter field with name `%s`", name)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
|
||||
func (m _Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
columnTypes := make([]gorm.ColumnType, 0)
|
||||
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
|
||||
var (
|
||||
sqls []string
|
||||
sqlDDL *ddl
|
||||
)
|
||||
|
||||
if err := m.DB.Raw("SELECT sql FROM sqlite_master WHERE type IN ? AND tbl_name = ? AND sql IS NOT NULL order by type = ? desc", []string{"table", "index"}, stmt.Table, "table").Scan(&sqls).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sqlDDL, err = parseDDL(sqls...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err = rows.Close()
|
||||
}()
|
||||
|
||||
var rawColumnTypes []*sql.ColumnType
|
||||
rawColumnTypes, err = rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, c := range rawColumnTypes {
|
||||
columnType := migrator.ColumnType{SQLColumnType: c}
|
||||
for _, column := range sqlDDL.columns {
|
||||
if column.NameValue.String == c.Name() {
|
||||
column.SQLColumnType = c
|
||||
columnType = column
|
||||
break
|
||||
}
|
||||
}
|
||||
columnTypes = append(columnTypes, columnType)
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return columnTypes, execErr
|
||||
}
|
||||
|
||||
func (m _Migrator) DropColumn(value interface{}, name string) error {
|
||||
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
ddl.removeColumn(name)
|
||||
return ddl, nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m _Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
|
||||
return m.recreateTable(value, &table,
|
||||
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
|
||||
var (
|
||||
constraintName string
|
||||
constraintSql string
|
||||
constraintValues []interface{}
|
||||
)
|
||||
|
||||
if constraint != nil {
|
||||
constraintName = constraint.Name
|
||||
constraintSql, constraintValues = buildConstraint(constraint)
|
||||
} else if chk != nil {
|
||||
constraintName = chk.Name
|
||||
constraintSql = "CONSTRAINT ? CHECK (?)"
|
||||
constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
|
||||
} else {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
ddl.addConstraint(constraintName, constraintSql)
|
||||
return ddl, constraintValues, nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (m _Migrator) DropConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.Name
|
||||
} else if chk != nil {
|
||||
name = chk.Name
|
||||
}
|
||||
|
||||
return m.recreateTable(value, &table,
|
||||
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
|
||||
ddl.removeConstraint(name)
|
||||
return ddl, nil, nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (m _Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.Name
|
||||
} else if chk != nil {
|
||||
name = chk.Name
|
||||
}
|
||||
|
||||
m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
|
||||
"table", table, `%CONSTRAINT "`+name+`" %`, `%CONSTRAINT `+name+` %`, "%CONSTRAINT `"+name+"`%", "%CONSTRAINT ["+name+"]%", "%CONSTRAINT \t"+name+"\t%",
|
||||
).Row().Scan(&count)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m _Migrator) CurrentDatabase() (name string) {
|
||||
var null interface{}
|
||||
m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
|
||||
return
|
||||
}
|
||||
|
||||
func (m _Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
|
||||
for _, opt := range opts {
|
||||
str := stmt.Quote(opt.DBName)
|
||||
if opt.Expression != "" {
|
||||
str = opt.Expression
|
||||
}
|
||||
|
||||
if opt.Collate != "" {
|
||||
str += " COLLATE " + opt.Collate
|
||||
}
|
||||
|
||||
if opt.Sort != "" {
|
||||
str += " " + opt.Sort
|
||||
}
|
||||
results = append(results, clause.Expr{SQL: str})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m _Migrator) CreateIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
opts := m.BuildIndexOptions(idx.Fields, stmt)
|
||||
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
|
||||
|
||||
createIndexSQL := "CREATE "
|
||||
if idx.Class != "" {
|
||||
createIndexSQL += idx.Class + " "
|
||||
}
|
||||
createIndexSQL += "INDEX ?"
|
||||
|
||||
if idx.Type != "" {
|
||||
createIndexSQL += " USING " + idx.Type
|
||||
}
|
||||
createIndexSQL += " ON ??"
|
||||
|
||||
if idx.Where != "" {
|
||||
createIndexSQL += " WHERE " + idx.Where
|
||||
}
|
||||
|
||||
return m.DB.Exec(createIndexSQL, values...).Error
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to create index with name %v", name)
|
||||
})
|
||||
}
|
||||
|
||||
func (m _Migrator) HasIndex(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name,
|
||||
).Row().Scan(&count)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m _Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
var sql string
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
|
||||
if sql != "" {
|
||||
if err := m.DropIndex(value, oldName); err != nil {
|
||||
return err
|
||||
}
|
||||
return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
|
||||
}
|
||||
return fmt.Errorf("failed to find index with name %v", oldName)
|
||||
})
|
||||
}
|
||||
|
||||
func (m _Migrator) DropIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
}
|
||||
|
||||
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
|
||||
})
|
||||
}
|
||||
|
||||
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
|
||||
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
||||
if constraint.OnDelete != "" {
|
||||
sql += " ON DELETE " + constraint.OnDelete
|
||||
}
|
||||
|
||||
if constraint.OnUpdate != "" {
|
||||
sql += " ON UPDATE " + constraint.OnUpdate
|
||||
}
|
||||
|
||||
var foreignKeys, references []interface{}
|
||||
for _, field := range constraint.ForeignKeys {
|
||||
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
for _, field := range constraint.References {
|
||||
references = append(references, clause.Column{Name: field.DBName})
|
||||
}
|
||||
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||
return
|
||||
}
|
||||
|
||||
func (m _Migrator) getRawDDL(table string) (string, error) {
|
||||
var createSQL string
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
|
||||
|
||||
if m.DB.Error != nil {
|
||||
return "", m.DB.Error
|
||||
}
|
||||
return createSQL, nil
|
||||
}
|
||||
|
||||
func (m _Migrator) recreateTable(
|
||||
value interface{}, tablePtr *string,
|
||||
getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error),
|
||||
) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
table := stmt.Table
|
||||
if tablePtr != nil {
|
||||
table = *tablePtr
|
||||
}
|
||||
|
||||
rawDDL, err := m.getRawDDL(table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
originDDL, err := parseDDL(rawDDL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if createDDL == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
newTableName := table + "__temp"
|
||||
if err := createDDL.renameTable(newTableName, table); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
columns := createDDL.getColumns()
|
||||
createSQL := createDDL.compile()
|
||||
|
||||
return m.DB.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
queries := []string{
|
||||
fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), table),
|
||||
fmt.Sprintf("DROP TABLE `%v`", table),
|
||||
fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, table),
|
||||
}
|
||||
for _, query := range queries {
|
||||
if err := tx.Exec(query).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
257
gormlite/sqlite.go
Normal file
257
gormlite/sqlite.go
Normal file
@@ -0,0 +1,257 @@
|
||||
// Package gormlite provides a GORM driver for SQLite.
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strconv"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
)
|
||||
|
||||
// Open opens a GORM dialector from a data source name.
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &_Dialector{DSN: dsn}
|
||||
}
|
||||
|
||||
// Open opens a GORM dialector from a database handle.
|
||||
func OpenDB(db *sql.DB) gorm.Dialector {
|
||||
return &_Dialector{Conn: db}
|
||||
}
|
||||
|
||||
type _Dialector struct {
|
||||
DSN string
|
||||
Conn gorm.ConnPool
|
||||
}
|
||||
|
||||
func (dialector _Dialector) Name() string {
|
||||
return "sqlite"
|
||||
}
|
||||
|
||||
func (dialector _Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
if dialector.Conn != nil {
|
||||
db.ConnPool = dialector.Conn
|
||||
} else {
|
||||
conn, err := driver.Open(dialector.DSN, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.ConnPool = conn
|
||||
}
|
||||
|
||||
var version string
|
||||
if err := db.ConnPool.QueryRowContext(context.Background(), "select sqlite_version()").Scan(&version); err != nil {
|
||||
return err
|
||||
}
|
||||
// https://www.sqlite.org/releaselog/3_35_0.html
|
||||
if compareVersion(version, "3.35.0") >= 0 {
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
|
||||
UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE", "RETURNING"},
|
||||
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
|
||||
LastInsertIDReversed: true,
|
||||
})
|
||||
} else {
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
LastInsertIDReversed: true,
|
||||
})
|
||||
}
|
||||
|
||||
for k, v := range dialector.ClauseBuilders() {
|
||||
db.ClauseBuilders[k] = v
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (dialector _Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
||||
return map[string]clause.ClauseBuilder{
|
||||
"INSERT": func(c clause.Clause, builder clause.Builder) {
|
||||
if insert, ok := c.Expression.(clause.Insert); ok {
|
||||
if stmt, ok := builder.(*gorm.Statement); ok {
|
||||
stmt.WriteString("INSERT ")
|
||||
if insert.Modifier != "" {
|
||||
stmt.WriteString(insert.Modifier)
|
||||
stmt.WriteByte(' ')
|
||||
}
|
||||
|
||||
stmt.WriteString("INTO ")
|
||||
if insert.Table.Name == "" {
|
||||
stmt.WriteQuoted(stmt.Table)
|
||||
} else {
|
||||
stmt.WriteQuoted(insert.Table)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Build(builder)
|
||||
},
|
||||
"LIMIT": func(c clause.Clause, builder clause.Builder) {
|
||||
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||
var lmt = -1
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
lmt = *limit.Limit
|
||||
}
|
||||
if lmt >= 0 || limit.Offset > 0 {
|
||||
builder.WriteString("LIMIT ")
|
||||
builder.WriteString(strconv.Itoa(lmt))
|
||||
}
|
||||
if limit.Offset > 0 {
|
||||
builder.WriteString(" OFFSET ")
|
||||
builder.WriteString(strconv.Itoa(limit.Offset))
|
||||
}
|
||||
}
|
||||
},
|
||||
"FOR": func(c clause.Clause, builder clause.Builder) {
|
||||
if _, ok := c.Expression.(clause.Locking); ok {
|
||||
// SQLite3 does not support row-level locking.
|
||||
return
|
||||
}
|
||||
c.Build(builder)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector _Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
|
||||
if field.AutoIncrement {
|
||||
return clause.Expr{SQL: "NULL"}
|
||||
}
|
||||
|
||||
// doesn't work, will raise error
|
||||
return clause.Expr{SQL: "DEFAULT"}
|
||||
}
|
||||
|
||||
func (dialector _Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return _Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
CreateIndexAfterCreateTable: true,
|
||||
}}}
|
||||
}
|
||||
|
||||
func (dialector _Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||||
writer.WriteByte('?')
|
||||
}
|
||||
|
||||
func (dialector _Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
var (
|
||||
underQuoted, selfQuoted bool
|
||||
continuousBacktick int8
|
||||
shiftDelimiter int8
|
||||
)
|
||||
|
||||
for _, v := range []byte(str) {
|
||||
switch v {
|
||||
case '`':
|
||||
continuousBacktick++
|
||||
if continuousBacktick == 2 {
|
||||
writer.WriteString("``")
|
||||
continuousBacktick = 0
|
||||
}
|
||||
case '.':
|
||||
if continuousBacktick > 0 || !selfQuoted {
|
||||
shiftDelimiter = 0
|
||||
underQuoted = false
|
||||
continuousBacktick = 0
|
||||
writer.WriteString("`")
|
||||
}
|
||||
writer.WriteByte(v)
|
||||
continue
|
||||
default:
|
||||
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
|
||||
writer.WriteString("`")
|
||||
underQuoted = true
|
||||
if selfQuoted = continuousBacktick > 0; selfQuoted {
|
||||
continuousBacktick -= 1
|
||||
}
|
||||
}
|
||||
|
||||
for ; continuousBacktick > 0; continuousBacktick -= 1 {
|
||||
writer.WriteString("``")
|
||||
}
|
||||
|
||||
writer.WriteByte(v)
|
||||
}
|
||||
shiftDelimiter++
|
||||
}
|
||||
|
||||
if continuousBacktick > 0 && !selfQuoted {
|
||||
writer.WriteString("``")
|
||||
}
|
||||
writer.WriteString("`")
|
||||
}
|
||||
|
||||
func (dialector _Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, nil, `"`, vars...)
|
||||
}
|
||||
|
||||
func (dialector _Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
return "numeric"
|
||||
case schema.Int, schema.Uint:
|
||||
if field.AutoIncrement {
|
||||
// doesn't check `PrimaryKey`, to keep backward compatibility
|
||||
// https://www.sqlite.org/autoinc.html
|
||||
return "integer PRIMARY KEY AUTOINCREMENT"
|
||||
} else {
|
||||
return "integer"
|
||||
}
|
||||
case schema.Float:
|
||||
return "real"
|
||||
case schema.String:
|
||||
return "text"
|
||||
case schema.Time:
|
||||
// Distinguish between schema.Time and tag time
|
||||
if val, ok := field.TagSettings["TYPE"]; ok {
|
||||
return val
|
||||
} else {
|
||||
return "datetime"
|
||||
}
|
||||
case schema.Bytes:
|
||||
return "blob"
|
||||
}
|
||||
|
||||
return string(field.DataType)
|
||||
}
|
||||
|
||||
func (dialectopr _Dialector) SavePoint(tx *gorm.DB, name string) error {
|
||||
tx.Exec("SAVEPOINT " + name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dialectopr _Dialector) RollbackTo(tx *gorm.DB, name string) error {
|
||||
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func compareVersion(version1, version2 string) int {
|
||||
n, m := len(version1), len(version2)
|
||||
i, j := 0, 0
|
||||
for i < n || j < m {
|
||||
x := 0
|
||||
for ; i < n && version1[i] != '.'; i++ {
|
||||
x = x*10 + int(version1[i]-'0')
|
||||
}
|
||||
i++
|
||||
y := 0
|
||||
for ; j < m && version2[j] != '.'; j++ {
|
||||
y = y*10 + int(version2[j]-'0')
|
||||
}
|
||||
j++
|
||||
if x > y {
|
||||
return 1
|
||||
}
|
||||
if x < y {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
96
gormlite/sqlite_test.go
Normal file
96
gormlite/sqlite_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestDialector(t *testing.T) {
|
||||
// This is the DSN of the in-memory SQLite database for these tests.
|
||||
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
|
||||
|
||||
// Custom connection with a custom function called "my_custom_function".
|
||||
db, err := driver.Open(InMemoryDSN, func(conn *sqlite3.Conn) error {
|
||||
return conn.CreateFunction("my_custom_function", 0, sqlite3.DETERMINISTIC,
|
||||
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
ctx.ResultText("my-result")
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows := []struct {
|
||||
description string
|
||||
dialector gorm.Dialector
|
||||
openSuccess bool
|
||||
query string
|
||||
querySuccess bool
|
||||
}{
|
||||
{
|
||||
description: "Default driver",
|
||||
dialector: Open(InMemoryDSN),
|
||||
openSuccess: true,
|
||||
query: "SELECT 1",
|
||||
querySuccess: true,
|
||||
},
|
||||
{
|
||||
description: "Custom function",
|
||||
dialector: Open(InMemoryDSN),
|
||||
openSuccess: true,
|
||||
query: "SELECT my_custom_function()",
|
||||
querySuccess: false,
|
||||
},
|
||||
{
|
||||
description: "Custom connection",
|
||||
dialector: OpenDB(db),
|
||||
openSuccess: true,
|
||||
query: "SELECT 1",
|
||||
querySuccess: true,
|
||||
},
|
||||
{
|
||||
description: "Custom connection, custom function",
|
||||
dialector: OpenDB(db),
|
||||
openSuccess: true,
|
||||
query: "SELECT my_custom_function()",
|
||||
querySuccess: true,
|
||||
},
|
||||
}
|
||||
for rowIndex, row := range rows {
|
||||
t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) {
|
||||
db, err := gorm.Open(row.dialector, &gorm.Config{})
|
||||
if !row.openSuccess {
|
||||
if err == nil {
|
||||
t.Errorf("Expected Open to fail.")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected Open to succeed; got error: %v", err)
|
||||
}
|
||||
if db == nil {
|
||||
t.Errorf("Expected db to be non-nil.")
|
||||
}
|
||||
if row.query != "" {
|
||||
err = db.Exec(row.query).Error
|
||||
if !row.querySuccess {
|
||||
if err == nil {
|
||||
t.Errorf("Expected query to fail.")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected query to succeed; got error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
24
gormlite/test.sh
Executable file
24
gormlite/test.sh
Executable file
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
rm -rf gorm/ tests/
|
||||
git clone --filter=blob:none https://github.com/go-gorm/gorm.git
|
||||
mv gorm/tests tests
|
||||
rm -rf gorm/
|
||||
|
||||
patch -p1 -N < tests.patch
|
||||
|
||||
cd tests
|
||||
go mod edit \
|
||||
-require github.com/ncruces/go-sqlite3/gormlite@v0.0.0 \
|
||||
-replace github.com/ncruces/go-sqlite3/gormlite=../ \
|
||||
-replace github.com/ncruces/go-sqlite3=../../ \
|
||||
-droprequire gorm.io/driver/sqlite \
|
||||
-dropreplace gorm.io/gorm
|
||||
go mod tidy && go work use . && go test
|
||||
|
||||
cd ..
|
||||
rm -rf tests/
|
||||
go work use -r .
|
||||
31
gormlite/tests.patch
Normal file
31
gormlite/tests.patch
Normal file
@@ -0,0 +1,31 @@
|
||||
diff --git a/tests/.gitignore b/tests/.gitignore
|
||||
--- a/tests/.gitignore
|
||||
+++ b/tests/.gitignore
|
||||
@@ -1 +1 @@
|
||||
-go.sum
|
||||
+*
|
||||
diff --git a/tests/tests_test.go b/tests/tests_test.go
|
||||
--- a/tests/tests_test.go
|
||||
+++ b/tests/tests_test.go
|
||||
@@ -7,9 +7,11 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
+ _ "github.com/ncruces/go-sqlite3/embed"
|
||||
+ sqlite "github.com/ncruces/go-sqlite3/gormlite"
|
||||
+
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
- "gorm.io/driver/sqlite"
|
||||
"gorm.io/driver/sqlserver"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
@@ -89,7 +91,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
|
||||
db, err = gorm.Open(mysql.Open(dbDSN), cfg)
|
||||
default:
|
||||
log.Println("testing sqlite3...")
|
||||
- db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg)
|
||||
+ db, err = gorm.Open(sqlite.Open("file:"+filepath.Join(os.TempDir(), "gorm.db")+"?_pragma=busy_timeout(1000)&_pragma=foreign_keys(1)"), cfg)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -72,6 +72,7 @@ const (
|
||||
IOERR_ROLLBACK_ATOMIC = IOERR | (31 << 8)
|
||||
IOERR_DATA = IOERR | (32 << 8)
|
||||
IOERR_CORRUPTFS = IOERR | (33 << 8)
|
||||
IOERR_IN_PAGE = IOERR | (34 << 8)
|
||||
LOCKED_SHAREDCACHE = LOCKED | (1 << 8)
|
||||
LOCKED_VTAB = LOCKED | (2 << 8)
|
||||
BUSY_RECOVERY = BUSY | (1 << 8)
|
||||
|
||||
@@ -23,6 +23,7 @@ const (
|
||||
OffsetErr = ErrorString("sqlite3: invalid offset")
|
||||
TailErr = ErrorString("sqlite3: multiple statements")
|
||||
IsolationErr = ErrorString("sqlite3: unsupported isolation level")
|
||||
ValueErr = ErrorString("sqlite3: unsupported value")
|
||||
NoVFSErr = ErrorString("sqlite3: no such vfs: ")
|
||||
)
|
||||
|
||||
|
||||
@@ -10,6 +10,32 @@ import (
|
||||
type i32 interface{ ~int32 | ~uint32 }
|
||||
type i64 interface{ ~int64 | ~uint64 }
|
||||
|
||||
type funcVI[T0 i32] func(context.Context, api.Module, T0)
|
||||
|
||||
func (fn funcVI[T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
fn(ctx, mod, T0(stack[0]))
|
||||
}
|
||||
|
||||
func ExportFuncVI[T0 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0)) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(funcVI[T0](fn),
|
||||
[]api.ValueType{api.ValueTypeI32}, nil).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
type funcVIII[T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2)
|
||||
|
||||
func (fn funcVIII[T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]))
|
||||
}
|
||||
|
||||
func ExportFuncVIII[T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2)) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(funcVIII[T0, T1, T2](fn),
|
||||
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, nil).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
type funcII[TR, T0 i32] func(context.Context, api.Module, T0) TR
|
||||
|
||||
func (fn funcII[TR, T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
|
||||
75
internal/util/handle.go
Normal file
75
internal/util/handle.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/tetratelabs/wazero/experimental"
|
||||
)
|
||||
|
||||
type handleKey struct{}
|
||||
type handleState struct {
|
||||
handles []any
|
||||
empty int
|
||||
}
|
||||
|
||||
func NewContext(ctx context.Context) context.Context {
|
||||
state := new(handleState)
|
||||
ctx = experimental.WithCloseNotifier(ctx, state)
|
||||
ctx = context.WithValue(ctx, handleKey{}, state)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (s *handleState) CloseNotify(ctx context.Context, exitCode uint32) {
|
||||
for _, h := range s.handles {
|
||||
if c, ok := h.(io.Closer); ok {
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
s.handles = nil
|
||||
s.empty = 0
|
||||
}
|
||||
|
||||
func GetHandle(ctx context.Context, id uint32) any {
|
||||
if id == 0 {
|
||||
return nil
|
||||
}
|
||||
s := ctx.Value(handleKey{}).(*handleState)
|
||||
return s.handles[^id]
|
||||
}
|
||||
|
||||
func DelHandle(ctx context.Context, id uint32) error {
|
||||
if id == 0 {
|
||||
return nil
|
||||
}
|
||||
s := ctx.Value(handleKey{}).(*handleState)
|
||||
a := s.handles[^id]
|
||||
s.handles[^id] = nil
|
||||
s.empty++
|
||||
if c, ok := a.(io.Closer); ok {
|
||||
return c.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddHandle(ctx context.Context, a any) (id uint32) {
|
||||
if a == nil {
|
||||
panic(NilErr)
|
||||
}
|
||||
s := ctx.Value(handleKey{}).(*handleState)
|
||||
|
||||
// Find an empty slot.
|
||||
if s.empty > cap(s.handles)-len(s.handles) {
|
||||
for id, h := range s.handles {
|
||||
if h == nil {
|
||||
s.empty--
|
||||
s.handles[id] = a
|
||||
return ^uint32(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add a new slot.
|
||||
s.handles = append(s.handles, a)
|
||||
return -uint32(len(s.handles))
|
||||
}
|
||||
@@ -14,6 +14,9 @@ func View(mod api.Module, ptr uint32, size uint64) []byte {
|
||||
if size > math.MaxUint32 {
|
||||
panic(RangeErr)
|
||||
}
|
||||
if size == 0 {
|
||||
return nil
|
||||
}
|
||||
buf, ok := mod.Memory().Read(ptr, uint32(size))
|
||||
if !ok {
|
||||
panic(RangeErr)
|
||||
|
||||
46
json.go
Normal file
46
json.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// JSON returns a value that can be used as an argument to
|
||||
// [database/sql.DB.Exec], [database/sql.Row.Scan] and similar methods to
|
||||
// store value as JSON, or decode JSON into value.
|
||||
func JSON(value any) any {
|
||||
return jsonValue{value}
|
||||
}
|
||||
|
||||
type jsonValue struct{ any }
|
||||
|
||||
func (j jsonValue) JSON() any { return j.any }
|
||||
|
||||
func (j jsonValue) Scan(value any) error {
|
||||
var buf []byte
|
||||
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
buf = v
|
||||
case string:
|
||||
buf = unsafe.Slice(unsafe.StringData(v), len(v))
|
||||
case int64:
|
||||
buf = strconv.AppendInt(nil, v, 10)
|
||||
case float64:
|
||||
buf = strconv.AppendFloat(nil, v, 'g', -1, 64)
|
||||
case time.Time:
|
||||
buf = append(buf, '"')
|
||||
buf = v.AppendFormat(buf, time.RFC3339Nano)
|
||||
buf = append(buf, '"')
|
||||
case nil:
|
||||
buf = append(buf, "null"...)
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
|
||||
return json.Unmarshal(buf, j.any)
|
||||
}
|
||||
14
pointer.go
Normal file
14
pointer.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package sqlite3
|
||||
|
||||
// Pointer returns a pointer to a value
|
||||
// that can be used as an argument to
|
||||
// [database/sql.DB.Exec] and similar methods.
|
||||
//
|
||||
// https://www.sqlite.org/bindptr.html
|
||||
func Pointer[T any](val T) any {
|
||||
return pointer[T]{val}
|
||||
}
|
||||
|
||||
type pointer[T any] struct{ val T }
|
||||
|
||||
func (p pointer[T]) Pointer() any { return p.val }
|
||||
112
quote.go
Normal file
112
quote.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Quote escapes and quotes a value
|
||||
// making it safe to embed in SQL text.
|
||||
func Quote(value any) string {
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return "NULL"
|
||||
case bool:
|
||||
if v {
|
||||
return "1"
|
||||
} else {
|
||||
return "0"
|
||||
}
|
||||
|
||||
case int:
|
||||
return strconv.Itoa(v)
|
||||
case int64:
|
||||
return strconv.FormatInt(v, 10)
|
||||
case float64:
|
||||
switch {
|
||||
case math.IsNaN(v):
|
||||
return "NULL"
|
||||
case math.IsInf(v, 1):
|
||||
return "9.0e999"
|
||||
case math.IsInf(v, -1):
|
||||
return "-9.0e999"
|
||||
}
|
||||
return strconv.FormatFloat(v, 'g', -1, 64)
|
||||
case time.Time:
|
||||
return "'" + v.Format(time.RFC3339Nano) + "'"
|
||||
|
||||
case string:
|
||||
if strings.IndexByte(v, 0) >= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
buf := make([]byte, 2+len(v)+strings.Count(v, "'"))
|
||||
buf[0] = '\''
|
||||
i := 1
|
||||
for _, b := range []byte(v) {
|
||||
if b == '\'' {
|
||||
buf[i] = b
|
||||
i += 1
|
||||
}
|
||||
buf[i] = b
|
||||
i += 1
|
||||
}
|
||||
buf[i] = '\''
|
||||
return unsafe.String(&buf[0], len(buf))
|
||||
|
||||
case []byte:
|
||||
buf := make([]byte, 3+2*len(v))
|
||||
buf[0] = 'x'
|
||||
buf[1] = '\''
|
||||
i := 2
|
||||
for _, b := range v {
|
||||
const hex = "0123456789ABCDEF"
|
||||
buf[i+0] = hex[b/16]
|
||||
buf[i+1] = hex[b%16]
|
||||
i += 2
|
||||
}
|
||||
buf[i] = '\''
|
||||
return unsafe.String(&buf[0], len(buf))
|
||||
|
||||
case ZeroBlob:
|
||||
if v > ZeroBlob(1e9-3)/2 {
|
||||
break
|
||||
}
|
||||
|
||||
buf := bytes.Repeat([]byte("0"), int(3+2*int64(v)))
|
||||
buf[0] = 'x'
|
||||
buf[1] = '\''
|
||||
buf[len(buf)-1] = '\''
|
||||
return unsafe.String(&buf[0], len(buf))
|
||||
}
|
||||
|
||||
panic(util.ValueErr)
|
||||
}
|
||||
|
||||
// QuoteIdentifier escapes and quotes an identifier
|
||||
// making it safe to embed in SQL text.
|
||||
func QuoteIdentifier(id string) string {
|
||||
if strings.IndexByte(id, 0) >= 0 {
|
||||
panic(util.ValueErr)
|
||||
}
|
||||
|
||||
buf := make([]byte, 2+len(id)+strings.Count(id, `"`))
|
||||
buf[0] = '"'
|
||||
i := 1
|
||||
for _, b := range []byte(id) {
|
||||
if b == '"' {
|
||||
buf[i] = b
|
||||
i += 1
|
||||
}
|
||||
buf[i] = b
|
||||
i += 1
|
||||
}
|
||||
buf[i] = '"'
|
||||
return unsafe.String(&buf[0], len(buf))
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -23,72 +22,72 @@ import (
|
||||
var (
|
||||
Binary []byte // WASM binary to load.
|
||||
Path string // Path to load the binary from.
|
||||
|
||||
RuntimeConfig wazero.RuntimeConfig
|
||||
)
|
||||
|
||||
var sqlite3 struct {
|
||||
var instance struct {
|
||||
runtime wazero.Runtime
|
||||
compiled wazero.CompiledModule
|
||||
err error
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func instantiateModule() (*module, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
sqlite3.once.Do(compileModule)
|
||||
if sqlite3.err != nil {
|
||||
return nil, sqlite3.err
|
||||
func compileSQLite() {
|
||||
if RuntimeConfig == nil {
|
||||
RuntimeConfig = wazero.NewRuntimeConfig()
|
||||
}
|
||||
|
||||
cfg := wazero.NewModuleConfig()
|
||||
|
||||
mod, err := sqlite3.runtime.InstantiateModule(ctx, sqlite3.compiled, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newModule(mod)
|
||||
}
|
||||
|
||||
func compileModule() {
|
||||
ctx := context.Background()
|
||||
sqlite3.runtime = wazero.NewRuntime(ctx)
|
||||
instance.runtime = wazero.NewRuntimeWithConfig(ctx, RuntimeConfig)
|
||||
|
||||
env := vfs.ExportHostFunctions(sqlite3.runtime.NewHostModuleBuilder("env"))
|
||||
_, sqlite3.err = env.Instantiate(ctx)
|
||||
if sqlite3.err != nil {
|
||||
env := instance.runtime.NewHostModuleBuilder("env")
|
||||
env = vfs.ExportHostFunctions(env)
|
||||
env = exportCallbacks(env)
|
||||
_, instance.err = env.Instantiate(ctx)
|
||||
if instance.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
bin := Binary
|
||||
if bin == nil && Path != "" {
|
||||
bin, sqlite3.err = os.ReadFile(Path)
|
||||
if sqlite3.err != nil {
|
||||
bin, instance.err = os.ReadFile(Path)
|
||||
if instance.err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if bin == nil {
|
||||
sqlite3.err = util.BinaryErr
|
||||
instance.err = util.BinaryErr
|
||||
return
|
||||
}
|
||||
|
||||
sqlite3.compiled, sqlite3.err = sqlite3.runtime.CompileModule(ctx, bin)
|
||||
instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin)
|
||||
}
|
||||
|
||||
type module struct {
|
||||
ctx context.Context
|
||||
mod api.Module
|
||||
vfs io.Closer
|
||||
api sqliteAPI
|
||||
arg [8]uint64
|
||||
type sqlite struct {
|
||||
ctx context.Context
|
||||
mod api.Module
|
||||
api sqliteAPI
|
||||
stack [8]uint64
|
||||
}
|
||||
|
||||
func newModule(mod api.Module) (m *module, err error) {
|
||||
m = new(module)
|
||||
m.mod = mod
|
||||
m.ctx, m.vfs = vfs.NewContext(context.Background())
|
||||
func instantiateSQLite() (sqlt *sqlite, err error) {
|
||||
instance.once.Do(compileSQLite)
|
||||
if instance.err != nil {
|
||||
return nil, instance.err
|
||||
}
|
||||
|
||||
sqlt = new(sqlite)
|
||||
sqlt.ctx = util.NewContext(context.Background())
|
||||
|
||||
sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
|
||||
instance.compiled, wazero.NewModuleConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getFun := func(name string) api.Function {
|
||||
f := mod.ExportedFunction(name)
|
||||
f := sqlt.mod.ExportedFunction(name)
|
||||
if f == nil {
|
||||
err = util.NoFuncErr + util.ErrorString(name)
|
||||
return nil
|
||||
@@ -97,15 +96,15 @@ func newModule(mod api.Module) (m *module, err error) {
|
||||
}
|
||||
|
||||
getVal := func(name string) uint32 {
|
||||
g := mod.ExportedGlobal(name)
|
||||
g := sqlt.mod.ExportedGlobal(name)
|
||||
if g == nil {
|
||||
err = util.NoGlobalErr + util.ErrorString(name)
|
||||
return 0
|
||||
}
|
||||
return util.ReadUint32(mod, uint32(g.Get()))
|
||||
return util.ReadUint32(sqlt.mod, uint32(g.Get()))
|
||||
}
|
||||
|
||||
m.api = sqliteAPI{
|
||||
sqlt.api = sqliteAPI{
|
||||
free: getFun("free"),
|
||||
malloc: getFun("malloc"),
|
||||
destructor: getVal("malloc_destructor"),
|
||||
@@ -121,6 +120,8 @@ func newModule(mod api.Module) (m *module, err error) {
|
||||
reset: getFun("sqlite3_reset"),
|
||||
step: getFun("sqlite3_step"),
|
||||
exec: getFun("sqlite3_exec"),
|
||||
interrupt: getFun("sqlite3_interrupt"),
|
||||
progressHandler: getFun("sqlite3_progress_handler_go"),
|
||||
clearBindings: getFun("sqlite3_clear_bindings"),
|
||||
bindCount: getFun("sqlite3_bind_parameter_count"),
|
||||
bindIndex: getFun("sqlite3_bind_parameter_index"),
|
||||
@@ -131,6 +132,7 @@ func newModule(mod api.Module) (m *module, err error) {
|
||||
bindText: getFun("sqlite3_bind_text64"),
|
||||
bindBlob: getFun("sqlite3_bind_blob64"),
|
||||
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
|
||||
bindPointer: getFun("sqlite3_bind_pointer_go"),
|
||||
columnCount: getFun("sqlite3_column_count"),
|
||||
columnName: getFun("sqlite3_column_name"),
|
||||
columnType: getFun("sqlite3_column_type"),
|
||||
@@ -153,20 +155,46 @@ func newModule(mod api.Module) (m *module, err error) {
|
||||
changes: getFun("sqlite3_changes64"),
|
||||
lastRowid: getFun("sqlite3_last_insert_rowid"),
|
||||
autocommit: getFun("sqlite3_get_autocommit"),
|
||||
anyCollation: getFun("sqlite3_anycollseq_init"),
|
||||
createCollation: getFun("sqlite3_create_collation_go"),
|
||||
createFunction: getFun("sqlite3_create_function_go"),
|
||||
createAggregate: getFun("sqlite3_create_aggregate_function_go"),
|
||||
createWindow: getFun("sqlite3_create_window_function_go"),
|
||||
aggregateCtx: getFun("sqlite3_aggregate_context"),
|
||||
userData: getFun("sqlite3_user_data"),
|
||||
setAuxData: getFun("sqlite3_set_auxdata_go"),
|
||||
getAuxData: getFun("sqlite3_get_auxdata"),
|
||||
valueType: getFun("sqlite3_value_type"),
|
||||
valueInteger: getFun("sqlite3_value_int64"),
|
||||
valueFloat: getFun("sqlite3_value_double"),
|
||||
valueText: getFun("sqlite3_value_text"),
|
||||
valueBlob: getFun("sqlite3_value_blob"),
|
||||
valueBytes: getFun("sqlite3_value_bytes"),
|
||||
valuePointer: getFun("sqlite3_value_pointer_go"),
|
||||
resultNull: getFun("sqlite3_result_null"),
|
||||
resultInteger: getFun("sqlite3_result_int64"),
|
||||
resultFloat: getFun("sqlite3_result_double"),
|
||||
resultText: getFun("sqlite3_result_text64"),
|
||||
resultBlob: getFun("sqlite3_result_blob64"),
|
||||
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
|
||||
resultPointer: getFun("sqlite3_result_pointer_go"),
|
||||
resultValue: getFun("sqlite3_result_value"),
|
||||
resultError: getFun("sqlite3_result_error"),
|
||||
resultErrorCode: getFun("sqlite3_result_error_code"),
|
||||
resultErrorMem: getFun("sqlite3_result_error_nomem"),
|
||||
resultErrorBig: getFun("sqlite3_result_error_toobig"),
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
return sqlt, nil
|
||||
}
|
||||
|
||||
func (m *module) close() error {
|
||||
err := m.mod.Close(m.ctx)
|
||||
m.vfs.Close()
|
||||
return err
|
||||
func (sqlt *sqlite) close() error {
|
||||
return sqlt.mod.Close(sqlt.ctx)
|
||||
}
|
||||
|
||||
func (m *module) error(rc uint64, handle uint32, sql ...string) error {
|
||||
func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
|
||||
if rc == _OK {
|
||||
return nil
|
||||
}
|
||||
@@ -177,17 +205,19 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
|
||||
if r := m.call(m.api.errstr, rc); r != 0 {
|
||||
err.str = util.ReadString(m.mod, uint32(r), _MAX_STRING)
|
||||
if r := sqlt.call(sqlt.api.errstr, rc); r != 0 {
|
||||
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
|
||||
}
|
||||
|
||||
if r := m.call(m.api.errmsg, uint64(handle)); r != 0 {
|
||||
err.msg = util.ReadString(m.mod, uint32(r), _MAX_STRING)
|
||||
}
|
||||
if handle != 0 {
|
||||
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
|
||||
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
|
||||
}
|
||||
|
||||
if sql != nil {
|
||||
if r := m.call(m.api.erroff, uint64(handle)); r != math.MaxUint32 {
|
||||
err.sql = sql[0][r:]
|
||||
if sql != nil {
|
||||
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
|
||||
err.sql = sql[0][r:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,60 +228,58 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
|
||||
return &err
|
||||
}
|
||||
|
||||
func (m *module) call(fn api.Function, params ...uint64) uint64 {
|
||||
copy(m.arg[:], params)
|
||||
err := fn.CallWithStack(m.ctx, m.arg[:])
|
||||
func (sqlt *sqlite) call(fn api.Function, params ...uint64) uint64 {
|
||||
copy(sqlt.stack[:], params)
|
||||
err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
|
||||
if err != nil {
|
||||
// The module closed or panicked; release resources.
|
||||
m.vfs.Close()
|
||||
panic(err)
|
||||
}
|
||||
return m.arg[0]
|
||||
return sqlt.stack[0]
|
||||
}
|
||||
|
||||
func (m *module) free(ptr uint32) {
|
||||
func (sqlt *sqlite) free(ptr uint32) {
|
||||
if ptr == 0 {
|
||||
return
|
||||
}
|
||||
m.call(m.api.free, uint64(ptr))
|
||||
sqlt.call(sqlt.api.free, uint64(ptr))
|
||||
}
|
||||
|
||||
func (m *module) new(size uint64) uint32 {
|
||||
func (sqlt *sqlite) new(size uint64) uint32 {
|
||||
if size > _MAX_ALLOCATION_SIZE {
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
ptr := uint32(m.call(m.api.malloc, size))
|
||||
ptr := uint32(sqlt.call(sqlt.api.malloc, size))
|
||||
if ptr == 0 && size != 0 {
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (m *module) newBytes(b []byte) uint32 {
|
||||
if b == nil {
|
||||
func (sqlt *sqlite) newBytes(b []byte) uint32 {
|
||||
if (*[0]byte)(b) == nil {
|
||||
return 0
|
||||
}
|
||||
ptr := m.new(uint64(len(b)))
|
||||
util.WriteBytes(m.mod, ptr, b)
|
||||
ptr := sqlt.new(uint64(len(b)))
|
||||
util.WriteBytes(sqlt.mod, ptr, b)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (m *module) newString(s string) uint32 {
|
||||
ptr := m.new(uint64(len(s) + 1))
|
||||
util.WriteString(m.mod, ptr, s)
|
||||
func (sqlt *sqlite) newString(s string) uint32 {
|
||||
ptr := sqlt.new(uint64(len(s) + 1))
|
||||
util.WriteString(sqlt.mod, ptr, s)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (m *module) newArena(size uint64) arena {
|
||||
func (sqlt *sqlite) newArena(size uint64) arena {
|
||||
return arena{
|
||||
m: m,
|
||||
base: m.new(size),
|
||||
sqlt: sqlt,
|
||||
size: uint32(size),
|
||||
base: sqlt.new(size),
|
||||
}
|
||||
}
|
||||
|
||||
type arena struct {
|
||||
m *module
|
||||
sqlt *sqlite
|
||||
ptrs []uint32
|
||||
base uint32
|
||||
next uint32
|
||||
@@ -259,17 +287,17 @@ type arena struct {
|
||||
}
|
||||
|
||||
func (a *arena) free() {
|
||||
if a.m == nil {
|
||||
if a.sqlt == nil {
|
||||
return
|
||||
}
|
||||
a.reset()
|
||||
a.m.free(a.base)
|
||||
a.m = nil
|
||||
a.sqlt.free(a.base)
|
||||
a.sqlt = nil
|
||||
}
|
||||
|
||||
func (a *arena) reset() {
|
||||
for _, ptr := range a.ptrs {
|
||||
a.m.free(ptr)
|
||||
a.sqlt.free(ptr)
|
||||
}
|
||||
a.ptrs = nil
|
||||
a.next = 0
|
||||
@@ -281,7 +309,7 @@ func (a *arena) new(size uint64) uint32 {
|
||||
a.next += uint32(size)
|
||||
return ptr
|
||||
}
|
||||
ptr := a.m.new(size)
|
||||
ptr := a.sqlt.new(size)
|
||||
a.ptrs = append(a.ptrs, ptr)
|
||||
return ptr
|
||||
}
|
||||
@@ -291,13 +319,13 @@ func (a *arena) bytes(b []byte) uint32 {
|
||||
return 0
|
||||
}
|
||||
ptr := a.new(uint64(len(b)))
|
||||
util.WriteBytes(a.m.mod, ptr, b)
|
||||
util.WriteBytes(a.sqlt.mod, ptr, b)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (a *arena) string(s string) uint32 {
|
||||
ptr := a.new(uint64(len(s) + 1))
|
||||
util.WriteString(a.m.mod, ptr, s)
|
||||
util.WriteString(a.sqlt.mod, ptr, s)
|
||||
return ptr
|
||||
}
|
||||
|
||||
@@ -316,16 +344,19 @@ type sqliteAPI struct {
|
||||
reset api.Function
|
||||
step api.Function
|
||||
exec api.Function
|
||||
interrupt api.Function
|
||||
progressHandler api.Function
|
||||
clearBindings api.Function
|
||||
bindNull api.Function
|
||||
bindCount api.Function
|
||||
bindIndex api.Function
|
||||
bindName api.Function
|
||||
bindNull api.Function
|
||||
bindInteger api.Function
|
||||
bindFloat api.Function
|
||||
bindText api.Function
|
||||
bindBlob api.Function
|
||||
bindZeroBlob api.Function
|
||||
bindPointer api.Function
|
||||
columnCount api.Function
|
||||
columnName api.Function
|
||||
columnType api.Function
|
||||
@@ -348,5 +379,45 @@ type sqliteAPI struct {
|
||||
changes api.Function
|
||||
lastRowid api.Function
|
||||
autocommit api.Function
|
||||
anyCollation api.Function
|
||||
createCollation api.Function
|
||||
createFunction api.Function
|
||||
createAggregate api.Function
|
||||
createWindow api.Function
|
||||
aggregateCtx api.Function
|
||||
userData api.Function
|
||||
setAuxData api.Function
|
||||
getAuxData api.Function
|
||||
valueType api.Function
|
||||
valueInteger api.Function
|
||||
valueFloat api.Function
|
||||
valueText api.Function
|
||||
valueBlob api.Function
|
||||
valueBytes api.Function
|
||||
valuePointer api.Function
|
||||
resultNull api.Function
|
||||
resultInteger api.Function
|
||||
resultFloat api.Function
|
||||
resultText api.Function
|
||||
resultBlob api.Function
|
||||
resultZeroBlob api.Function
|
||||
resultPointer api.Function
|
||||
resultValue api.Function
|
||||
resultError api.Function
|
||||
resultErrorCode api.Function
|
||||
resultErrorMem api.Function
|
||||
resultErrorBig api.Function
|
||||
destructor uint32
|
||||
}
|
||||
|
||||
func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
|
||||
util.ExportFuncII(env, "go_progress", callbackProgress)
|
||||
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
|
||||
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
|
||||
util.ExportFuncVIII(env, "go_func", callbackFunc)
|
||||
util.ExportFuncVIII(env, "go_step", callbackStep)
|
||||
util.ExportFuncVI(env, "go_final", callbackFinal)
|
||||
util.ExportFuncVI(env, "go_value", callbackValue)
|
||||
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
|
||||
return env
|
||||
}
|
||||
1
sqlite3/.gitignore
vendored
1
sqlite3/.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
ext/
|
||||
sqlite3.c
|
||||
sqlite3.h
|
||||
sqlite3ext.h
|
||||
@@ -1,20 +0,0 @@
|
||||
--- sqlite3.c.orig
|
||||
+++ sqlite3.c
|
||||
@@ -60425,7 +60425,7 @@
|
||||
int rc = SQLITE_OK; /* Return code */
|
||||
int tempFile = 0; /* True for temp files (incl. in-memory files) */
|
||||
int memDb = 0; /* True if this is an in-memory file */
|
||||
-#ifndef SQLITE_OMIT_DESERIALIZE
|
||||
+#if 1
|
||||
int memJM = 0; /* Memory journal mode */
|
||||
#else
|
||||
# define memJM 0
|
||||
@@ -60628,7 +60628,7 @@
|
||||
int fout = 0; /* VFS flags returned by xOpen() */
|
||||
rc = sqlite3OsOpen(pVfs, pPager->zFilename, pPager->fd, vfsFlags, &fout);
|
||||
assert( !memDb );
|
||||
-#ifndef SQLITE_OMIT_DESERIALIZE
|
||||
+#if 1
|
||||
pPager->memVfs = memJM = (fout&SQLITE_OPEN_MEMORY)!=0;
|
||||
#endif
|
||||
readOnly = (fout&SQLITE_OPEN_READONLY)!=0;
|
||||
@@ -3,32 +3,33 @@ set -euo pipefail
|
||||
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3420000.zip"
|
||||
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3440000.zip"
|
||||
unzip -d . sqlite-amalgamation-*.zip
|
||||
mv sqlite-amalgamation-*/sqlite3* .
|
||||
rm -rf sqlite-amalgamation-*
|
||||
|
||||
patch < vfs_find.patch
|
||||
patch < deserialize.patch
|
||||
cat *.patch | patch --posix
|
||||
|
||||
mkdir -p ext/
|
||||
cd ext/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/decimal.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uint.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uuid.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/base64.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/regexp.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/series.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/decimal.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/uint.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/uuid.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/base64.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/regexp.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/series.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/anycollseq.c"
|
||||
cd ~-
|
||||
|
||||
cd ../vfs/tests/mptest/testdata/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/mptest.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config02.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash02.subtest"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/multiwrite01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/mptest.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/config01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/config02.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/crash01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/crash02.subtest"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/multiwrite01.test"
|
||||
cd ~-
|
||||
|
||||
cd ../vfs/tests/speedtest1/testdata/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/test/speedtest1.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/test/speedtest1.c"
|
||||
cd ~-
|
||||
1
sqlite3/ext/.gitignore
vendored
1
sqlite3/ext/.gitignore
vendored
@@ -1 +0,0 @@
|
||||
*.c
|
||||
55
sqlite3/func.c
Normal file
55
sqlite3/func.c
Normal file
@@ -0,0 +1,55 @@
|
||||
#include <stddef.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
|
||||
int go_compare(void *, int, const void *, int, const void *);
|
||||
void go_func(sqlite3_context *, int, sqlite3_value **);
|
||||
void go_step(sqlite3_context *, int, sqlite3_value **);
|
||||
void go_final(sqlite3_context *);
|
||||
void go_value(sqlite3_context *);
|
||||
void go_inverse(sqlite3_context *, int, sqlite3_value **);
|
||||
void go_destroy(void *);
|
||||
|
||||
int sqlite3_create_collation_go(sqlite3 *db, const char *zName, void *pApp) {
|
||||
return sqlite3_create_collation_v2(db, zName, SQLITE_UTF8, pApp, go_compare,
|
||||
go_destroy);
|
||||
}
|
||||
|
||||
int sqlite3_create_function_go(sqlite3 *db, const char *zName, int nArg,
|
||||
int flags, void *pApp) {
|
||||
return sqlite3_create_function_v2(db, zName, nArg, SQLITE_UTF8 | flags, pApp,
|
||||
go_func, /*step=*/NULL, /*final=*/NULL,
|
||||
go_destroy);
|
||||
}
|
||||
|
||||
int sqlite3_create_aggregate_function_go(sqlite3 *db, const char *zName,
|
||||
int nArg, int flags, void *pApp) {
|
||||
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
|
||||
pApp, go_step, go_final, /*value=*/NULL,
|
||||
/*inverse=*/NULL, go_destroy);
|
||||
}
|
||||
|
||||
int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg,
|
||||
int flags, void *pApp) {
|
||||
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
|
||||
pApp, go_step, go_final, go_value,
|
||||
go_inverse, go_destroy);
|
||||
}
|
||||
|
||||
void sqlite3_set_auxdata_go(sqlite3_context *ctx, int iArg, void *pAux) {
|
||||
sqlite3_set_auxdata(ctx, iArg, pAux, go_destroy);
|
||||
}
|
||||
|
||||
#define GO_POINTER_TYPE "github.com/ncruces/go-sqlite3.Pointer"
|
||||
|
||||
int sqlite3_bind_pointer_go(sqlite3_stmt *stmt, int i, void *pApp) {
|
||||
return sqlite3_bind_pointer(stmt, i, pApp, GO_POINTER_TYPE, go_destroy);
|
||||
}
|
||||
|
||||
void sqlite3_result_pointer_go(sqlite3_context *ctx, void *pApp) {
|
||||
sqlite3_result_pointer(ctx, pApp, GO_POINTER_TYPE, go_destroy);
|
||||
}
|
||||
|
||||
void *sqlite3_value_pointer_go(sqlite3_value *val) {
|
||||
return sqlite3_value_pointer(val, GO_POINTER_TYPE);
|
||||
}
|
||||
34
sqlite3/isoweek.patch
Normal file
34
sqlite3/isoweek.patch
Normal file
@@ -0,0 +1,34 @@
|
||||
# ISO week date specifiers.
|
||||
# https://sqlite.org/forum/forumpost/73d99e4497e8e6a7
|
||||
--- sqlite3.c.orig
|
||||
+++ sqlite3.c
|
||||
@@ -1373,6 +1373,29 @@ static void strftimeFunc(
|
||||
sqlite3_str_appendchar(&sRes, 1, c);
|
||||
break;
|
||||
}
|
||||
+ case 'V': /* Fall thru */
|
||||
+ case 'G': {
|
||||
+ DateTime y = x;
|
||||
+ computeJD(&y);
|
||||
+ y.validYMD = 0;
|
||||
+ /* Adjust date to Thursday this week:
|
||||
+ The number in parentheses is 0 for Monday, 3 for Thursday */
|
||||
+ y.iJD += (3 - (((y.iJD+43200000)/86400000) % 7))*86400000;
|
||||
+ computeYMD(&y);
|
||||
+ if( cf=='G' ){
|
||||
+ sqlite3_str_appendf(&sRes,"%04d",y.Y);
|
||||
+ }else{
|
||||
+ int nDay; /* Number of days since 1st day of year */
|
||||
+ i64 tJD = y.iJD;
|
||||
+ y.validJD = 0;
|
||||
+ y.M = 1;
|
||||
+ y.D = 1;
|
||||
+ computeJD(&y);
|
||||
+ nDay = (int)((tJD-y.iJD+43200000)/86400000);
|
||||
+ sqlite3_str_appendf(&sRes,"%02d",nDay/7+1);
|
||||
+ }
|
||||
+ break;
|
||||
+ }
|
||||
case 'Y': {
|
||||
sqlite3_str_appendf(&sRes,"%04d",x.Y);
|
||||
break;
|
||||
14
sqlite3/locking_mode.patch
Normal file
14
sqlite3/locking_mode.patch
Normal file
@@ -0,0 +1,14 @@
|
||||
# Use exclusive locking mode for WAL databases with v1 VFSes.
|
||||
--- sqlite3.c.orig
|
||||
+++ sqlite3.c
|
||||
@@ -63210,7 +63210,9 @@
|
||||
SQLITE_PRIVATE int sqlite3PagerWalSupported(Pager *pPager){
|
||||
const sqlite3_io_methods *pMethods = pPager->fd->pMethods;
|
||||
if( pPager->noLock ) return 0;
|
||||
- return pPager->exclusiveMode || (pMethods->iVersion>=2 && pMethods->xShmMap);
|
||||
+ if( pMethods->iVersion>=2 && pMethods->xShmMap ) return 1;
|
||||
+ pPager->exclusiveMode = 1;
|
||||
+ return 1;
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -1,19 +1,17 @@
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
|
||||
// Configuration
|
||||
#include "sqlite_cfg.h"
|
||||
// Amalgamation
|
||||
#include "sqlite3.c"
|
||||
// VFS
|
||||
#include "vfs.c"
|
||||
// Extensions
|
||||
#include "ext/anycollseq.c"
|
||||
#include "ext/base64.c"
|
||||
#include "ext/decimal.c"
|
||||
#include "ext/regexp.c"
|
||||
#include "ext/series.c"
|
||||
#include "ext/uint.c"
|
||||
#include "ext/uuid.c"
|
||||
#include "func.c"
|
||||
#include "progress.c"
|
||||
#include "time.c"
|
||||
|
||||
__attribute__((constructor)) void init() {
|
||||
@@ -25,4 +23,4 @@ __attribute__((constructor)) void init() {
|
||||
sqlite3_auto_extension((void (*)(void))sqlite3_uint_init);
|
||||
sqlite3_auto_extension((void (*)(void))sqlite3_uuid_init);
|
||||
sqlite3_auto_extension((void (*)(void))sqlite3_time_init);
|
||||
}
|
||||
}
|
||||
9
sqlite3/progress.c
Normal file
9
sqlite3/progress.c
Normal file
@@ -0,0 +1,9 @@
|
||||
#include <stddef.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
|
||||
int go_progress(void *);
|
||||
|
||||
void sqlite3_progress_handler_go(sqlite3 *db, int n) {
|
||||
sqlite3_progress_handler(db, n, go_progress, /*arg=*/NULL);
|
||||
}
|
||||
@@ -5,12 +5,28 @@
|
||||
#define SQLITE_OS_OTHER 1
|
||||
#define SQLITE_BYTEORDER 1234
|
||||
|
||||
#define HAVE_INT8_T 1
|
||||
#define HAVE_INT16_T 1
|
||||
#define HAVE_INT32_T 1
|
||||
#define HAVE_INT64_T 1
|
||||
#define HAVE_UINT8_T 1
|
||||
#define HAVE_UINT16_T 1
|
||||
#define HAVE_UINT32_T 1
|
||||
#define HAVE_UINT64_T 1
|
||||
#define HAVE_STDINT_H 1
|
||||
#define HAVE_INTTYPES_H 1
|
||||
|
||||
#define HAVE_LOG2 1
|
||||
#define HAVE_LOG10 1
|
||||
#define HAVE_ISNAN 1
|
||||
|
||||
#define HAVE_USLEEP 1
|
||||
#define HAVE_NANOSLEEP 1
|
||||
|
||||
#define HAVE_GMTIME_R 1
|
||||
#define HAVE_LOCALTIME_S 1
|
||||
|
||||
#define HAVE_MALLOC_H 1
|
||||
#define HAVE_MALLOC_USABLE_SIZE 1
|
||||
|
||||
// Recommended Options
|
||||
@@ -23,12 +39,12 @@
|
||||
#define SQLITE_MAX_EXPR_DEPTH 0
|
||||
#define SQLITE_OMIT_DECLTYPE
|
||||
#define SQLITE_OMIT_DEPRECATED
|
||||
#define SQLITE_OMIT_PROGRESS_CALLBACK
|
||||
#define SQLITE_OMIT_SHARED_CACHE
|
||||
#define SQLITE_OMIT_AUTOINIT
|
||||
#define SQLITE_USE_ALLOCA
|
||||
|
||||
// Other Options
|
||||
|
||||
#define SQLITE_ALLOW_URI_AUTHORITY
|
||||
#define SQLITE_ENABLE_BATCH_ATOMIC_WRITE
|
||||
#define SQLITE_ENABLE_ATOMIC_WRITE
|
||||
@@ -36,12 +52,9 @@
|
||||
|
||||
// Because WASM does not support shared memory,
|
||||
// SQLite disables WAL for WASM builds.
|
||||
// We set the default locking mode to EXCLUSIVE instead.
|
||||
// We patch SQLite to use exclusive locking mode instead.
|
||||
// https://www.sqlite.org/wal.html#noshm
|
||||
#undef SQLITE_OMIT_WAL
|
||||
#ifndef SQLITE_DEFAULT_LOCKING_MODE
|
||||
#define SQLITE_DEFAULT_LOCKING_MODE 1
|
||||
#endif
|
||||
|
||||
// Amalgamated Extensions
|
||||
|
||||
@@ -54,6 +67,8 @@
|
||||
#define SQLITE_ENABLE_RTREE 1
|
||||
#define SQLITE_ENABLE_GEOPOLY 1
|
||||
|
||||
#define SQLITE_SOUNDEX
|
||||
|
||||
// Session Extension
|
||||
// #define SQLITE_ENABLE_SESSION
|
||||
// #define SQLITE_ENABLE_PREUPDATE_HOOK
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include <stddef.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
@@ -26,7 +27,63 @@ static int time_collation(void *pArg, int nKey1, const void *pKey1, int nKey2,
|
||||
return rc;
|
||||
}
|
||||
|
||||
static void json_time_func(sqlite3_context *context, int argc,
|
||||
sqlite3_value **argv) {
|
||||
DateTime x;
|
||||
if (isDate(context, argc, argv, &x)) return;
|
||||
if (x.tzSet && x.tz) {
|
||||
x.iJD += x.tz * 60000;
|
||||
if (!validJulianDay(x.iJD)) return;
|
||||
x.validYMD = 0;
|
||||
x.validHMS = 0;
|
||||
}
|
||||
computeYMD_HMS(&x);
|
||||
|
||||
sqlite3 *db = sqlite3_context_db_handle(context);
|
||||
sqlite3_str *res = sqlite3_str_new(db);
|
||||
|
||||
sqlite3_str_appendf(res, "%04d-%02d-%02dT%02d:%02d:%02d", //
|
||||
x.Y, x.M, x.D, //
|
||||
x.h, x.m, (int)(x.iJD / 1000 % 60));
|
||||
|
||||
if (x.useSubsec) {
|
||||
int rem = x.iJD % 1000;
|
||||
if (rem) {
|
||||
sqlite3_str_appendchar(res, 1, '.');
|
||||
sqlite3_str_appendchar(res, 1, '0' + rem / 100);
|
||||
if ((rem %= 100)) {
|
||||
sqlite3_str_appendchar(res, 1, '0' + rem / 10);
|
||||
if ((rem %= 10)) {
|
||||
sqlite3_str_appendchar(res, 1, '0' + rem);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (x.tz) {
|
||||
sqlite3_str_appendf(res, "%+03d:%02d", x.tz / 60, abs(x.tz) % 60);
|
||||
} else {
|
||||
sqlite3_str_appendchar(res, 1, 'Z');
|
||||
}
|
||||
|
||||
int rc = sqlite3_str_errcode(res);
|
||||
if (rc) {
|
||||
sqlite3_result_error_code(context, rc);
|
||||
return;
|
||||
}
|
||||
|
||||
int n = sqlite3_str_length(res);
|
||||
sqlite3_result_text(context, sqlite3_str_finish(res), n, sqlite3_free);
|
||||
}
|
||||
|
||||
int sqlite3_time_init(sqlite3 *db, char **pzErrMsg,
|
||||
const sqlite3_api_routines *pApi) {
|
||||
return sqlite3_create_collation(db, "time", SQLITE_UTF8, 0, time_collation);
|
||||
sqlite3_create_collation_v2(db, "time", SQLITE_UTF8, /*arg=*/NULL,
|
||||
time_collation,
|
||||
/*destroy=*/NULL);
|
||||
sqlite3_create_function_v2(
|
||||
db, "json_time", -1,
|
||||
SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS, /*arg=*/NULL,
|
||||
json_time_func, /*step=*/NULL, /*final=*/NULL, /*destroy=*/NULL);
|
||||
return SQLITE_OK;
|
||||
}
|
||||
45
sqlite3/timezone.patch
Normal file
45
sqlite3/timezone.patch
Normal file
@@ -0,0 +1,45 @@
|
||||
# Set UTC timezone, compute local offset.
|
||||
--- sqlite3.c.orig
|
||||
+++ sqlite3.c
|
||||
@@ -340,6 +340,7 @@ static int setDateTimeToCurrent(sqlite3_context *context, DateTime *p){
|
||||
p->iJD = sqlite3StmtCurrentTime(context);
|
||||
if( p->iJD>0 ){
|
||||
p->validJD = 1;
|
||||
+ p->tzSet = 1;
|
||||
return 0;
|
||||
}else{
|
||||
return 1;
|
||||
@@ -355,6 +356,7 @@ static int setDateTimeToCurrent(sqlite3_context *context, DateTime *p){
|
||||
static void setRawDateNumber(DateTime *p, double r){
|
||||
p->s = r;
|
||||
p->rawS = 1;
|
||||
+ p->tzSet = 1;
|
||||
if( r>=0.0 && r<5373484.5 ){
|
||||
p->iJD = (sqlite3_int64)(r*86400000.0 + 0.5);
|
||||
p->validJD = 1;
|
||||
@@ -731,7 +733,16 @@ static int parseModifier(
|
||||
** show local time.
|
||||
*/
|
||||
if( sqlite3_stricmp(z, "localtime")==0 && sqlite3NotPureFunc(pCtx) ){
|
||||
- rc = toLocaltime(p, pCtx);
|
||||
+ if( p->tzSet!=0 || p->tz==0 ) {
|
||||
+ rc = toLocaltime(p, pCtx);
|
||||
+ i64 iOrigJD = p->iJD;
|
||||
+ p->tzSet = 0;
|
||||
+ computeJD(p);
|
||||
+ p->tz = (p->iJD-iOrigJD)/60000;
|
||||
+ if( abs(p->tz)>= 900 ) p->tz = 0;
|
||||
+ } else {
|
||||
+ rc = 0;
|
||||
+ }
|
||||
}
|
||||
break;
|
||||
}
|
||||
@@ -781,6 +792,7 @@ static int parseModifier(
|
||||
p->validJD = 1;
|
||||
p->tzSet = 1;
|
||||
}
|
||||
+ p->tz = 0;
|
||||
rc = SQLITE_OK;
|
||||
}
|
||||
#endif
|
||||
@@ -1,3 +1,5 @@
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
#include <time.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
@@ -90,22 +92,25 @@ int localtime_s(struct tm *const pTm, time_t const *const pTime) {
|
||||
sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
|
||||
if (zVfsName) {
|
||||
static sqlite3_vfs *go_vfs_list;
|
||||
sqlite3_vfs *found = NULL;
|
||||
for (sqlite3_vfs **next = &go_vfs_list; *next;) {
|
||||
sqlite3_vfs *it = *next;
|
||||
|
||||
for (sqlite3_vfs *it = go_vfs_list; it; it = it->pNext) {
|
||||
if (!strcmp(zVfsName, it->zName) && go_vfs_find(it->zName)) {
|
||||
return it;
|
||||
}
|
||||
}
|
||||
|
||||
for (sqlite3_vfs **ptr = &go_vfs_list; *ptr;) {
|
||||
sqlite3_vfs *it = *ptr;
|
||||
if (go_vfs_find(it->zName)) {
|
||||
if (!strcmp(zVfsName, it->zName)) found = it;
|
||||
next = &it->pNext;
|
||||
ptr = &it->pNext;
|
||||
} else {
|
||||
*next = it->pNext;
|
||||
*ptr = it->pNext;
|
||||
free(it);
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
return found;
|
||||
}
|
||||
|
||||
if (go_vfs_find(zVfsName)) {
|
||||
sqlite3_vfs *prev = go_vfs_list;
|
||||
sqlite3_vfs *head = go_vfs_list;
|
||||
go_vfs_list = malloc(sizeof(sqlite3_vfs) + strlen(zVfsName) + 1);
|
||||
char *name = (char *)(go_vfs_list + 1);
|
||||
strcpy(name, zVfsName);
|
||||
@@ -114,7 +119,7 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
|
||||
.szOsFile = sizeof(struct go_file),
|
||||
.mxPathname = 512,
|
||||
.zName = name,
|
||||
.pNext = prev,
|
||||
.pNext = head,
|
||||
|
||||
.xOpen = go_open_wrapper,
|
||||
.xDelete = go_delete,
|
||||
@@ -132,6 +137,5 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
|
||||
return sqlite3_vfs_find_orig(zVfsName);
|
||||
}
|
||||
|
||||
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");
|
||||
static_assert(offsetof(sqlite3_vfs, zName) == 16, "Unexpected offset");
|
||||
static_assert(offsetof(sqlite3, u1.isInterrupted) == 280, "Unexpected offset");
|
||||
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");
|
||||
@@ -1,3 +1,4 @@
|
||||
# Wrap sqlite3_vfs_find.
|
||||
--- sqlite3.c.orig
|
||||
+++ sqlite3.c
|
||||
@@ -25394,7 +25394,7 @@
|
||||
|
||||
@@ -3,6 +3,7 @@ package sqlite3
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
@@ -12,67 +13,73 @@ func init() {
|
||||
Path = "./embed/sqlite3.wasm"
|
||||
}
|
||||
|
||||
func TestConn_error_OOM(t *testing.T) {
|
||||
func Test_sqlite_error_OOM(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
defer func() { _ = recover() }()
|
||||
m.error(uint64(NOMEM), 0)
|
||||
sqlite.error(uint64(NOMEM), 0)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestConn_call_closed(t *testing.T) {
|
||||
func Test_sqlite_call_closed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m.close()
|
||||
sqlite.close()
|
||||
|
||||
defer func() { _ = recover() }()
|
||||
m.call(m.api.free)
|
||||
sqlite.call(sqlite.api.free)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestConn_new(t *testing.T) {
|
||||
func Test_sqlite_new(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
t.Run("MaxUint32", func(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
m.new(math.MaxUint32)
|
||||
sqlite.new(math.MaxUint32)
|
||||
t.Error("want panic")
|
||||
})
|
||||
|
||||
t.Run("_MAX_ALLOCATION_SIZE", func(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
if os.Getenv("CI") != "" {
|
||||
t.Skip("skipping in CI")
|
||||
}
|
||||
defer func() { _ = recover() }()
|
||||
m.new(_MAX_ALLOCATION_SIZE)
|
||||
m.new(_MAX_ALLOCATION_SIZE)
|
||||
sqlite.new(_MAX_ALLOCATION_SIZE)
|
||||
sqlite.new(_MAX_ALLOCATION_SIZE)
|
||||
t.Error("want panic")
|
||||
})
|
||||
}
|
||||
|
||||
func TestConn_newArena(t *testing.T) {
|
||||
func Test_sqlite_newArena(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
arena := m.newArena(16)
|
||||
arena := sqlite.newArena(16)
|
||||
defer arena.free()
|
||||
|
||||
const title = "Lorem ipsum"
|
||||
@@ -80,7 +87,7 @@ func TestConn_newArena(t *testing.T) {
|
||||
if ptr == 0 {
|
||||
t.Fatalf("got nullptr")
|
||||
}
|
||||
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != title {
|
||||
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != title {
|
||||
t.Errorf("got %q, want %q", got, title)
|
||||
}
|
||||
|
||||
@@ -89,7 +96,7 @@ func TestConn_newArena(t *testing.T) {
|
||||
if ptr == 0 {
|
||||
t.Fatalf("got nullptr")
|
||||
}
|
||||
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != body {
|
||||
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != body {
|
||||
t.Errorf("got %q, want %q", got, body)
|
||||
}
|
||||
|
||||
@@ -101,121 +108,130 @@ func TestConn_newArena(t *testing.T) {
|
||||
if ptr == 0 {
|
||||
t.Fatalf("got nullptr")
|
||||
}
|
||||
if got := util.View(m.mod, ptr, uint64(len(title))); string(got) != title {
|
||||
if got := util.View(sqlite.mod, ptr, uint64(len(title))); string(got) != title {
|
||||
t.Errorf("got %q, want %q", got, title)
|
||||
}
|
||||
|
||||
arena.free()
|
||||
}
|
||||
|
||||
func TestConn_newBytes(t *testing.T) {
|
||||
func Test_sqlite_newBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
ptr := m.newBytes(nil)
|
||||
ptr := sqlite.newBytes(nil)
|
||||
if ptr != 0 {
|
||||
t.Errorf("got %#x, want nullptr", ptr)
|
||||
}
|
||||
|
||||
buf := []byte("sqlite3")
|
||||
ptr = m.newBytes(buf)
|
||||
ptr = sqlite.newBytes(buf)
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
want := buf
|
||||
if got := util.View(m.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
|
||||
if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
|
||||
ptr = sqlite.newBytes(buf[:0])
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
if got := util.View(sqlite.mod, ptr, 0); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_newString(t *testing.T) {
|
||||
func Test_sqlite_newString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
ptr := m.newString("")
|
||||
ptr := sqlite.newString("")
|
||||
if ptr == 0 {
|
||||
t.Error("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
str := "sqlite3\000sqlite3"
|
||||
ptr = m.newString(str)
|
||||
ptr = sqlite.newString(str)
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
want := str + "\000"
|
||||
if got := util.View(m.mod, ptr, uint64(len(want))); string(got) != want {
|
||||
if got := util.View(sqlite.mod, ptr, uint64(len(want))); string(got) != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_getString(t *testing.T) {
|
||||
func Test_sqlite_getString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
ptr := m.newString("")
|
||||
ptr := sqlite.newString("")
|
||||
if ptr == 0 {
|
||||
t.Error("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
str := "sqlite3" + "\000 drop this"
|
||||
ptr = m.newString(str)
|
||||
ptr = sqlite.newString(str)
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
want := "sqlite3"
|
||||
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != want {
|
||||
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
if got := util.ReadString(m.mod, ptr, 0); got != "" {
|
||||
if got := util.ReadString(sqlite.mod, ptr, 0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() { _ = recover() }()
|
||||
util.ReadString(m.mod, ptr, uint32(len(want)/2))
|
||||
util.ReadString(sqlite.mod, ptr, uint32(len(want)/2))
|
||||
t.Error("want panic")
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() { _ = recover() }()
|
||||
util.ReadString(m.mod, 0, math.MaxUint32)
|
||||
util.ReadString(sqlite.mod, 0, math.MaxUint32)
|
||||
t.Error("want panic")
|
||||
}()
|
||||
}
|
||||
|
||||
func TestConn_free(t *testing.T) {
|
||||
func Test_sqlite_free(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
m.free(0)
|
||||
sqlite.free(0)
|
||||
|
||||
ptr := m.new(1)
|
||||
ptr := sqlite.new(1)
|
||||
if ptr == 0 {
|
||||
t.Error("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
m.free(ptr)
|
||||
sqlite.free(ptr)
|
||||
}
|
||||
92
stmt.go
92
stmt.go
@@ -1,7 +1,9 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
@@ -61,12 +63,12 @@ func (s *Stmt) ClearBindings() error {
|
||||
func (s *Stmt) Step() bool {
|
||||
s.c.checkInterrupt()
|
||||
r := s.c.call(s.c.api.step, uint64(s.handle))
|
||||
if r == _ROW {
|
||||
switch r {
|
||||
case _ROW:
|
||||
return true
|
||||
}
|
||||
if r == _DONE {
|
||||
case _DONE:
|
||||
s.err = nil
|
||||
} else {
|
||||
default:
|
||||
s.err = s.c.error(r)
|
||||
}
|
||||
return false
|
||||
@@ -131,10 +133,11 @@ func (s *Stmt) BindName(param int) string {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindBool(param int, value bool) error {
|
||||
var i int64
|
||||
if value {
|
||||
return s.BindInt64(param, 1)
|
||||
i = 1
|
||||
}
|
||||
return s.BindInt64(param, 0)
|
||||
return s.BindInt64(param, i)
|
||||
}
|
||||
|
||||
// BindInt binds an int to the prepared statement.
|
||||
@@ -234,7 +237,7 @@ func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
|
||||
}
|
||||
|
||||
func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
|
||||
const maxlen = uint64(len(time.RFC3339Nano))
|
||||
const maxlen = uint64(len(time.RFC3339Nano)) + 5
|
||||
|
||||
ptr := s.c.new(maxlen)
|
||||
buf := util.View(s.c.mod, ptr, maxlen)
|
||||
@@ -247,6 +250,35 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindPointer binds a NULL to the prepared statement, just like [Stmt.BindNull],
|
||||
// but it also associates ptr with that NULL value such that it can be retrieved
|
||||
// within an application-defined SQL function using [Value.Pointer].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindPointer(param int, ptr any) error {
|
||||
valPtr := util.AddHandle(s.c.ctx, ptr)
|
||||
r := s.c.call(s.c.api.bindPointer,
|
||||
uint64(s.handle), uint64(param), uint64(valPtr))
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindJSON binds the JSON encoding of value to the prepared statement.
|
||||
// The leftmost SQL parameter has an index of 1.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindJSON(param int, value any) error {
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ptr := s.c.newBytes(data)
|
||||
r := s.c.call(s.c.api.bindText,
|
||||
uint64(s.handle), uint64(param),
|
||||
uint64(ptr), uint64(len(data)),
|
||||
uint64(s.c.api.destructor), _UTF8)
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// ColumnCount returns the number of columns in a result set.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_count.html
|
||||
@@ -374,18 +406,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
|
||||
func (s *Stmt) ColumnRawText(col int) []byte {
|
||||
r := s.c.call(s.c.api.columnText,
|
||||
uint64(s.handle), uint64(col))
|
||||
|
||||
ptr := uint32(r)
|
||||
if ptr == 0 {
|
||||
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
|
||||
s.err = s.c.error(r)
|
||||
return nil
|
||||
}
|
||||
|
||||
r = s.c.call(s.c.api.columnBytes,
|
||||
uint64(s.handle), uint64(col))
|
||||
|
||||
return util.View(s.c.mod, ptr, r)
|
||||
return s.columnRawBytes(col, uint32(r))
|
||||
}
|
||||
|
||||
// ColumnRawBlob returns the value of the result column as a []byte.
|
||||
@@ -397,20 +418,45 @@ func (s *Stmt) ColumnRawText(col int) []byte {
|
||||
func (s *Stmt) ColumnRawBlob(col int) []byte {
|
||||
r := s.c.call(s.c.api.columnBlob,
|
||||
uint64(s.handle), uint64(col))
|
||||
return s.columnRawBytes(col, uint32(r))
|
||||
}
|
||||
|
||||
ptr := uint32(r)
|
||||
func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte {
|
||||
if ptr == 0 {
|
||||
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
|
||||
r := s.c.call(s.c.api.errcode, uint64(s.c.handle))
|
||||
s.err = s.c.error(r)
|
||||
return nil
|
||||
}
|
||||
|
||||
r = s.c.call(s.c.api.columnBytes,
|
||||
r := s.c.call(s.c.api.columnBytes,
|
||||
uint64(s.handle), uint64(col))
|
||||
|
||||
return util.View(s.c.mod, ptr, r)
|
||||
}
|
||||
|
||||
// ColumnJSON parses the JSON-encoded value of the result column
|
||||
// and stores it in the value pointed to by ptr.
|
||||
// The leftmost column of the result set has the index 0.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_blob.html
|
||||
func (s *Stmt) ColumnJSON(col int, ptr any) error {
|
||||
var data []byte
|
||||
switch s.ColumnType(col) {
|
||||
case NULL:
|
||||
data = append(data, "null"...)
|
||||
case TEXT:
|
||||
data = s.ColumnRawText(col)
|
||||
case BLOB:
|
||||
data = s.ColumnRawBlob(col)
|
||||
case INTEGER:
|
||||
data = strconv.AppendInt(nil, s.ColumnInt64(col), 10)
|
||||
case FLOAT:
|
||||
data = strconv.AppendFloat(nil, s.ColumnFloat(col), 'g', -1, 64)
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
return json.Unmarshal(data, ptr)
|
||||
}
|
||||
|
||||
// Return true if stmt is an empty SQL statement.
|
||||
// This is used as an optimization.
|
||||
// It's OK to always return false here.
|
||||
|
||||
@@ -43,7 +43,7 @@ func (t params) mustExec(sql string, args ...interface{}) sql.Result {
|
||||
func (sqliteDB) RunTest(t *testing.T, fn func(params)) {
|
||||
db, err := sql.Open("sqlite3", "file:"+
|
||||
filepath.Join(t.TempDir(), "foo.db")+
|
||||
"?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)&_pragma=synchronous(off)")
|
||||
"?_pragma=busy_timeout(10000)&_pragma=synchronous(off)")
|
||||
if err != nil {
|
||||
t.Fatalf("foo.db open fail: %v", err)
|
||||
}
|
||||
|
||||
@@ -182,7 +182,7 @@ func TestConn_SetInterrupt(t *testing.T) {
|
||||
defer stmt.Close()
|
||||
|
||||
db.SetInterrupt(ctx)
|
||||
cancel()
|
||||
go cancel()
|
||||
|
||||
// Interrupting works.
|
||||
err = stmt.Exec()
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
_ "embed"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
//go:embed testdata/wal.db
|
||||
var waldb []byte
|
||||
|
||||
func TestDB_memory(t *testing.T) {
|
||||
t.Parallel()
|
||||
testDB(t, ":memory:")
|
||||
@@ -19,6 +25,23 @@ func TestDB_file(t *testing.T) {
|
||||
testDB(t, filepath.Join(t.TempDir(), "test.db"))
|
||||
}
|
||||
|
||||
func TestDB_nolock(t *testing.T) {
|
||||
t.Parallel()
|
||||
testDB(t, "file:"+
|
||||
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
|
||||
"?nolock=1")
|
||||
}
|
||||
|
||||
func TestDB_wal(t *testing.T) {
|
||||
t.Parallel()
|
||||
wal := filepath.Join(t.TempDir(), "test.db")
|
||||
err := os.WriteFile(wal, waldb, 0666)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testDB(t, wal)
|
||||
}
|
||||
|
||||
func TestDB_vfs(t *testing.T) {
|
||||
testDB(t, "file:test.db?vfs=memdb")
|
||||
}
|
||||
|
||||
@@ -2,10 +2,9 @@ package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
@@ -15,7 +14,7 @@ func TestDriver(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
db, err := driver.Open(":memory:", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
218
tests/func_test.go
Normal file
218
tests/func_test.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestCreateFunction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.CreateFunction("test", 1, sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
switch arg := arg[0]; arg.Int() {
|
||||
case 0:
|
||||
ctx.ResultInt(arg.Int())
|
||||
case 1:
|
||||
ctx.ResultInt64(arg.Int64())
|
||||
case 2:
|
||||
ctx.ResultBool(arg.Bool())
|
||||
case 3:
|
||||
ctx.ResultFloat(arg.Float())
|
||||
case 4:
|
||||
ctx.ResultText(arg.Text())
|
||||
case 5:
|
||||
ctx.ResultBlob(arg.Blob(nil))
|
||||
case 6:
|
||||
ctx.ResultZeroBlob(arg.Int64())
|
||||
case 7:
|
||||
ctx.ResultTime(arg.Time(sqlite3.TimeFormatUnix), sqlite3.TimeFormatDefault)
|
||||
case 8:
|
||||
var v any
|
||||
if err := arg.JSON(&v); err != nil {
|
||||
ctx.ResultError(err)
|
||||
} else {
|
||||
ctx.ResultJSON(v)
|
||||
}
|
||||
case 9:
|
||||
ctx.ResultValue(arg)
|
||||
case 10:
|
||||
ctx.ResultNull()
|
||||
case 11:
|
||||
ctx.ResultError(sqlite3.FULL)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0)`)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnInt64(0); got != 1 {
|
||||
t.Errorf("got %v, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != true {
|
||||
t.Errorf("got %v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.FLOAT {
|
||||
t.Errorf("got %v, want FLOAT", got)
|
||||
}
|
||||
if got := stmt.ColumnInt64(0); got != 3 {
|
||||
t.Errorf("got %v, want 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "4" {
|
||||
t.Errorf("got %s, want 4", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
|
||||
t.Errorf("got %v, want BLOB", got)
|
||||
}
|
||||
if got := stmt.ColumnRawBlob(0); string(got) != "5" {
|
||||
t.Errorf("got %s, want 5", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
|
||||
t.Errorf("got %v, want BLOB", got)
|
||||
}
|
||||
if got := stmt.ColumnRawBlob(0); len(got) != 6 {
|
||||
t.Errorf("got %v, want 6", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); got.Unix() != 7 {
|
||||
t.Errorf("got %v, want 7", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
var got int
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != 8 {
|
||||
t.Errorf("got %v, want 8", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnInt64(0); got != 9 {
|
||||
t.Errorf("got %v, want 9", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
t.Error("want error")
|
||||
}
|
||||
if err := stmt.Err(); !errors.Is(err, sqlite3.FULL) {
|
||||
t.Errorf("got %v, want sqlite3.FULL", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnyCollationNeeded(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.AnyCollationNeeded()
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT id, name FROM users ORDER BY name COLLATE silly`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
row := 0
|
||||
ids := []int{0, 2, 1}
|
||||
names := []string{"go", "whatever", "zig"}
|
||||
for ; stmt.Step(); row++ {
|
||||
id := stmt.ColumnInt(0)
|
||||
name := stmt.ColumnText(1)
|
||||
|
||||
if id != ids[row] {
|
||||
t.Errorf("got %d, want %d", id, ids[row])
|
||||
}
|
||||
if name != names[row] {
|
||||
t.Errorf("got %q, want %q", name, names[row])
|
||||
}
|
||||
}
|
||||
if row != 3 {
|
||||
t.Errorf("got %d, want %d", row, len(ids))
|
||||
}
|
||||
|
||||
if err := stmt.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
68
tests/json_test.go
Normal file
68
tests/json_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
"github.com/ncruces/julianday"
|
||||
)
|
||||
|
||||
func TestJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
|
||||
|
||||
_, err = db.Exec(
|
||||
`INSERT INTO test (col) VALUES (?), (?), (?), (?)`,
|
||||
nil, 1, math.Pi, reference,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(
|
||||
`INSERT INTO test (col) VALUES (?), (?), (?), (?)`,
|
||||
sqlite3.JSON(math.Pi), sqlite3.JSON(false),
|
||||
julianday.Format(reference), sqlite3.JSON([]string{}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := db.Query("SELECT * FROM test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
want := []string{
|
||||
"null", "1", "3.141592653589793",
|
||||
`"2013-10-07T04:23:19.12-04:00"`,
|
||||
"3.141592653589793", "false",
|
||||
"2456572.849526851851852", "[]",
|
||||
}
|
||||
for rows.Next() {
|
||||
var got json.RawMessage
|
||||
err = rows.Scan(sqlite3.JSON(&got))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(got) != want[0] {
|
||||
t.Errorf("got %q, want %q", got, want[0])
|
||||
}
|
||||
want = want[1:]
|
||||
}
|
||||
}
|
||||
@@ -25,7 +25,6 @@ func TestParallel(t *testing.T) {
|
||||
name := "file:" +
|
||||
filepath.Join(t.TempDir(), "test.db") +
|
||||
"?_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=locking_mode(normal)" +
|
||||
"&_pragma=journal_mode(truncate)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
testParallel(t, name, iter)
|
||||
@@ -42,7 +41,6 @@ func TestMemory(t *testing.T) {
|
||||
|
||||
name := "file:/test.db?vfs=memdb" +
|
||||
"&_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=locking_mode(normal)" +
|
||||
"&_pragma=journal_mode(memory)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
testParallel(t, name, iter)
|
||||
@@ -59,7 +57,6 @@ func TestMultiProcess(t *testing.T) {
|
||||
|
||||
name := "file:" + file +
|
||||
"?_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=locking_mode(normal)" +
|
||||
"&_pragma=journal_mode(truncate)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
|
||||
@@ -93,7 +90,6 @@ func TestChildProcess(t *testing.T) {
|
||||
|
||||
name := "file:" + file +
|
||||
"?_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=locking_mode(normal)" +
|
||||
"&_pragma=journal_mode(truncate)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
|
||||
@@ -128,10 +124,7 @@ func testParallel(t *testing.T, name string, n int) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`
|
||||
PRAGMA busy_timeout=10000;
|
||||
PRAGMA locking_mode=normal;
|
||||
`)
|
||||
err = db.Exec(`PRAGMA busy_timeout=10000`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
82
tests/quote_test.go
Normal file
82
tests/quote_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestQuote(t *testing.T) {
|
||||
tests := []struct {
|
||||
val any
|
||||
want string
|
||||
}{
|
||||
{`abc`, "'abc'"},
|
||||
{`a"bc`, "'a\"bc'"},
|
||||
{`a'bc`, "'a''bc'"},
|
||||
{"\x07bc", "'\abc'"},
|
||||
{"\x1c\n", "'\x1c\n'"},
|
||||
{[]byte("\xB0\x00\x0B"), "x'B0000B'"},
|
||||
{"\xB0\x00\x0B", ""},
|
||||
|
||||
{0, "0"},
|
||||
{true, "1"},
|
||||
{false, "0"},
|
||||
{nil, "NULL"},
|
||||
{math.NaN(), "NULL"},
|
||||
{math.Inf(1), "9.0e999"},
|
||||
{math.Inf(-1), "-9.0e999"},
|
||||
{math.Pi, "3.141592653589793"},
|
||||
{int64(math.MaxInt64), "9223372036854775807"},
|
||||
{time.Unix(0, 0).UTC(), "'1970-01-01T00:00:00Z'"},
|
||||
{sqlite3.ZeroBlob(4), "x'00000000'"},
|
||||
{sqlite3.ZeroBlob(1e9), ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.want, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil && tt.want != "" {
|
||||
t.Errorf("Quote(%q) = %v", tt.val, r)
|
||||
}
|
||||
}()
|
||||
|
||||
got := sqlite3.Quote(tt.val)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Quote(%v) = %q, want %q", tt.val, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
id string
|
||||
want string
|
||||
}{
|
||||
{`abc`, `"abc"`},
|
||||
{`a"bc`, `"a""bc"`},
|
||||
{`a'bc`, `"a'bc"`},
|
||||
{"\x07bc", "\"\abc\""},
|
||||
{"\x1c\n", "\"\x1c\n\""},
|
||||
{"\xB0\x00\x0B", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.want, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil && tt.want != "" {
|
||||
t.Errorf("QuoteIdentifier(%q) = %v", tt.id, r)
|
||||
}
|
||||
}()
|
||||
|
||||
got := sqlite3.QuoteIdentifier(tt.id)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("QuoteIdentifier(%v) = %q, want %q", tt.id, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -81,6 +82,13 @@ func TestStmt(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.BindBlob(1, []byte("")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.BindBlob(1, []byte("blob")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -102,6 +110,13 @@ func TestStmt(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.BindJSON(1, true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.ClearBindings(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -114,7 +129,7 @@ func TestStmt(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// The table should have: 0, 1, 2, π, NULL, "", "text", "blob", NULL, "\0\0\0\0", NULL
|
||||
// The table should have: 0, 1, 2, π, NULL, "", "text", "", "blob", NULL, "\0\0\0\0", "true", NULL
|
||||
stmt, _, err = db.Prepare(`SELECT col FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -140,6 +155,12 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "0" {
|
||||
t.Errorf("got %q, want zero", got)
|
||||
}
|
||||
var got int
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -161,6 +182,12 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "1" {
|
||||
t.Errorf("got %q, want one", got)
|
||||
}
|
||||
var got float32
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != 1 {
|
||||
t.Errorf("got %v, want one", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -182,6 +209,12 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "2" {
|
||||
t.Errorf("got %q, want two", got)
|
||||
}
|
||||
var got json.Number
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != "2" {
|
||||
t.Errorf("got %v, want two", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -203,6 +236,12 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "3.14159265358979" {
|
||||
t.Errorf("got %q, want π", got)
|
||||
}
|
||||
var got float64
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != math.Pi {
|
||||
t.Errorf("got %v, want π", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -224,6 +263,12 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
var got any = 1
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != nil {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -245,6 +290,10 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
var got any
|
||||
if err := stmt.ColumnJSON(0, &got); err == nil {
|
||||
t.Errorf("got %v, want error", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -266,6 +315,35 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "text" {
|
||||
t.Errorf(`got %q, want "text"`, got)
|
||||
}
|
||||
var got any
|
||||
if err := stmt.ColumnJSON(0, &got); err == nil {
|
||||
t.Errorf("got %v, want error", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
|
||||
t.Errorf("got %v, want BLOB", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
var got any
|
||||
if err := stmt.ColumnJSON(0, &got); err == nil {
|
||||
t.Errorf("got %v, want error", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -287,6 +365,10 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "blob" {
|
||||
t.Errorf(`got %q, want "blob"`, got)
|
||||
}
|
||||
var got any
|
||||
if err := stmt.ColumnJSON(0, &got); err == nil {
|
||||
t.Errorf("got %v, want error", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -308,6 +390,12 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
var got any = 1
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != nil {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -329,6 +417,37 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "\x00\x00\x00\x00" {
|
||||
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
|
||||
}
|
||||
var got any
|
||||
if err := stmt.ColumnJSON(0, &got); err == nil {
|
||||
t.Errorf("got %v, want error", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "true" {
|
||||
t.Errorf("got %q, want true", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "true" {
|
||||
t.Errorf("got %q, want true", got)
|
||||
}
|
||||
var got any = 1
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != true {
|
||||
t.Errorf("got %v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -350,6 +469,12 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
var got any = 1
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != nil {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
}
|
||||
|
||||
if err := stmt.Close(); err != nil {
|
||||
|
||||
BIN
tests/testdata/wal.db
vendored
Normal file
BIN
tests/testdata/wal.db
vendored
Normal file
Binary file not shown.
@@ -1,11 +1,13 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
)
|
||||
|
||||
func TestTimeFormat_Encode(t *testing.T) {
|
||||
@@ -39,70 +41,72 @@ func TestTimeFormat_Encode(t *testing.T) {
|
||||
func TestTimeFormat_Decode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
|
||||
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
|
||||
zone := time.FixedZone("", -4*3600)
|
||||
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, zone)
|
||||
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, zone)
|
||||
|
||||
tests := []struct {
|
||||
fmt sqlite3.TimeFormat
|
||||
val any
|
||||
want time.Time
|
||||
wantDelta time.Duration
|
||||
wantLoc *time.Location
|
||||
wantErr bool
|
||||
}{
|
||||
{sqlite3.TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
|
||||
{sqlite3.TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
|
||||
{sqlite3.TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
|
||||
{sqlite3.TimeFormatJulianDay, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatJulianDay, "2456572.849526851851852", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, time.UTC, false},
|
||||
{sqlite3.TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, time.UTC, false},
|
||||
{sqlite3.TimeFormatJulianDay, false, time.Time{}, 0, nil, true},
|
||||
|
||||
{sqlite3.TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatUnix, int64(1381134199), reference, time.Second, false},
|
||||
{sqlite3.TimeFormatUnix, "abc", time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnix, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnix, "1381134199.120", reference, time.Microsecond, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnix, 1381134199.120, reference, time.Microsecond, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnix, int64(1381134199), reference, time.Second, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnix, "abc", time.Time{}, 0, nil, true},
|
||||
{sqlite3.TimeFormatUnix, false, time.Time{}, 0, nil, true},
|
||||
|
||||
{sqlite3.TimeFormatUnixMilli, "1381134199120", reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnixMilli, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnixMilli, "1381134199120", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixMilli, 1381134199.120e3, reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixMilli, int64(1381134199_120), reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixMilli, "abc", time.Time{}, 0, nil, true},
|
||||
{sqlite3.TimeFormatUnixMilli, false, time.Time{}, 0, nil, true},
|
||||
|
||||
{sqlite3.TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnixMicro, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnixMicro, "1381134199120000", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixMicro, 1381134199.120e6, reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixMicro, "abc", time.Time{}, 0, nil, true},
|
||||
{sqlite3.TimeFormatUnixMicro, false, time.Time{}, 0, nil, true},
|
||||
|
||||
{sqlite3.TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixNano, "abc", time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnixNano, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnixNano, "1381134199120000000", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixNano, 1381134199.120e9, reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixNano, "abc", time.Time{}, 0, nil, true},
|
||||
{sqlite3.TimeFormatUnixNano, false, time.Time{}, 0, nil, true},
|
||||
|
||||
{sqlite3.TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
|
||||
{sqlite3.TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199", reference, time.Second, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199120", reference, 0, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, false},
|
||||
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
|
||||
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, false},
|
||||
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "2456572", reference, 24 * time.Hour, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120", reference, time.Microsecond, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199", reference, time.Second, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199120", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, zone, false},
|
||||
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, zone, false},
|
||||
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, nil, true},
|
||||
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, nil, true},
|
||||
|
||||
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
|
||||
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
|
||||
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, false},
|
||||
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, false},
|
||||
{sqlite3.TimeFormat3, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormat9, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, zone, false},
|
||||
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, zone, false},
|
||||
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormat3, false, time.Time{}, 0, nil, true},
|
||||
{sqlite3.TimeFormat9, false, time.Time{}, 0, nil, true},
|
||||
|
||||
{sqlite3.TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
|
||||
{sqlite3.TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
|
||||
{sqlite3.TimeFormatDefault, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, zone, false},
|
||||
{sqlite3.TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatDefault, false, time.Time{}, 0, nil, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -112,13 +116,48 @@ func TestTimeFormat_Decode(t *testing.T) {
|
||||
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.want.Sub(got).Abs() > tt.wantDelta {
|
||||
if got.Sub(tt.want).Abs() > tt.wantDelta {
|
||||
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
|
||||
}
|
||||
if got.Location().String() != tt.wantLoc.String() {
|
||||
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got.Location(), tt.wantLoc)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeFormat_Scanner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(
|
||||
`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
|
||||
|
||||
_, err = db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.TimeFormat7TZ.Encode(reference))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var got time.Time
|
||||
err = db.QueryRow("SELECT * FROM test").Scan(sqlite3.TimeFormatAuto.Scanner(&got))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !got.Equal(reference) {
|
||||
t.Errorf("got %v, want %v", got, reference)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_timeCollation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -167,3 +206,57 @@ func TestDB_timeCollation(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_isoWeek(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []time.Time{
|
||||
time.Date(1977, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1977, 1, 2, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1977, 12, 31, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1978, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1978, 1, 2, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1978, 12, 31, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1979, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1979, 1, 2, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1979, 12, 31, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1980, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1980, 12, 28, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1980, 12, 29, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1980, 12, 30, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1980, 12, 31, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1981, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1981, 12, 31, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1982, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1982, 1, 2, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1982, 1, 3, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT strftime('%G-W%V-%u', ?)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for _, tm := range tests {
|
||||
stmt.BindTime(1, tm, sqlite3.TimeFormatDefault)
|
||||
if stmt.Step() {
|
||||
y, w := tm.ISOWeek()
|
||||
d := tm.Weekday()
|
||||
if d == 0 {
|
||||
d = 7
|
||||
}
|
||||
want := fmt.Sprintf("%04d-W%02d-%d", y, w, d)
|
||||
if got := stmt.ColumnText(0); got != want {
|
||||
t.Errorf("got %q, want %q (%v)", got, want, tm)
|
||||
}
|
||||
}
|
||||
stmt.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
56
time.go
56
time.go
@@ -164,9 +164,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
|
||||
case float64:
|
||||
sec, frac := math.Modf(v)
|
||||
nsec := math.Floor(frac * 1e9)
|
||||
return time.Unix(int64(sec), int64(nsec)), nil
|
||||
return time.Unix(int64(sec), int64(nsec)).UTC(), nil
|
||||
case int64:
|
||||
return time.Unix(v, 0), nil
|
||||
return time.Unix(v, 0).UTC(), nil
|
||||
default:
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
@@ -181,9 +181,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
return time.UnixMilli(int64(math.Floor(v))), nil
|
||||
return time.UnixMilli(int64(math.Floor(v))).UTC(), nil
|
||||
case int64:
|
||||
return time.UnixMilli(int64(v)), nil
|
||||
return time.UnixMilli(int64(v)).UTC(), nil
|
||||
default:
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
@@ -198,9 +198,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
return time.UnixMicro(int64(math.Floor(v))), nil
|
||||
return time.UnixMicro(int64(math.Floor(v))).UTC(), nil
|
||||
case int64:
|
||||
return time.UnixMicro(int64(v)), nil
|
||||
return time.UnixMicro(int64(v)).UTC(), nil
|
||||
default:
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
@@ -215,9 +215,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
return time.Unix(0, int64(math.Floor(v))), nil
|
||||
return time.Unix(0, int64(math.Floor(v))).UTC(), nil
|
||||
case int64:
|
||||
return time.Unix(0, int64(v)), nil
|
||||
return time.Unix(0, int64(v)).UTC(), nil
|
||||
default:
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
@@ -238,26 +238,16 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
|
||||
}
|
||||
|
||||
dates := []TimeFormat{
|
||||
TimeFormat6TZ, TimeFormat6, TimeFormat3TZ, TimeFormat3,
|
||||
TimeFormat5TZ, TimeFormat5, TimeFormat2TZ, TimeFormat2,
|
||||
TimeFormat1,
|
||||
TimeFormat9, TimeFormat8,
|
||||
TimeFormat6, TimeFormat5,
|
||||
TimeFormat3, TimeFormat2, TimeFormat1,
|
||||
}
|
||||
for _, f := range dates {
|
||||
t, err := time.Parse(string(f), s)
|
||||
t, err := f.Decode(s)
|
||||
if err == nil {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
times := []TimeFormat{
|
||||
TimeFormat9TZ, TimeFormat9, TimeFormat8TZ, TimeFormat8,
|
||||
}
|
||||
for _, f := range times {
|
||||
t, err := time.Parse(string(f), s)
|
||||
if err == nil {
|
||||
return t.AddDate(2000, 0, 0), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
@@ -314,7 +304,10 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
t, err := f.parseRelaxed(s)
|
||||
return t.AddDate(2000, 0, 0), err
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return t.AddDate(2000, 0, 0), nil
|
||||
|
||||
default:
|
||||
s, ok := v.(string)
|
||||
@@ -338,3 +331,20 @@ func (f TimeFormat) parseRelaxed(s string) (time.Time, error) {
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Scanner returns a [database/sql.Scanner] that can be used as an argument to
|
||||
// [database/sql.Row.Scan] and similar methods to
|
||||
// decode a time value into dest using this format.
|
||||
func (f TimeFormat) Scanner(dest *time.Time) interface{ Scan(any) error } {
|
||||
return timeScanner{dest, f}
|
||||
}
|
||||
|
||||
type timeScanner struct {
|
||||
*time.Time
|
||||
TimeFormat
|
||||
}
|
||||
|
||||
func (s timeScanner) Scan(src any) (err error) {
|
||||
*s.Time, err = s.Decode(src)
|
||||
return
|
||||
}
|
||||
|
||||
33
tx.go
33
tx.go
@@ -7,6 +7,7 @@ import (
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Tx is an in-progress database transaction.
|
||||
@@ -119,17 +120,8 @@ type Savepoint struct {
|
||||
//
|
||||
// https://www.sqlite.org/lang_savepoint.html
|
||||
func (c *Conn) Savepoint() Savepoint {
|
||||
name := "sqlite3.Savepoint"
|
||||
var pc [1]uintptr
|
||||
if n := runtime.Callers(2, pc[:]); n > 0 {
|
||||
frames := runtime.CallersFrames(pc[:n])
|
||||
frame, _ := frames.Next()
|
||||
if frame.Function != "" {
|
||||
name = frame.Function
|
||||
}
|
||||
}
|
||||
// Names can be reused; this makes catching bugs more likely.
|
||||
name += "#" + strconv.Itoa(int(rand.Int31()))
|
||||
name := saveptName() + "_" + strconv.Itoa(int(rand.Int31()))
|
||||
|
||||
err := c.txExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name))
|
||||
if err != nil {
|
||||
@@ -138,6 +130,27 @@ func (c *Conn) Savepoint() Savepoint {
|
||||
return Savepoint{c: c, name: name}
|
||||
}
|
||||
|
||||
func saveptName() (name string) {
|
||||
defer func() {
|
||||
if name == "" {
|
||||
name = "sqlite3.Savepoint"
|
||||
}
|
||||
}()
|
||||
|
||||
var pc [8]uintptr
|
||||
n := runtime.Callers(3, pc[:])
|
||||
if n <= 0 {
|
||||
return ""
|
||||
}
|
||||
frames := runtime.CallersFrames(pc[:n])
|
||||
frame, more := frames.Next()
|
||||
for more && (strings.HasPrefix(frame.Function, "database/sql.") ||
|
||||
strings.HasPrefix(frame.Function, "github.com/ncruces/go-sqlite3/driver.")) {
|
||||
frame, more = frames.Next()
|
||||
}
|
||||
return frame.Function
|
||||
}
|
||||
|
||||
// Release releases the savepoint rolling back any changes
|
||||
// if *error points to a non-nil error.
|
||||
//
|
||||
|
||||
155
value.go
Normal file
155
value.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Value is any value that can be stored in a database table.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value.html
|
||||
type Value struct {
|
||||
*sqlite
|
||||
handle uint32
|
||||
}
|
||||
|
||||
// Type returns the initial [Datatype] of the value.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Type() Datatype {
|
||||
r := v.call(v.api.valueType, uint64(v.handle))
|
||||
return Datatype(r)
|
||||
}
|
||||
|
||||
// Bool returns the value as a bool.
|
||||
// SQLite does not have a separate boolean storage class.
|
||||
// Instead, boolean values are retrieved as integers,
|
||||
// with 0 converted to false and any other value to true.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Bool() bool {
|
||||
if i := v.Int64(); i != 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Int returns the value as an int.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Int() int {
|
||||
return int(v.Int64())
|
||||
}
|
||||
|
||||
// Int64 returns the value as an int64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Int64() int64 {
|
||||
r := v.call(v.api.valueInteger, uint64(v.handle))
|
||||
return int64(r)
|
||||
}
|
||||
|
||||
// Float returns the value as a float64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Float() float64 {
|
||||
r := v.call(v.api.valueFloat, uint64(v.handle))
|
||||
return math.Float64frombits(r)
|
||||
}
|
||||
|
||||
// Time returns the value as a [time.Time].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Time(format TimeFormat) time.Time {
|
||||
var a any
|
||||
switch v.Type() {
|
||||
case INTEGER:
|
||||
a = v.Int64()
|
||||
case FLOAT:
|
||||
a = v.Float()
|
||||
case TEXT, BLOB:
|
||||
a = v.Text()
|
||||
case NULL:
|
||||
return time.Time{}
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
t, _ := format.Decode(a)
|
||||
return t
|
||||
}
|
||||
|
||||
// Text returns the value as a string.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Text() string {
|
||||
return string(v.RawText())
|
||||
}
|
||||
|
||||
// Blob appends to buf and returns
|
||||
// the value as a []byte.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Blob(buf []byte) []byte {
|
||||
return append(buf, v.RawBlob()...)
|
||||
}
|
||||
|
||||
// RawText returns the value as a []byte.
|
||||
// The []byte is owned by SQLite and may be invalidated by
|
||||
// subsequent calls to [Value] methods.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) RawText() []byte {
|
||||
r := v.call(v.api.valueText, uint64(v.handle))
|
||||
return v.rawBytes(uint32(r))
|
||||
}
|
||||
|
||||
// RawBlob returns the value as a []byte.
|
||||
// The []byte is owned by SQLite and may be invalidated by
|
||||
// subsequent calls to [Value] methods.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) RawBlob() []byte {
|
||||
r := v.call(v.api.valueBlob, uint64(v.handle))
|
||||
return v.rawBytes(uint32(r))
|
||||
}
|
||||
|
||||
func (v Value) rawBytes(ptr uint32) []byte {
|
||||
if ptr == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
r := v.call(v.api.valueBytes, uint64(v.handle))
|
||||
return util.View(v.mod, ptr, r)
|
||||
}
|
||||
|
||||
// Pointer gets the pointer associated with this value,
|
||||
// or nil if it has no associated pointer.
|
||||
func (v Value) Pointer() any {
|
||||
r := v.call(v.api.valuePointer, uint64(v.handle))
|
||||
return util.GetHandle(v.ctx, uint32(r))
|
||||
}
|
||||
|
||||
// JSON parses a JSON-encoded value
|
||||
// and stores the result in the value pointed to by ptr.
|
||||
func (v Value) JSON(ptr any) error {
|
||||
var data []byte
|
||||
switch v.Type() {
|
||||
case NULL:
|
||||
data = append(data, "null"...)
|
||||
case TEXT:
|
||||
data = v.RawText()
|
||||
case BLOB:
|
||||
data = v.RawBlob()
|
||||
case INTEGER:
|
||||
data = strconv.AppendInt(nil, v.Int64(), 10)
|
||||
case FLOAT:
|
||||
data = strconv.AppendFloat(nil, v.Float(), 'g', -1, 64)
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
return json.Unmarshal(data, ptr)
|
||||
}
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
This package implements the SQLite [OS Interface](https://www.sqlite.org/vfs.html) (aka VFS).
|
||||
|
||||
It replaces the default VFS with a pure Go implementation,
|
||||
that is tested on Linux, macOS and Windows,
|
||||
but which should also work on illumos and the various BSDs.
|
||||
It replaces the default SQLite VFS with a pure Go implementation.
|
||||
|
||||
It also exposes interfaces that should allow you to implement your own custom VFSes.
|
||||
23
vfs/api.go
23
vfs/api.go
@@ -15,7 +15,7 @@ type VFS interface {
|
||||
FullPathname(name string) (string, error)
|
||||
}
|
||||
|
||||
// VFSParams extends VFS to with the ability to handle URI parameters
|
||||
// VFSParams extends VFS with the ability to handle URI parameters
|
||||
// through the OpenParams method.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/uri_boolean.html
|
||||
@@ -47,7 +47,7 @@ type File interface {
|
||||
// FileLockState extends File to implement the
|
||||
// SQLITE_FCNTL_LOCKSTATE file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntllockstate
|
||||
type FileLockState interface {
|
||||
File
|
||||
LockState() LockLevel
|
||||
@@ -56,7 +56,7 @@ type FileLockState interface {
|
||||
// FileSizeHint extends File to implement the
|
||||
// SQLITE_FCNTL_SIZE_HINT file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlsizehint
|
||||
type FileSizeHint interface {
|
||||
File
|
||||
SizeHint(size int64) error
|
||||
@@ -65,16 +65,25 @@ type FileSizeHint interface {
|
||||
// FileHasMoved extends File to implement the
|
||||
// SQLITE_FCNTL_HAS_MOVED file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlhasmoved
|
||||
type FileHasMoved interface {
|
||||
File
|
||||
HasMoved() (bool, error)
|
||||
}
|
||||
|
||||
// FileOverwrite extends File to implement the
|
||||
// SQLITE_FCNTL_OVERWRITE file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntloverwrite
|
||||
type FileOverwrite interface {
|
||||
File
|
||||
Overwrite() error
|
||||
}
|
||||
|
||||
// FilePowersafeOverwrite extends File to implement the
|
||||
// SQLITE_FCNTL_POWERSAFE_OVERWRITE file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlpowersafeoverwrite
|
||||
type FilePowersafeOverwrite interface {
|
||||
File
|
||||
PowersafeOverwrite() bool
|
||||
@@ -84,7 +93,7 @@ type FilePowersafeOverwrite interface {
|
||||
// FilePowersafeOverwrite extends File to implement the
|
||||
// SQLITE_FCNTL_COMMIT_PHASETWO file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlcommitphasetwo
|
||||
type FileCommitPhaseTwo interface {
|
||||
File
|
||||
CommitPhaseTwo() error
|
||||
@@ -94,7 +103,7 @@ type FileCommitPhaseTwo interface {
|
||||
// SQLITE_FCNTL_BEGIN_ATOMIC_WRITE, SQLITE_FCNTL_COMMIT_ATOMIC_WRITE
|
||||
// and SQLITE_FCNTL_ROLLBACK_ATOMIC_WRITE file control opcodes.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlbeginatomicwrite
|
||||
type FileBatchAtomicWrite interface {
|
||||
File
|
||||
BeginAtomicWrite() error
|
||||
|
||||
9
vfs/clear.go
Normal file
9
vfs/clear.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !go1.21
|
||||
|
||||
package vfs
|
||||
|
||||
func clear(b []byte) {
|
||||
for i := range b {
|
||||
b[i] = 0
|
||||
}
|
||||
}
|
||||
10
vfs/file.go
10
vfs/file.go
@@ -9,7 +9,6 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type vfsOS struct{}
|
||||
@@ -124,11 +123,10 @@ func (vfsOS) OpenParams(name string, flags OpenFlag, params url.Values) (File, O
|
||||
|
||||
type vfsFile struct {
|
||||
*os.File
|
||||
lockTimeout time.Duration
|
||||
lock LockLevel
|
||||
psow bool
|
||||
syncDir bool
|
||||
readOnly bool
|
||||
lock LockLevel
|
||||
psow bool
|
||||
syncDir bool
|
||||
readOnly bool
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
28
vfs/lock.go
28
vfs/lock.go
@@ -1,11 +1,6 @@
|
||||
package vfs
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
import "github.com/ncruces/go-sqlite3/internal/util"
|
||||
|
||||
const (
|
||||
_PENDING_BYTE = 0x40000000
|
||||
@@ -48,7 +43,7 @@ func (f *vfsFile) Lock(lock LockLevel) error {
|
||||
if f.lock != LOCK_NONE {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
if rc := osGetSharedLock(f.File, f.lockTimeout); rc != _OK {
|
||||
if rc := osGetSharedLock(f.File); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
f.lock = LOCK_SHARED
|
||||
@@ -59,7 +54,7 @@ func (f *vfsFile) Lock(lock LockLevel) error {
|
||||
if f.lock != LOCK_SHARED {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
if rc := osGetReservedLock(f.File, f.lockTimeout); rc != _OK {
|
||||
if rc := osGetReservedLock(f.File); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
f.lock = LOCK_RESERVED
|
||||
@@ -77,7 +72,7 @@ func (f *vfsFile) Lock(lock LockLevel) error {
|
||||
}
|
||||
f.lock = LOCK_PENDING
|
||||
}
|
||||
if rc := osGetExclusiveLock(f.File, f.lockTimeout); rc != _OK {
|
||||
if rc := osGetExclusiveLock(f.File); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
f.lock = LOCK_EXCLUSIVE
|
||||
@@ -133,18 +128,3 @@ func (f *vfsFile) CheckReservedLock() (bool, error) {
|
||||
}
|
||||
return osCheckReservedLock(f.File)
|
||||
}
|
||||
|
||||
func osGetReservedLock(file *os.File, timeout time.Duration) _ErrorCode {
|
||||
// Acquire the RESERVED lock.
|
||||
return osWriteLock(file, _RESERVED_BYTE, 1, timeout)
|
||||
}
|
||||
|
||||
func osGetPendingLock(file *os.File) _ErrorCode {
|
||||
// Acquire the PENDING lock.
|
||||
return osWriteLock(file, _PENDING_BYTE, 1, 0)
|
||||
}
|
||||
|
||||
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
|
||||
// Test the RESERVED lock.
|
||||
return osCheckLock(file, _RESERVED_BYTE, 1)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
@@ -12,13 +11,6 @@ import (
|
||||
)
|
||||
|
||||
func Test_vfsLock(t *testing.T) {
|
||||
switch runtime.GOOS {
|
||||
case "linux", "darwin", "windows":
|
||||
break
|
||||
default:
|
||||
t.Skip("OS lacks OFD locks")
|
||||
}
|
||||
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
|
||||
// Create a temporary file.
|
||||
@@ -41,8 +33,7 @@ func Test_vfsLock(t *testing.T) {
|
||||
pOutput = 32
|
||||
)
|
||||
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
|
||||
ctx, vfs := NewContext(context.TODO())
|
||||
defer vfs.Close()
|
||||
ctx := util.NewContext(context.TODO())
|
||||
|
||||
vfsFileRegister(ctx, mod, pFile1, &vfsFile{File: file1})
|
||||
vfsFileRegister(ctx, mod, pFile2, &vfsFile{File: file2})
|
||||
@@ -212,9 +203,4 @@ func Test_vfsLock(t *testing.T) {
|
||||
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) {
|
||||
t.Error("invalid lock state", got)
|
||||
}
|
||||
|
||||
rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCK_TIMEOUT, 1)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
}
|
||||
|
||||
10
vfs/memdb/clear.go
Normal file
10
vfs/memdb/clear.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !go1.21
|
||||
|
||||
package memdb
|
||||
|
||||
func clear[T any](b []T) {
|
||||
var zero T
|
||||
for i := range b {
|
||||
b[i] = zero
|
||||
}
|
||||
}
|
||||
@@ -133,7 +133,7 @@ func (m *memFile) WriteAt(b []byte, off int64) (n int, err error) {
|
||||
n = copy((*m.data[base])[rest:], b)
|
||||
if n < len(b) {
|
||||
// Assume writes are page aligned.
|
||||
return 0, io.ErrShortWrite
|
||||
return n, io.ErrShortWrite
|
||||
}
|
||||
if size := off + int64(len(b)); size > m.size {
|
||||
m.size = size
|
||||
@@ -176,6 +176,8 @@ func (m *memFile) Size() (int64, error) {
|
||||
return m.size, nil
|
||||
}
|
||||
|
||||
const spinWait = 25 * time.Microsecond
|
||||
|
||||
func (m *memFile) Lock(lock vfs.LockLevel) error {
|
||||
if m.lock >= lock {
|
||||
return nil
|
||||
@@ -187,17 +189,11 @@ func (m *memFile) Lock(lock vfs.LockLevel) error {
|
||||
|
||||
m.lockMtx.Lock()
|
||||
defer m.lockMtx.Unlock()
|
||||
deadline := time.Now().Add(time.Millisecond)
|
||||
|
||||
switch lock {
|
||||
case vfs.LOCK_SHARED:
|
||||
for m.pending != nil {
|
||||
if time.Now().After(deadline) {
|
||||
return sqlite3.BUSY
|
||||
}
|
||||
m.lockMtx.Unlock()
|
||||
runtime.Gosched()
|
||||
m.lockMtx.Lock()
|
||||
if m.pending != nil {
|
||||
return sqlite3.BUSY
|
||||
}
|
||||
m.shared++
|
||||
|
||||
@@ -216,8 +212,8 @@ func (m *memFile) Lock(lock vfs.LockLevel) error {
|
||||
m.pending = m
|
||||
}
|
||||
|
||||
for m.shared > 1 {
|
||||
if time.Now().After(deadline) {
|
||||
for before := time.Now(); m.shared > 1; {
|
||||
if time.Since(before) > spinWait {
|
||||
return sqlite3.BUSY
|
||||
}
|
||||
m.lockMtx.Unlock()
|
||||
@@ -291,10 +287,3 @@ func divRoundUp(a, b int64) int64 {
|
||||
func modRoundUp(a, b int64) int64 {
|
||||
return b - (b-a%b)%b
|
||||
}
|
||||
|
||||
func clear[T any](b []T) {
|
||||
var zero T
|
||||
for i := range b {
|
||||
b[i] = zero
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build freebsd || openbsd || netbsd || dragonfly || (darwin && sqlite3_bsd)
|
||||
//go:build (freebsd || openbsd || netbsd || dragonfly || sqlite3_flock) && !sqlite3_nosys
|
||||
|
||||
package vfs
|
||||
|
||||
@@ -20,16 +20,16 @@ func osUnlock(file *os.File, start, len int64) _ErrorCode {
|
||||
}
|
||||
|
||||
func osLock(file *os.File, how int, timeout time.Duration, def _ErrorCode) _ErrorCode {
|
||||
before := time.Now()
|
||||
var err error
|
||||
for {
|
||||
err = unix.Flock(int(file.Fd()), how)
|
||||
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
|
||||
break
|
||||
}
|
||||
if timeout < time.Millisecond {
|
||||
if timeout <= 0 || timeout < time.Since(before) {
|
||||
break
|
||||
}
|
||||
timeout -= time.Millisecond
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
return osLockErrorCode(err, def)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !sqlite3_bsd
|
||||
//go:build !sqlite3_flock && !sqlite3_nosys
|
||||
|
||||
package vfs
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !sqlite3_nosys
|
||||
|
||||
package vfs
|
||||
|
||||
import (
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user