mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 22:19:14 +00:00
Compare commits
239 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5d7cd1f2d | ||
|
|
a33a187d48 | ||
|
|
70c6ee15c6 | ||
|
|
994d9b1812 | ||
|
|
b19bd28ed3 | ||
|
|
e66bd51845 | ||
|
|
f5614bc2ed | ||
|
|
d9fcf60b7d | ||
|
|
ac6dd1aa5f | ||
|
|
b1495bd6cb | ||
|
|
2d91760295 | ||
|
|
38d4254bc4 | ||
|
|
c0aa734786 | ||
|
|
fa845dbd3d | ||
|
|
fed315ab79 | ||
|
|
726d7316f7 | ||
|
|
ddb387b021 | ||
|
|
d0f19507f5 | ||
|
|
9d997552ad | ||
|
|
9d75c39dcc | ||
|
|
746a84965e | ||
|
|
312d3b58f2 | ||
|
|
b71cd295c2 | ||
|
|
5b3b61a304 | ||
|
|
d661d15723 | ||
|
|
1e38165ad0 | ||
|
|
58a32d7c9d | ||
|
|
6765e883c1 | ||
|
|
18fc608433 | ||
|
|
77f37893b9 | ||
|
|
f1e36e2581 | ||
|
|
772b9153c7 | ||
|
|
4b280a3a7e | ||
|
|
19b6098bf6 | ||
|
|
2aa685320f | ||
|
|
9941be05c2 | ||
|
|
a0a9ab7737 | ||
|
|
a77727a1ce | ||
|
|
47fe032078 | ||
|
|
bdfe279444 | ||
|
|
a86937a54e | ||
|
|
6ef422fbde | ||
|
|
ff0cb6fb88 | ||
|
|
72db90efdf | ||
|
|
5a3fdef3c5 | ||
|
|
ff34b0cae1 | ||
|
|
f064492bb1 | ||
|
|
1427d30541 | ||
|
|
d3730341f0 | ||
|
|
78ac2386f6 | ||
|
|
632ea933b3 | ||
|
|
0f7fa6ebc9 | ||
|
|
6f7f776488 | ||
|
|
f6d7c5e9c5 | ||
|
|
1cc7ecfe8d | ||
|
|
3844e81404 | ||
|
|
fec1f8d32a | ||
|
|
31572e6095 | ||
|
|
4aee38b957 | ||
|
|
232a7705b5 | ||
|
|
a6c2fccd74 | ||
|
|
6a982559cd | ||
|
|
c7904d30de | ||
|
|
ce4386604d | ||
|
|
26b62c520d | ||
|
|
738714bf32 | ||
|
|
41b020bafc | ||
|
|
d0e720272b | ||
|
|
76171da12b | ||
|
|
dcc845d684 | ||
|
|
f1b42c26d5 | ||
|
|
1e94407ae7 | ||
|
|
eb8d9b95fd | ||
|
|
04037a75ed | ||
|
|
2472ceb0a0 | ||
|
|
bfe9bfde2e | ||
|
|
f07e82e361 | ||
|
|
fbbbe5a631 | ||
|
|
5ea603ed78 | ||
|
|
401cb77e38 | ||
|
|
6511175011 | ||
|
|
f7d987fdf1 | ||
|
|
00ba681bb5 | ||
|
|
d4d4533a41 | ||
|
|
ec9533b13f | ||
|
|
8fe77a065c | ||
|
|
7bf5312bd4 | ||
|
|
ae7b74d858 | ||
|
|
9a8de3ad13 | ||
|
|
05737e6025 | ||
|
|
ac2836bb82 | ||
|
|
d0d4b0e1a2 | ||
|
|
dc3dc6853d | ||
|
|
830240c368 | ||
|
|
dedec8682b | ||
|
|
a33b828e13 | ||
|
|
8b2e96dedc | ||
|
|
f1c46db512 | ||
|
|
7ca9d79424 | ||
|
|
254d473546 | ||
|
|
5639fc1ff8 | ||
|
|
ae4954d09b | ||
|
|
45937d9749 | ||
|
|
eee71e06aa | ||
|
|
9e7b6bb8ea | ||
|
|
597178f80d | ||
|
|
cc2d16ac83 | ||
|
|
cfb69e4ce7 | ||
|
|
e6969432e3 | ||
|
|
2b3da350cc | ||
|
|
336ba87d56 | ||
|
|
dd4823ebf0 | ||
|
|
663b23ff3b | ||
|
|
4e2ce6c635 | ||
|
|
66effb4249 | ||
|
|
e1cce83f71 | ||
|
|
df953b31c2 | ||
|
|
67cc3d35d5 | ||
|
|
6846b72b31 | ||
|
|
c94cdaf720 | ||
|
|
f6a887dd1c | ||
|
|
2a010a2022 | ||
|
|
c86b06b048 | ||
|
|
a44a13a506 | ||
|
|
4604719966 | ||
|
|
03168d5d34 | ||
|
|
be4b6304f9 | ||
|
|
b5e678a40a | ||
|
|
2fc4698ddc | ||
|
|
bd86539577 | ||
|
|
7a785d9aec | ||
|
|
59f79e8e74 | ||
|
|
40457721d7 | ||
|
|
18eeb85783 | ||
|
|
b36536979b | ||
|
|
a6226c3b31 | ||
|
|
bed2ee7674 | ||
|
|
7e6d178122 | ||
|
|
f360c77a78 | ||
|
|
759b11a05d | ||
|
|
93ce586139 | ||
|
|
2e5082c616 | ||
|
|
34acc28af8 | ||
|
|
c1a640f7d8 | ||
|
|
005b15610a | ||
|
|
23ee4ccb0b | ||
|
|
3a8cfd036d | ||
|
|
c38382fd8e | ||
|
|
8509e0b6c8 | ||
|
|
9c07e57252 | ||
|
|
80039385d3 | ||
|
|
89f4327b2b | ||
|
|
37a3ff37e8 | ||
|
|
d880d6842c | ||
|
|
bef46e7954 | ||
|
|
4e72b4d117 | ||
|
|
3b08d02a83 | ||
|
|
b19c12c4c7 | ||
|
|
859a21ef4e | ||
|
|
8ff0ee752f | ||
|
|
589ad86f76 | ||
|
|
1a3a1be1f6 | ||
|
|
222c217bc8 | ||
|
|
c1dc716391 | ||
|
|
71e1e5a8ee | ||
|
|
e4efb20c71 | ||
|
|
2c9459d907 | ||
|
|
d0875e5fab | ||
|
|
15dec13f15 | ||
|
|
f38e36109a | ||
|
|
4cb65ccbd9 | ||
|
|
f789c2fb8b | ||
|
|
c6a2617dfc | ||
|
|
6fc0afcd12 | ||
|
|
77088962f5 | ||
|
|
71da34861b | ||
|
|
56e8281bdb | ||
|
|
f61d430e65 | ||
|
|
dbaed53b9a | ||
|
|
8b1bfd04e3 | ||
|
|
11c1687146 | ||
|
|
94c43a8685 | ||
|
|
a25159a070 | ||
|
|
e007e9b060 | ||
|
|
66a730893f | ||
|
|
926adeb3f5 | ||
|
|
677f51bec1 | ||
|
|
5d6f92b733 | ||
|
|
f5747f19fb | ||
|
|
dfcdbf9c4c | ||
|
|
ad1e8f4b0e | ||
|
|
8f29882671 | ||
|
|
6c96a019e6 | ||
|
|
d291738b81 | ||
|
|
c1263d4f33 | ||
|
|
1ebdc1aa93 | ||
|
|
4dd10f071a | ||
|
|
7dbddfa5c0 | ||
|
|
ce5e035801 | ||
|
|
8bb8367a36 | ||
|
|
9f59b3d0ec | ||
|
|
5f893b5459 | ||
|
|
35b1a97b88 | ||
|
|
416c3863a0 | ||
|
|
dbc400eb15 | ||
|
|
35265271aa | ||
|
|
c7165a2e56 | ||
|
|
e64bffa520 | ||
|
|
54046b6adc | ||
|
|
1b3823483f | ||
|
|
ce6d0627b2 | ||
|
|
dd30215702 | ||
|
|
21aff4c9f5 | ||
|
|
b30f127547 | ||
|
|
6509e5deb2 | ||
|
|
125b8053f8 | ||
|
|
1e4a246d2f | ||
|
|
e6cd0aaf87 | ||
|
|
c1472a48b0 | ||
|
|
a69ab1ebe3 | ||
|
|
1190c21684 | ||
|
|
8c28c3a6f4 | ||
|
|
0146496036 | ||
|
|
fcd33d2f0f | ||
|
|
627df5db0f | ||
|
|
1ed62d300d | ||
|
|
5b2451c3ad | ||
|
|
d52e0371eb | ||
|
|
75f2644b0e | ||
|
|
71ae26e5c9 | ||
|
|
e91758c6a4 | ||
|
|
b749b32a62 | ||
|
|
3b4df71a94 | ||
|
|
df687a1c54 | ||
|
|
2f5b9837e1 | ||
|
|
c351400be7 | ||
|
|
231d3a0438 | ||
|
|
2f25e4eedb | ||
|
|
ad27d5d840 |
1
.github/FUNDING.yml
vendored
Normal file
1
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1 @@
|
||||
custom: https://www.paypal.com/donate?hosted_button_id=33P59ELZWGMK6
|
||||
45
.github/workflows/go.yml
vendored
45
.github/workflows/go.yml
vendored
@@ -5,6 +5,7 @@ on:
|
||||
branches: [ "main" ]
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
@@ -15,12 +16,30 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
lfs: 'true'
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
- name: Set up
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: stable
|
||||
cache: true
|
||||
|
||||
- name: Format
|
||||
run: gofmt -s -w . && git diff --exit-code
|
||||
if: matrix.os != 'windows-latest'
|
||||
|
||||
- name: Tidy
|
||||
run: go mod tidy && git diff --exit-code
|
||||
|
||||
- name: Download
|
||||
run: go mod download
|
||||
|
||||
- name: Verify
|
||||
run: go mod verify
|
||||
|
||||
- name: Vet
|
||||
run: go vet ./...
|
||||
continue-on-error: true
|
||||
|
||||
- name: Build
|
||||
run: go build -v ./...
|
||||
@@ -28,13 +47,21 @@ jobs:
|
||||
- name: Test
|
||||
run: go test -v ./...
|
||||
|
||||
- name: Test data races
|
||||
run: go test -v -race ./...
|
||||
- name: Test no locks
|
||||
run: go test -v -tags sqlite3_nolock .
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
|
||||
- name: Update coverage report
|
||||
uses: ncruces/go-coverage-report@main
|
||||
- name: Test BSD locks
|
||||
run: go test -v -tags sqlite3_flock ./...
|
||||
if: matrix.os == 'macos-latest'
|
||||
|
||||
- name: Coverage report
|
||||
uses: ncruces/go-coverage-report@v0
|
||||
with:
|
||||
chart: 'true'
|
||||
amend: 'true'
|
||||
reuse-go: 'true'
|
||||
if: |
|
||||
matrix.os == 'ubuntu-latest' &&
|
||||
github.event_name == 'push'
|
||||
github.event_name == 'push' &&
|
||||
matrix.os == 'ubuntu-latest'
|
||||
continue-on-error: true
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -13,7 +13,4 @@
|
||||
|
||||
# Dependency directories (remove the comment below to include it)
|
||||
# vendor/
|
||||
tools
|
||||
|
||||
# Project
|
||||
sqlite3/sqlite3*
|
||||
tools
|
||||
108
README.md
108
README.md
@@ -2,20 +2,100 @@
|
||||
|
||||
[](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
|
||||
[](https://goreportcard.com/report/github.com/ncruces/go-sqlite3)
|
||||
[](https://raw.githack.com/wiki/ncruces/go-sqlite3/coverage.html)
|
||||
[](https://github.com/ncruces/go-sqlite3/wiki/Test-coverage-report)
|
||||
|
||||
⚠️ CAUTION ⚠️
|
||||
Go module `github.com/ncruces/go-sqlite3` wraps a [WASM](https://webassembly.org/) build of [SQLite](https://sqlite.org/),
|
||||
and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings.
|
||||
|
||||
This is a WIP.\
|
||||
DO NOT USE with data you care about.
|
||||
- [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
|
||||
wraps the [C SQLite API](https://www.sqlite.org/cintro.html)
|
||||
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
|
||||
- [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver)
|
||||
provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
|
||||
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
|
||||
- [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
|
||||
embeds a build of SQLite into your application.
|
||||
- [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
|
||||
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
|
||||
- [`github.com/ncruces/go-sqlite3/vfs/memdb`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/memdb)
|
||||
implements an in-memory VFS.
|
||||
- [`github.com/ncruces/go-sqlite3/vfs/readervfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/readervfs)
|
||||
implements a VFS for immutable databases.
|
||||
- [`github.com/ncruces/go-sqlite3/ext/unicode`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/unicode)
|
||||
registers Unicode aware functions.
|
||||
- [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
|
||||
registers [statistics functions](https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html).
|
||||
- [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
|
||||
provides a [GORM](https://gorm.io) driver.
|
||||
|
||||
Roadmap:
|
||||
- [x] build SQLite using `zig cc --target=wasm32-wasi`
|
||||
- [x] `:memory:` databases
|
||||
- [x] port [`test_demovfs.c`](https://www.sqlite.org/src/doc/trunk/src/test_demovfs.c) to Go
|
||||
- branch [`wasi`](https://github.com/ncruces/go-sqlite3/tree/wasi) uses `test_demovfs.c` directly
|
||||
- [x] design a simple, nice API, enough for simple use cases
|
||||
- [x] provide a simple `database/sql` driver
|
||||
- [x] file locking, compatible with SQLite on Windows/Unix
|
||||
- [ ] shared memory, compatible with SQLite on Windows/Unix
|
||||
- needed for improved WAL mode
|
||||
### Caveats
|
||||
|
||||
This module replaces the SQLite [OS Interface](https://www.sqlite.org/vfs.html)
|
||||
(aka VFS) with a [pure Go](vfs/) implementation.
|
||||
This has benefits, but also comes with some drawbacks.
|
||||
|
||||
#### Write-Ahead Logging
|
||||
|
||||
Because WASM does not support shared memory,
|
||||
[WAL](https://www.sqlite.org/wal.html) support is [limited](https://www.sqlite.org/wal.html#noshm).
|
||||
|
||||
To work around this limitation, SQLite is [patched](sqlite3/locking_mode.patch)
|
||||
to always use `EXCLUSIVE` locking mode for WAL databases.
|
||||
|
||||
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
|
||||
to use the [`database/sql`](https://pkg.go.dev/database/sql)
|
||||
driver with WAL mode databases you should disable connection pooling by calling
|
||||
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
|
||||
|
||||
#### POSIX Advisory Locks
|
||||
|
||||
POSIX advisory locks, which SQLite uses, are
|
||||
[broken by design](https://www.sqlite.org/src/artifact/2e8b12?ln=1073-1161).
|
||||
|
||||
On Linux, macOS and illumos, this module uses
|
||||
[OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
|
||||
to synchronize access to database files.
|
||||
OFD locks are fully compatible with process-associated POSIX advisory locks.
|
||||
|
||||
On BSD Unixes, this module may use
|
||||
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2).
|
||||
BSD locks may _not_ be compatible with process-associated POSIX advisory locks.
|
||||
|
||||
##### TL;DR
|
||||
|
||||
In all platforms for which this package builds,
|
||||
it should be safe to use it to access databases concurrently,
|
||||
from multiple goroutines, processes, and
|
||||
with _other_ implementations of SQLite.
|
||||
|
||||
If the package does not build for your platform,
|
||||
see [this](vfs/README.md#portability).
|
||||
|
||||
#### Testing
|
||||
|
||||
The pure Go VFS is tested by running SQLite's
|
||||
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
|
||||
on Linux, macOS and Windows.
|
||||
Performance is tested by running
|
||||
[speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c).
|
||||
|
||||
### Roadmap
|
||||
|
||||
- [ ] advanced SQLite features
|
||||
- [x] custom functions
|
||||
- [x] nested transactions
|
||||
- [x] incremental BLOB I/O
|
||||
- [x] online backup
|
||||
- [ ] session extension
|
||||
- [ ] custom VFSes
|
||||
- [x] custom VFS API
|
||||
- [x] in-memory VFS
|
||||
- [x] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
|
||||
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
|
||||
|
||||
### Alternatives
|
||||
|
||||
- [`modernc.org/sqlite`](https://pkg.go.dev/modernc.org/sqlite)
|
||||
- [`crawshaw.io/sqlite`](https://pkg.go.dev/crawshaw.io/sqlite)
|
||||
- [`github.com/mattn/go-sqlite3`](https://pkg.go.dev/github.com/mattn/go-sqlite3)
|
||||
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)
|
||||
|
||||
112
api.go
112
api.go
@@ -1,112 +0,0 @@
|
||||
// Package sqlite3 wraps the C SQLite API.
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) {
|
||||
getFun := func(name string) api.Function {
|
||||
f := module.ExportedFunction(name)
|
||||
if f == nil {
|
||||
err = noFuncErr + errorString(name)
|
||||
return nil
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
getPtr := func(name string) uint32 {
|
||||
global := module.ExportedGlobal(name)
|
||||
if global == nil {
|
||||
err = noGlobalErr + errorString(name)
|
||||
return 0
|
||||
}
|
||||
return memory{module}.readUint32(uint32(global.Get()))
|
||||
}
|
||||
|
||||
c := Conn{
|
||||
ctx: ctx,
|
||||
mem: memory{module},
|
||||
api: sqliteAPI{
|
||||
malloc: getFun("malloc"),
|
||||
free: getFun("free"),
|
||||
destructor: uint64(getPtr("malloc_destructor")),
|
||||
errcode: getFun("sqlite3_errcode"),
|
||||
errstr: getFun("sqlite3_errstr"),
|
||||
errmsg: getFun("sqlite3_errmsg"),
|
||||
erroff: getFun("sqlite3_error_offset"),
|
||||
open: getFun("sqlite3_open_v2"),
|
||||
close: getFun("sqlite3_close"),
|
||||
prepare: getFun("sqlite3_prepare_v3"),
|
||||
finalize: getFun("sqlite3_finalize"),
|
||||
reset: getFun("sqlite3_reset"),
|
||||
step: getFun("sqlite3_step"),
|
||||
exec: getFun("sqlite3_exec"),
|
||||
clearBindings: getFun("sqlite3_clear_bindings"),
|
||||
bindCount: getFun("sqlite3_bind_parameter_count"),
|
||||
bindIndex: getFun("sqlite3_bind_parameter_index"),
|
||||
bindName: getFun("sqlite3_bind_parameter_name"),
|
||||
bindNull: getFun("sqlite3_bind_null"),
|
||||
bindInteger: getFun("sqlite3_bind_int64"),
|
||||
bindFloat: getFun("sqlite3_bind_double"),
|
||||
bindText: getFun("sqlite3_bind_text64"),
|
||||
bindBlob: getFun("sqlite3_bind_blob64"),
|
||||
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
|
||||
columnCount: getFun("sqlite3_column_count"),
|
||||
columnName: getFun("sqlite3_column_name"),
|
||||
columnType: getFun("sqlite3_column_type"),
|
||||
columnInteger: getFun("sqlite3_column_int64"),
|
||||
columnFloat: getFun("sqlite3_column_double"),
|
||||
columnText: getFun("sqlite3_column_text"),
|
||||
columnBlob: getFun("sqlite3_column_blob"),
|
||||
columnBytes: getFun("sqlite3_column_bytes"),
|
||||
lastRowid: getFun("sqlite3_last_insert_rowid"),
|
||||
changes: getFun("sqlite3_changes64"),
|
||||
interrupt: getFun("sqlite3_interrupt"),
|
||||
},
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
type sqliteAPI struct {
|
||||
malloc api.Function
|
||||
free api.Function
|
||||
destructor uint64
|
||||
errcode api.Function
|
||||
errstr api.Function
|
||||
errmsg api.Function
|
||||
erroff api.Function
|
||||
open api.Function
|
||||
close api.Function
|
||||
prepare api.Function
|
||||
finalize api.Function
|
||||
reset api.Function
|
||||
step api.Function
|
||||
exec api.Function
|
||||
clearBindings api.Function
|
||||
bindNull api.Function
|
||||
bindCount api.Function
|
||||
bindIndex api.Function
|
||||
bindName api.Function
|
||||
bindInteger api.Function
|
||||
bindFloat api.Function
|
||||
bindText api.Function
|
||||
bindBlob api.Function
|
||||
bindZeroBlob api.Function
|
||||
columnCount api.Function
|
||||
columnName api.Function
|
||||
columnType api.Function
|
||||
columnInteger api.Function
|
||||
columnFloat api.Function
|
||||
columnText api.Function
|
||||
columnBlob api.Function
|
||||
columnBytes api.Function
|
||||
lastRowid api.Function
|
||||
changes api.Function
|
||||
interrupt api.Function
|
||||
}
|
||||
134
backup.go
Normal file
134
backup.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package sqlite3
|
||||
|
||||
// Backup is an handle to an ongoing online backup operation.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/backup.html
|
||||
type Backup struct {
|
||||
c *Conn
|
||||
handle uint32
|
||||
otherc uint32
|
||||
}
|
||||
|
||||
// Backup backs up srcDB on the src connection to the "main" database in dstURI.
|
||||
//
|
||||
// Backup opens the SQLite database file dstURI,
|
||||
// and blocks until the entire backup is complete.
|
||||
// Use [Conn.BackupInit] for incremental backup.
|
||||
//
|
||||
// https://www.sqlite.org/backup.html
|
||||
func (src *Conn) Backup(srcDB, dstURI string) error {
|
||||
b, err := src.BackupInit(srcDB, dstURI)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer b.Close()
|
||||
_, err = b.Step(-1)
|
||||
return err
|
||||
}
|
||||
|
||||
// Restore restores dstDB on the dst connection from the "main" database in srcURI.
|
||||
//
|
||||
// Restore opens the SQLite database file srcURI,
|
||||
// and blocks until the entire restore is complete.
|
||||
//
|
||||
// https://www.sqlite.org/backup.html
|
||||
func (dst *Conn) Restore(dstDB, srcURI string) error {
|
||||
src, err := dst.openDB(srcURI, OPEN_READONLY|OPEN_URI)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b, err := dst.backupInit(dst.handle, dstDB, src, "main")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer b.Close()
|
||||
_, err = b.Step(-1)
|
||||
return err
|
||||
}
|
||||
|
||||
// BackupInit initializes a backup operation to copy the content of one database into another.
|
||||
//
|
||||
// BackupInit opens the SQLite database file dstURI,
|
||||
// then initializes a backup that copies the contents of srcDB on the src connection
|
||||
// to the "main" database in dstURI.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupinit
|
||||
func (src *Conn) BackupInit(srcDB, dstURI string) (*Backup, error) {
|
||||
dst, err := src.openDB(dstURI, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return src.backupInit(dst, "main", src.handle, srcDB)
|
||||
}
|
||||
|
||||
func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string) (*Backup, error) {
|
||||
defer c.arena.reset()
|
||||
dstPtr := c.arena.string(dstName)
|
||||
srcPtr := c.arena.string(srcName)
|
||||
|
||||
other := dst
|
||||
if c.handle == dst {
|
||||
other = src
|
||||
}
|
||||
|
||||
r := c.call(c.api.backupInit,
|
||||
uint64(dst), uint64(dstPtr),
|
||||
uint64(src), uint64(srcPtr))
|
||||
if r == 0 {
|
||||
defer c.closeDB(other)
|
||||
r = c.call(c.api.errcode, uint64(dst))
|
||||
return nil, c.sqlite.error(r, dst)
|
||||
}
|
||||
|
||||
return &Backup{
|
||||
c: c,
|
||||
otherc: other,
|
||||
handle: uint32(r),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close finishes a backup operation.
|
||||
//
|
||||
// It is safe to close a nil, zero or closed Backup.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupfinish
|
||||
func (b *Backup) Close() error {
|
||||
if b == nil || b.handle == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
r := b.c.call(b.c.api.backupFinish, uint64(b.handle))
|
||||
b.c.closeDB(b.otherc)
|
||||
b.handle = 0
|
||||
return b.c.error(r)
|
||||
}
|
||||
|
||||
// Step copies up to nPage pages between the source and destination databases.
|
||||
// If nPage is negative, all remaining source pages are copied.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupstep
|
||||
func (b *Backup) Step(nPage int) (done bool, err error) {
|
||||
r := b.c.call(b.c.api.backupStep, uint64(b.handle), uint64(nPage))
|
||||
if r == _DONE {
|
||||
return true, nil
|
||||
}
|
||||
return false, b.c.error(r)
|
||||
}
|
||||
|
||||
// Remaining returns the number of pages still to be backed up
|
||||
// at the conclusion of the most recent [Backup.Step].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupremaining
|
||||
func (b *Backup) Remaining() int {
|
||||
r := b.c.call(b.c.api.backupRemaining, uint64(b.handle))
|
||||
return int(r)
|
||||
}
|
||||
|
||||
// PageCount returns the total number of pages in the source database
|
||||
// at the conclusion of the most recent [Backup.Step].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backuppagecount
|
||||
func (b *Backup) PageCount() int {
|
||||
r := b.c.call(b.c.api.backupPageCount, uint64(b.handle))
|
||||
return int(r)
|
||||
}
|
||||
247
blob.go
Normal file
247
blob.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// ZeroBlob represents a zero-filled, length n BLOB
|
||||
// that can be used as an argument to
|
||||
// [database/sql.DB.Exec] and similar methods.
|
||||
type ZeroBlob int64
|
||||
|
||||
// Blob is an handle to an open BLOB.
|
||||
//
|
||||
// It implements [io.ReadWriteSeeker] for incremental BLOB I/O.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/blob.html
|
||||
type Blob struct {
|
||||
c *Conn
|
||||
bytes int64
|
||||
offset int64
|
||||
handle uint32
|
||||
}
|
||||
|
||||
var _ io.ReadWriteSeeker = &Blob{}
|
||||
|
||||
// OpenBlob opens a BLOB for incremental I/O.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/blob_open.html
|
||||
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
|
||||
c.checkInterrupt()
|
||||
defer c.arena.reset()
|
||||
blobPtr := c.arena.new(ptrlen)
|
||||
dbPtr := c.arena.string(db)
|
||||
tablePtr := c.arena.string(table)
|
||||
columnPtr := c.arena.string(column)
|
||||
|
||||
var flags uint64
|
||||
if write {
|
||||
flags = 1
|
||||
}
|
||||
|
||||
r := c.call(c.api.blobOpen, uint64(c.handle),
|
||||
uint64(dbPtr), uint64(tablePtr), uint64(columnPtr),
|
||||
uint64(row), flags, uint64(blobPtr))
|
||||
|
||||
if err := c.error(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blob := Blob{c: c}
|
||||
blob.handle = util.ReadUint32(c.mod, blobPtr)
|
||||
blob.bytes = int64(c.call(c.api.blobBytes, uint64(blob.handle)))
|
||||
return &blob, nil
|
||||
}
|
||||
|
||||
// Close closes a BLOB handle.
|
||||
//
|
||||
// It is safe to close a nil, zero or closed Blob.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/blob_close.html
|
||||
func (b *Blob) Close() error {
|
||||
if b == nil || b.handle == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
r := b.c.call(b.c.api.blobClose, uint64(b.handle))
|
||||
|
||||
b.handle = 0
|
||||
return b.c.error(r)
|
||||
}
|
||||
|
||||
// Size returns the size of the BLOB in bytes.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/blob_bytes.html
|
||||
func (b *Blob) Size() int64 {
|
||||
return b.bytes
|
||||
}
|
||||
|
||||
// Read implements the [io.Reader] interface.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/blob_read.html
|
||||
func (b *Blob) Read(p []byte) (n int, err error) {
|
||||
if b.offset >= b.bytes {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
avail := b.bytes - b.offset
|
||||
want := int64(len(p))
|
||||
if want > avail {
|
||||
want = avail
|
||||
}
|
||||
|
||||
defer b.c.arena.reset()
|
||||
ptr := b.c.arena.new(uint64(want))
|
||||
|
||||
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
|
||||
uint64(ptr), uint64(want), uint64(b.offset))
|
||||
err = b.c.error(r)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b.offset += want
|
||||
if b.offset >= b.bytes {
|
||||
err = io.EOF
|
||||
}
|
||||
|
||||
copy(p, util.View(b.c.mod, ptr, uint64(want)))
|
||||
return int(want), err
|
||||
}
|
||||
|
||||
// WriteTo implements the [io.WriterTo] interface.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/blob_read.html
|
||||
func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
|
||||
if b.offset >= b.bytes {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
avail := b.bytes - b.offset
|
||||
want := int64(65536)
|
||||
if want > avail {
|
||||
want = avail
|
||||
}
|
||||
|
||||
ptr := b.c.new(uint64(want))
|
||||
defer b.c.free(ptr)
|
||||
|
||||
for want > 0 {
|
||||
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
|
||||
uint64(ptr), uint64(want), uint64(b.offset))
|
||||
err = b.c.error(r)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
mem := util.View(b.c.mod, ptr, uint64(want))
|
||||
m, err := w.Write(mem[:want])
|
||||
b.offset += int64(m)
|
||||
n += int64(m)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if int64(m) != want {
|
||||
return n, io.ErrShortWrite
|
||||
}
|
||||
|
||||
avail = b.bytes - b.offset
|
||||
if want > avail {
|
||||
want = avail
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Write implements the [io.Writer] interface.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/blob_write.html
|
||||
func (b *Blob) Write(p []byte) (n int, err error) {
|
||||
defer b.c.arena.reset()
|
||||
ptr := b.c.arena.bytes(p)
|
||||
|
||||
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
|
||||
uint64(ptr), uint64(len(p)), uint64(b.offset))
|
||||
err = b.c.error(r)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b.offset += int64(len(p))
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// ReadFrom implements the [io.ReaderFrom] interface.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/blob_write.html
|
||||
func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
avail := b.bytes - b.offset
|
||||
want := int64(65536)
|
||||
if want > avail {
|
||||
want = avail
|
||||
}
|
||||
if want < 1 {
|
||||
want = 1
|
||||
}
|
||||
|
||||
ptr := b.c.new(uint64(want))
|
||||
defer b.c.free(ptr)
|
||||
|
||||
for {
|
||||
mem := util.View(b.c.mod, ptr, uint64(want))
|
||||
m, err := r.Read(mem[:want])
|
||||
if m > 0 {
|
||||
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
|
||||
uint64(ptr), uint64(m), uint64(b.offset))
|
||||
err := b.c.error(r)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
b.offset += int64(m)
|
||||
n += int64(m)
|
||||
}
|
||||
if err == io.EOF {
|
||||
return n, nil
|
||||
}
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
avail = b.bytes - b.offset
|
||||
if want > avail {
|
||||
want = avail
|
||||
}
|
||||
if want < 1 {
|
||||
want = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Seek implements the [io.Seeker] interface.
|
||||
func (b *Blob) Seek(offset int64, whence int) (int64, error) {
|
||||
switch whence {
|
||||
default:
|
||||
return 0, util.WhenceErr
|
||||
case io.SeekStart:
|
||||
break
|
||||
case io.SeekCurrent:
|
||||
offset += b.offset
|
||||
case io.SeekEnd:
|
||||
offset += b.bytes
|
||||
}
|
||||
if offset < 0 {
|
||||
return 0, util.OffsetErr
|
||||
}
|
||||
b.offset = offset
|
||||
return offset, nil
|
||||
}
|
||||
|
||||
// Reopen moves a BLOB handle to a new row of the same database table.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/blob_reopen.html
|
||||
func (b *Blob) Reopen(row int64) error {
|
||||
err := b.c.error(b.c.call(b.c.api.blobReopen, uint64(b.handle), uint64(row)))
|
||||
b.bytes = int64(b.c.call(b.c.api.blobBytes, uint64(b.handle)))
|
||||
b.offset = 0
|
||||
return err
|
||||
}
|
||||
54
compile.go
54
compile.go
@@ -1,54 +0,0 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
// Configure SQLite.
|
||||
var (
|
||||
Binary []byte // Binary to load.
|
||||
Path string // Path to load the binary from.
|
||||
)
|
||||
|
||||
var sqlite3 sqlite3Runtime
|
||||
|
||||
type sqlite3Runtime struct {
|
||||
once sync.Once
|
||||
runtime wazero.Runtime
|
||||
compiled wazero.CompiledModule
|
||||
instances atomic.Uint64
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *sqlite3Runtime) instantiateModule(ctx context.Context) (api.Module, error) {
|
||||
s.once.Do(func() { s.compileModule(ctx) })
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
|
||||
cfg := wazero.NewModuleConfig().
|
||||
WithName("sqlite3-" + strconv.FormatUint(s.instances.Add(1), 10))
|
||||
return s.runtime.InstantiateModule(ctx, s.compiled, cfg)
|
||||
}
|
||||
|
||||
func (s *sqlite3Runtime) compileModule(ctx context.Context) {
|
||||
s.runtime = wazero.NewRuntime(ctx)
|
||||
vfsInstantiate(ctx, s.runtime)
|
||||
|
||||
bin := Binary
|
||||
if bin == nil && Path != "" {
|
||||
bin, s.err = os.ReadFile(Path)
|
||||
if s.err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.compiled, s.err = s.runtime.CompileModule(ctx, bin)
|
||||
}
|
||||
451
conn.go
451
conn.go
@@ -2,64 +2,119 @@ package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Conn is a database connection handle.
|
||||
// A Conn is not safe for concurrent use by multiple goroutines.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/sqlite3.html
|
||||
type Conn struct {
|
||||
ctx context.Context
|
||||
api sqliteAPI
|
||||
mem memory
|
||||
handle uint32
|
||||
*sqlite
|
||||
|
||||
arena arena
|
||||
pending *Stmt
|
||||
waiter chan struct{}
|
||||
done <-chan struct{}
|
||||
interrupt context.Context
|
||||
waiter chan struct{}
|
||||
pending *Stmt
|
||||
arena arena
|
||||
|
||||
handle uint32
|
||||
}
|
||||
|
||||
// Open calls [OpenFlags] with [OPEN_READWRITE] and [OPEN_CREATE].
|
||||
func Open(filename string) (conn *Conn, err error) {
|
||||
return OpenFlags(filename, OPEN_READWRITE|OPEN_CREATE)
|
||||
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE], [OPEN_URI] and [OPEN_NOFOLLOW].
|
||||
func Open(filename string) (*Conn, error) {
|
||||
return newConn(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI|OPEN_NOFOLLOW)
|
||||
}
|
||||
|
||||
// OpenFlags opens an SQLite database file as specified by the filename argument.
|
||||
//
|
||||
// If none of the required flags is used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used.
|
||||
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
|
||||
//
|
||||
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)")
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/open.html
|
||||
func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
|
||||
ctx := context.Background()
|
||||
module, err := sqlite3.instantiateModule(ctx)
|
||||
func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
|
||||
if flags&(OPEN_READONLY|OPEN_READWRITE|OPEN_CREATE) == 0 {
|
||||
flags |= OPEN_READWRITE | OPEN_CREATE
|
||||
}
|
||||
return newConn(filename, flags)
|
||||
}
|
||||
|
||||
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if conn == nil {
|
||||
module.Close(ctx)
|
||||
sqlite.close()
|
||||
} else {
|
||||
runtime.SetFinalizer(conn, util.Finalizer[Conn](3))
|
||||
}
|
||||
}()
|
||||
|
||||
c, err := newConn(ctx, module)
|
||||
c := &Conn{sqlite: sqlite}
|
||||
c.arena = c.newArena(1024)
|
||||
c.handle, err = c.openDB(filename, flags)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.arena = c.newArena(1024)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
|
||||
defer c.arena.reset()
|
||||
connPtr := c.arena.new(ptrlen)
|
||||
namePtr := c.arena.string(filename)
|
||||
|
||||
r, err := c.api.open.Call(c.ctx, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
flags |= OPEN_EXRESCODE
|
||||
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
|
||||
|
||||
handle := util.ReadUint32(c.mod, connPtr)
|
||||
if err := c.sqlite.error(r, handle); err != nil {
|
||||
c.closeDB(handle)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
c.handle = c.mem.readUint32(connPtr)
|
||||
if err := c.error(r[0]); err != nil {
|
||||
return nil, err
|
||||
if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
|
||||
var pragmas strings.Builder
|
||||
if _, after, ok := strings.Cut(filename, "?"); ok {
|
||||
query, _ := url.ParseQuery(after)
|
||||
for _, p := range query["_pragma"] {
|
||||
pragmas.WriteString(`PRAGMA `)
|
||||
pragmas.WriteString(p)
|
||||
pragmas.WriteByte(';')
|
||||
}
|
||||
}
|
||||
|
||||
c.arena.reset()
|
||||
pragmaPtr := c.arena.string(pragmas.String())
|
||||
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
|
||||
if err := c.sqlite.error(r, handle, pragmas.String()); err != nil {
|
||||
if errors.Is(err, ERROR) {
|
||||
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
|
||||
}
|
||||
c.closeDB(handle)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
func (c *Conn) closeDB(handle uint32) {
|
||||
r := c.call(c.api.closeZombie, uint64(handle))
|
||||
if err := c.sqlite.error(r, handle); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Close closes the database connection.
|
||||
@@ -68,88 +123,26 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
|
||||
// open blob handles, and/or unfinished backup objects,
|
||||
// Close will leave the database connection open and return [BUSY].
|
||||
//
|
||||
// It is safe to close a nil, zero or closed Conn.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/close.html
|
||||
func (c *Conn) Close() error {
|
||||
if c == nil || c.handle == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.SetInterrupt(nil)
|
||||
c.SetInterrupt(context.Background())
|
||||
c.pending.Close()
|
||||
c.pending = nil
|
||||
|
||||
r, err := c.api.close.Call(c.ctx, uint64(c.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := c.error(r[0]); err != nil {
|
||||
r := c.call(c.api.close, uint64(c.handle))
|
||||
if err := c.error(r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.handle = 0
|
||||
return c.mem.mod.Close(c.ctx)
|
||||
}
|
||||
|
||||
// SetInterrupt interrupts a long-running query when done is closed.
|
||||
//
|
||||
// Subsequent uses of the connection will return [INTERRUPT]
|
||||
// until done is reset by another call to SetInterrupt.
|
||||
//
|
||||
// Typically, done is provided by [context.Context.Done]:
|
||||
//
|
||||
// ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
|
||||
// conn.SetInterrupt(ctx.Done())
|
||||
// defer cancel()
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/interrupt.html
|
||||
func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) {
|
||||
// Is a waiter running?
|
||||
if c.waiter != nil {
|
||||
c.waiter <- struct{}{} // Cancel the waiter.
|
||||
<-c.waiter // Wait for it to finish.
|
||||
c.waiter = nil
|
||||
}
|
||||
|
||||
// Finalize the uncompleted SQL statement.
|
||||
if c.pending != nil {
|
||||
c.pending.Close()
|
||||
c.pending = nil
|
||||
}
|
||||
|
||||
old = c.done
|
||||
c.done = done
|
||||
if done == nil {
|
||||
return old
|
||||
}
|
||||
|
||||
// Creating an uncompleted SQL statement prevents SQLite from ignoring
|
||||
// an interrupt that comes before any other statements are started.
|
||||
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
|
||||
c.pending.Step()
|
||||
|
||||
waiter := make(chan struct{})
|
||||
c.waiter = waiter
|
||||
go func() {
|
||||
select {
|
||||
case <-waiter: // Waiter was cancelled.
|
||||
break
|
||||
|
||||
case <-done: // Done was closed.
|
||||
|
||||
// This is safe to call from a goroutine
|
||||
// because it doesn't touch the C stack.
|
||||
_, err := c.api.interrupt.Call(c.ctx, uint64(c.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Wait for the next call to SetInterrupt.
|
||||
<-waiter
|
||||
}
|
||||
|
||||
// Signal that the waiter has finished.
|
||||
waiter <- struct{}{}
|
||||
}()
|
||||
return old
|
||||
runtime.SetFinalizer(c, nil)
|
||||
return c.close()
|
||||
}
|
||||
|
||||
// Exec is a convenience function that allows an application to run
|
||||
@@ -157,14 +150,12 @@ func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/exec.html
|
||||
func (c *Conn) Exec(sql string) error {
|
||||
c.checkInterrupt()
|
||||
defer c.arena.reset()
|
||||
sqlPtr := c.arena.string(sql)
|
||||
|
||||
r, err := c.api.exec.Call(c.ctx, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return c.error(r[0])
|
||||
r := c.call(c.api.exec, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
|
||||
return c.error(r)
|
||||
}
|
||||
|
||||
// Prepare calls [Conn.PrepareFlags] with no flags.
|
||||
@@ -179,24 +170,25 @@ func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/prepare.html
|
||||
func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) {
|
||||
if emptyStatement(sql) {
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
defer c.arena.reset()
|
||||
stmtPtr := c.arena.new(ptrlen)
|
||||
tailPtr := c.arena.new(ptrlen)
|
||||
sqlPtr := c.arena.string(sql)
|
||||
|
||||
r, err := c.api.prepare.Call(c.ctx, uint64(c.handle),
|
||||
r := c.call(c.api.prepare, uint64(c.handle),
|
||||
uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),
|
||||
uint64(stmtPtr), uint64(tailPtr))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
stmt = &Stmt{c: c}
|
||||
stmt.handle = c.mem.readUint32(stmtPtr)
|
||||
i := c.mem.readUint32(tailPtr)
|
||||
stmt.handle = util.ReadUint32(c.mod, stmtPtr)
|
||||
i := util.ReadUint32(c.mod, tailPtr)
|
||||
tail = sql[i-sqlPtr:]
|
||||
|
||||
if err := c.error(r[0], sql); err != nil {
|
||||
if err := c.error(r, sql); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if stmt.handle == 0 {
|
||||
@@ -205,16 +197,21 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
|
||||
return
|
||||
}
|
||||
|
||||
// GetAutocommit tests the connection for auto-commit mode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/get_autocommit.html
|
||||
func (c *Conn) GetAutocommit() bool {
|
||||
r := c.call(c.api.autocommit, uint64(c.handle))
|
||||
return r != 0
|
||||
}
|
||||
|
||||
// LastInsertRowID returns the rowid of the most recent successful INSERT
|
||||
// on the database connection.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/last_insert_rowid.html
|
||||
func (c *Conn) LastInsertRowID() uint64 {
|
||||
r, err := c.api.lastRowid.Call(c.ctx, uint64(c.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r[0]
|
||||
func (c *Conn) LastInsertRowID() int64 {
|
||||
r := c.call(c.api.lastRowid, uint64(c.handle))
|
||||
return int64(r)
|
||||
}
|
||||
|
||||
// Changes returns the number of rows modified, inserted or deleted
|
||||
@@ -222,125 +219,121 @@ func (c *Conn) LastInsertRowID() uint64 {
|
||||
// on the database connection.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/changes.html
|
||||
func (c *Conn) Changes() uint64 {
|
||||
r, err := c.api.changes.Call(c.ctx, uint64(c.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
func (c *Conn) Changes() int64 {
|
||||
r := c.call(c.api.changes, uint64(c.handle))
|
||||
return int64(r)
|
||||
}
|
||||
|
||||
// SetInterrupt interrupts a long-running query when a context is done.
|
||||
//
|
||||
// Subsequent uses of the connection will return [INTERRUPT]
|
||||
// until the context is reset by another call to SetInterrupt.
|
||||
//
|
||||
// To associate a timeout with a connection:
|
||||
//
|
||||
// ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
|
||||
// conn.SetInterrupt(ctx)
|
||||
// defer cancel()
|
||||
//
|
||||
// SetInterrupt returns the old context assigned to the connection.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/interrupt.html
|
||||
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
|
||||
// Is it the same context?
|
||||
if ctx == c.interrupt {
|
||||
return ctx
|
||||
}
|
||||
return r[0]
|
||||
|
||||
// Is a waiter running?
|
||||
if c.waiter != nil {
|
||||
c.waiter <- struct{}{} // Cancel the waiter.
|
||||
<-c.waiter // Wait for it to finish.
|
||||
c.waiter = nil
|
||||
}
|
||||
// Reset the pending statement.
|
||||
if c.pending != nil {
|
||||
c.pending.Reset()
|
||||
}
|
||||
|
||||
old = c.interrupt
|
||||
c.interrupt = ctx
|
||||
if ctx == nil || ctx.Done() == nil {
|
||||
return old
|
||||
}
|
||||
|
||||
// Creating an uncompleted SQL statement prevents SQLite from ignoring
|
||||
// an interrupt that comes before any other statements are started.
|
||||
if c.pending == nil {
|
||||
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
|
||||
}
|
||||
c.pending.Step()
|
||||
|
||||
// Don't create the goroutine if we're already interrupted.
|
||||
// This happens frequently while restoring to a previously interrupted state.
|
||||
if c.checkInterrupt() {
|
||||
return old
|
||||
}
|
||||
|
||||
waiter := make(chan struct{})
|
||||
c.waiter = waiter
|
||||
go func() {
|
||||
select {
|
||||
case <-waiter: // Waiter was cancelled.
|
||||
break
|
||||
|
||||
case <-ctx.Done(): // Done was closed.
|
||||
const isInterruptedOffset = 288
|
||||
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
|
||||
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
|
||||
// Wait for the next call to SetInterrupt.
|
||||
<-waiter
|
||||
}
|
||||
|
||||
// Signal that the waiter has finished.
|
||||
waiter <- struct{}{}
|
||||
}()
|
||||
return old
|
||||
}
|
||||
|
||||
func (c *Conn) checkInterrupt() bool {
|
||||
if c.interrupt == nil || c.interrupt.Err() == nil {
|
||||
return false
|
||||
}
|
||||
const isInterruptedOffset = 288
|
||||
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
|
||||
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
|
||||
return true
|
||||
}
|
||||
|
||||
// Pragma executes a PRAGMA statement and returns any results.
|
||||
//
|
||||
// https://www.sqlite.org/pragma.html
|
||||
func (c *Conn) Pragma(str string) ([]string, error) {
|
||||
stmt, _, err := c.Prepare(`PRAGMA ` + str)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
var pragmas []string
|
||||
for stmt.Step() {
|
||||
pragmas = append(pragmas, stmt.ColumnText(0))
|
||||
}
|
||||
return pragmas, stmt.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) error(rc uint64, sql ...string) error {
|
||||
if rc == _OK {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := Error{code: rc}
|
||||
|
||||
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
|
||||
panic(oomErr)
|
||||
}
|
||||
|
||||
var r []uint64
|
||||
|
||||
r, _ = c.api.errstr.Call(c.ctx, rc)
|
||||
if r != nil {
|
||||
err.str = c.mem.readString(uint32(r[0]), _MAX_STRING)
|
||||
}
|
||||
|
||||
r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle))
|
||||
if r != nil {
|
||||
err.msg = c.mem.readString(uint32(r[0]), _MAX_STRING)
|
||||
}
|
||||
|
||||
if sql != nil {
|
||||
r, _ = c.api.erroff.Call(c.ctx, uint64(c.handle))
|
||||
if r != nil && r[0] != math.MaxUint32 {
|
||||
err.sql = sql[0][r[0]:]
|
||||
}
|
||||
}
|
||||
|
||||
switch err.msg {
|
||||
case err.str, "not an error":
|
||||
err.msg = ""
|
||||
}
|
||||
return &err
|
||||
return c.sqlite.error(rc, c.handle, sql...)
|
||||
}
|
||||
|
||||
func (c *Conn) free(ptr uint32) {
|
||||
if ptr == 0 {
|
||||
return
|
||||
}
|
||||
_, err := c.api.free.Call(c.ctx, uint64(ptr))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) new(size uint32) uint32 {
|
||||
r, err := c.api.malloc.Call(c.ctx, uint64(size))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ptr := uint32(r[0])
|
||||
if ptr == 0 && size != 0 {
|
||||
panic(oomErr)
|
||||
}
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (c *Conn) newBytes(b []byte) uint32 {
|
||||
if b == nil {
|
||||
return 0
|
||||
}
|
||||
ptr := c.new(uint32(len(b)))
|
||||
c.mem.writeBytes(ptr, b)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (c *Conn) newString(s string) uint32 {
|
||||
ptr := c.new(uint32(len(s) + 1))
|
||||
c.mem.writeString(ptr, s)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (c *Conn) newArena(size uint32) arena {
|
||||
return arena{
|
||||
c: c,
|
||||
size: size,
|
||||
base: c.new(size),
|
||||
}
|
||||
}
|
||||
|
||||
type arena struct {
|
||||
c *Conn
|
||||
base uint32
|
||||
next uint32
|
||||
size uint32
|
||||
ptrs []uint32
|
||||
}
|
||||
|
||||
func (a *arena) reset() {
|
||||
for _, ptr := range a.ptrs {
|
||||
a.c.free(ptr)
|
||||
}
|
||||
a.ptrs = nil
|
||||
a.next = 0
|
||||
}
|
||||
|
||||
func (a *arena) new(size uint32) uint32 {
|
||||
if a.next+size <= a.size {
|
||||
ptr := a.base + a.next
|
||||
a.next += size
|
||||
return ptr
|
||||
}
|
||||
ptr := a.c.new(size)
|
||||
a.ptrs = append(a.ptrs, ptr)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (a *arena) string(s string) uint32 {
|
||||
ptr := a.new(uint32(len(s) + 1))
|
||||
a.c.mem.writeString(ptr, s)
|
||||
return ptr
|
||||
// DriverConn is implemented by the SQLite [database/sql] driver connection.
|
||||
//
|
||||
// It can be used to access advanced SQLite features like
|
||||
// [savepoints], [online backup] and [incremental BLOB I/O].
|
||||
//
|
||||
// [savepoints]: https://www.sqlite.org/lang_savepoint.html
|
||||
// [online backup]: https://www.sqlite.org/backup.html
|
||||
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
|
||||
type DriverConn interface {
|
||||
Raw() *Conn
|
||||
}
|
||||
|
||||
342
conn_test.go
342
conn_test.go
@@ -1,342 +0,0 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConn_Close(t *testing.T) {
|
||||
var conn *Conn
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestConn_Close_BUSY(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(`BEGIN`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
err = db.Close()
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
var serr *Error
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != BUSY {
|
||||
t.Errorf("got %d, want sqlite3.BUSY", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: database is locked: unable to close due to unfinalized statements or unfinished backups` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_SetInterrupt(t *testing.T) {
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
db.SetInterrupt(ctx.Done())
|
||||
|
||||
// Interrupt doesn't interrupt this.
|
||||
err = db.Exec(`SELECT 1`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.SetInterrupt(nil)
|
||||
|
||||
stmt, _, err := db.Prepare(`
|
||||
WITH RECURSIVE
|
||||
fibonacci (curr, next)
|
||||
AS (
|
||||
SELECT 0, 1
|
||||
UNION ALL
|
||||
SELECT next, curr + next FROM fibonacci
|
||||
LIMIT 1e6
|
||||
)
|
||||
SELECT min(curr) FROM fibonacci
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
cancel()
|
||||
db.SetInterrupt(ctx.Done())
|
||||
|
||||
var serr *Error
|
||||
|
||||
// Interrupting works.
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != INTERRUPT {
|
||||
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: interrupted` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Interrupting sticks.
|
||||
err = db.Exec(`SELECT 1`)
|
||||
if err != nil {
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != INTERRUPT {
|
||||
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: interrupted` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
db.SetInterrupt(nil)
|
||||
|
||||
// Interrupting can be cleared.
|
||||
err = db.Exec(`SELECT 1`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Prepare_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(``)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt != nil {
|
||||
t.Error("want nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Prepare_Invalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
var serr *Error
|
||||
|
||||
_, _, err = db.Prepare(`SELECT`)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != ERROR {
|
||||
t.Errorf("got %d, want sqlite3.ERROR", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
|
||||
_, _, err = db.Prepare(`SELECT * FRM sqlite_schema`)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.ERROR", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != ERROR {
|
||||
t.Errorf("got %d, want sqlite3.ERROR", rc)
|
||||
}
|
||||
if got := serr.SQL(); got != `FRM sqlite_schema` {
|
||||
t.Error("got SQL: ", got)
|
||||
}
|
||||
if got := serr.Error(); got != `sqlite3: SQL logic error: near "FRM": syntax error` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_new(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
defer func() { _ = recover() }()
|
||||
db.new(math.MaxUint32)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestConn_newArena(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
arena := db.newArena(16)
|
||||
defer arena.reset()
|
||||
|
||||
const title = "Lorem ipsum"
|
||||
|
||||
ptr := arena.string(title)
|
||||
if ptr == 0 {
|
||||
t.Fatalf("got nullptr")
|
||||
}
|
||||
if got := db.mem.readString(ptr, math.MaxUint32); got != title {
|
||||
t.Errorf("got %q, want %q", got, title)
|
||||
}
|
||||
|
||||
const body = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
|
||||
ptr = arena.string(body)
|
||||
if ptr == 0 {
|
||||
t.Fatalf("got nullptr")
|
||||
}
|
||||
if got := db.mem.readString(ptr, math.MaxUint32); got != body {
|
||||
t.Errorf("got %q, want %q", got, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_newBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
ptr := db.newBytes(nil)
|
||||
if ptr != 0 {
|
||||
t.Errorf("got %#x, want nullptr", ptr)
|
||||
}
|
||||
|
||||
buf := []byte("sqlite3")
|
||||
ptr = db.newBytes(buf)
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
want := buf
|
||||
if got := db.mem.view(ptr, uint32(len(want))); !bytes.Equal(got, want) {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_newString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
ptr := db.newString("")
|
||||
if ptr == 0 {
|
||||
t.Error("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
str := "sqlite3\000sqlite3"
|
||||
ptr = db.newString(str)
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
want := str + "\000"
|
||||
if got := db.mem.view(ptr, uint32(len(want))); string(got) != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_getString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
ptr := db.newString("")
|
||||
if ptr == 0 {
|
||||
t.Error("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
str := "sqlite3" + "\000 drop this"
|
||||
ptr = db.newString(str)
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
want := "sqlite3"
|
||||
if got := db.mem.readString(ptr, math.MaxUint32); got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
if got := db.mem.readString(ptr, 0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() { _ = recover() }()
|
||||
db.mem.readString(ptr, uint32(len(want)/2))
|
||||
t.Error("want panic")
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() { _ = recover() }()
|
||||
db.mem.readString(0, math.MaxUint32)
|
||||
t.Error("want panic")
|
||||
}()
|
||||
}
|
||||
|
||||
func TestConn_free(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
db.free(0)
|
||||
|
||||
ptr := db.new(1)
|
||||
if ptr == 0 {
|
||||
t.Error("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
db.free(ptr)
|
||||
}
|
||||
70
const.go
70
const.go
@@ -9,8 +9,9 @@ const (
|
||||
|
||||
_UTF8 = 1
|
||||
|
||||
_MAX_STRING = 512 // Used for short strings: names, error messages…
|
||||
_MAX_PATHNAME = 512
|
||||
_MAX_STRING = 512 // Used for short strings: names, error messages…
|
||||
|
||||
_MAX_ALLOCATION_SIZE = 0x7ffffeff
|
||||
|
||||
ptrlen = 4
|
||||
)
|
||||
@@ -131,46 +132,28 @@ const (
|
||||
CONSTRAINT_DATATYPE ExtendedErrorCode = xErrorCode(CONSTRAINT) | (12 << 8)
|
||||
NOTICE_RECOVER_WAL ExtendedErrorCode = xErrorCode(NOTICE) | (1 << 8)
|
||||
NOTICE_RECOVER_ROLLBACK ExtendedErrorCode = xErrorCode(NOTICE) | (2 << 8)
|
||||
NOTICE_RBU ExtendedErrorCode = xErrorCode(NOTICE) | (3 << 8)
|
||||
WARNING_AUTOINDEX ExtendedErrorCode = xErrorCode(WARNING) | (1 << 8)
|
||||
AUTH_USER ExtendedErrorCode = xErrorCode(AUTH) | (1 << 8)
|
||||
)
|
||||
|
||||
// OpenFlag is a flag for a file open operation.
|
||||
// OpenFlag is a flag for the [OpenFlags] function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_open_autoproxy.html
|
||||
type OpenFlag uint32
|
||||
|
||||
const (
|
||||
OPEN_READONLY OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_READWRITE OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_CREATE OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_DELETEONCLOSE OpenFlag = 0x00000008 /* VFS only */
|
||||
OPEN_EXCLUSIVE OpenFlag = 0x00000010 /* VFS only */
|
||||
OPEN_AUTOPROXY OpenFlag = 0x00000020 /* VFS only */
|
||||
OPEN_URI OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_MEMORY OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_MAIN_DB OpenFlag = 0x00000100 /* VFS only */
|
||||
OPEN_TEMP_DB OpenFlag = 0x00000200 /* VFS only */
|
||||
OPEN_TRANSIENT_DB OpenFlag = 0x00000400 /* VFS only */
|
||||
OPEN_MAIN_JOURNAL OpenFlag = 0x00000800 /* VFS only */
|
||||
OPEN_TEMP_JOURNAL OpenFlag = 0x00001000 /* VFS only */
|
||||
OPEN_SUBJOURNAL OpenFlag = 0x00002000 /* VFS only */
|
||||
OPEN_SUPER_JOURNAL OpenFlag = 0x00004000 /* VFS only */
|
||||
OPEN_NOMUTEX OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_FULLMUTEX OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_SHAREDCACHE OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_PRIVATECACHE OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_WAL OpenFlag = 0x00080000 /* VFS only */
|
||||
OPEN_NOFOLLOW OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */
|
||||
)
|
||||
|
||||
type _AccessFlag uint32
|
||||
|
||||
const (
|
||||
_ACCESS_EXISTS _AccessFlag = 0
|
||||
_ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
|
||||
_ACCESS_READ _AccessFlag = 2 /* Unused */
|
||||
OPEN_READONLY OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_READWRITE OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_CREATE OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_URI OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_MEMORY OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_NOMUTEX OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_FULLMUTEX OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_SHAREDCACHE OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_PRIVATECACHE OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_NOFOLLOW OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */
|
||||
)
|
||||
|
||||
// PrepareFlag is a flag that can be passed to [Conn.PrepareFlags].
|
||||
@@ -184,6 +167,18 @@ const (
|
||||
PREPARE_NO_VTAB PrepareFlag = 0x04
|
||||
)
|
||||
|
||||
// FunctionFlag is a flag that can be passed to [Conn.PrepareFlags].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_deterministic.html
|
||||
type FunctionFlag uint32
|
||||
|
||||
const (
|
||||
DETERMINISTIC FunctionFlag = 0x000000800
|
||||
DIRECTONLY FunctionFlag = 0x000080000
|
||||
SUBTYPE FunctionFlag = 0x000100000
|
||||
INNOCUOUS FunctionFlag = 0x000200000
|
||||
)
|
||||
|
||||
// Datatype is a fundamental datatype of SQLite.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_blob.html
|
||||
@@ -197,19 +192,20 @@ const (
|
||||
NULL Datatype = 5
|
||||
)
|
||||
|
||||
// String implements the [fmt.Stringer] interface.
|
||||
func (t Datatype) String() string {
|
||||
const name = "INTEGERFLOATTEXTBLOBNULL"
|
||||
const name = "INTEGERFLOATEXTBLOBNULL"
|
||||
switch t {
|
||||
case INTEGER:
|
||||
return name[0:7]
|
||||
case FLOAT:
|
||||
return name[7:12]
|
||||
case TEXT:
|
||||
return name[12:16]
|
||||
return name[11:15]
|
||||
case BLOB:
|
||||
return name[16:20]
|
||||
return name[15:19]
|
||||
case NULL:
|
||||
return name[20:24]
|
||||
return name[19:23]
|
||||
}
|
||||
return strconv.FormatUint(uint64(t), 10)
|
||||
}
|
||||
|
||||
174
context.go
Normal file
174
context.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Context is the context in which an SQL function executes.
|
||||
// An SQLite [Context] is in no way related to a Go [context.Context].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/context.html
|
||||
type Context struct {
|
||||
*sqlite
|
||||
handle uint32
|
||||
}
|
||||
|
||||
// SetAuxData saves metadata for argument n of the function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/get_auxdata.html
|
||||
func (c Context) SetAuxData(n int, data any) {
|
||||
ptr := util.AddHandle(c.ctx, data)
|
||||
c.call(c.api.setAuxData, uint64(c.handle), uint64(n), uint64(ptr))
|
||||
}
|
||||
|
||||
// GetAuxData returns metadata for argument n of the function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/get_auxdata.html
|
||||
func (c Context) GetAuxData(n int) any {
|
||||
ptr := uint32(c.call(c.api.getAuxData, uint64(c.handle), uint64(n)))
|
||||
return util.GetHandle(c.ctx, ptr)
|
||||
}
|
||||
|
||||
// ResultBool sets the result of the function to a bool.
|
||||
// SQLite does not have a separate boolean storage class.
|
||||
// Instead, boolean values are stored as integers 0 (false) and 1 (true).
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultBool(value bool) {
|
||||
var i int64
|
||||
if value {
|
||||
i = 1
|
||||
}
|
||||
c.ResultInt64(i)
|
||||
}
|
||||
|
||||
// ResultInt sets the result of the function to an int.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultInt(value int) {
|
||||
c.ResultInt64(int64(value))
|
||||
}
|
||||
|
||||
// ResultInt64 sets the result of the function to an int64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultInt64(value int64) {
|
||||
c.call(c.api.resultInteger,
|
||||
uint64(c.handle), uint64(value))
|
||||
}
|
||||
|
||||
// ResultFloat sets the result of the function to a float64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultFloat(value float64) {
|
||||
c.call(c.api.resultFloat,
|
||||
uint64(c.handle), math.Float64bits(value))
|
||||
}
|
||||
|
||||
// ResultText sets the result of the function to a string.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultText(value string) {
|
||||
ptr := c.newString(value)
|
||||
c.call(c.api.resultText,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(value)),
|
||||
uint64(c.api.destructor), _UTF8)
|
||||
}
|
||||
|
||||
// ResultBlob sets the result of the function to a []byte.
|
||||
// Returning a nil slice is the same as calling [Context.ResultNull].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultBlob(value []byte) {
|
||||
ptr := c.newBytes(value)
|
||||
c.call(c.api.resultBlob,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(value)),
|
||||
uint64(c.api.destructor))
|
||||
}
|
||||
|
||||
// BindZeroBlob sets the result of the function to a zero-filled, length n BLOB.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultZeroBlob(n int64) {
|
||||
c.call(c.api.resultZeroBlob,
|
||||
uint64(c.handle), uint64(n))
|
||||
}
|
||||
|
||||
// ResultNull sets the result of the function to NULL.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultNull() {
|
||||
c.call(c.api.resultNull,
|
||||
uint64(c.handle))
|
||||
}
|
||||
|
||||
// ResultTime sets the result of the function to a [time.Time].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultTime(value time.Time, format TimeFormat) {
|
||||
if format == TimeFormatDefault {
|
||||
c.resultRFC3339Nano(value)
|
||||
return
|
||||
}
|
||||
switch v := format.Encode(value).(type) {
|
||||
case string:
|
||||
c.ResultText(v)
|
||||
case int64:
|
||||
c.ResultInt64(v)
|
||||
case float64:
|
||||
c.ResultFloat(v)
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
}
|
||||
|
||||
func (c Context) resultRFC3339Nano(value time.Time) {
|
||||
const maxlen = uint64(len(time.RFC3339Nano))
|
||||
|
||||
ptr := c.new(maxlen)
|
||||
buf := util.View(c.mod, ptr, maxlen)
|
||||
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
|
||||
|
||||
c.call(c.api.resultText,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(buf)),
|
||||
uint64(c.api.destructor), _UTF8)
|
||||
}
|
||||
|
||||
// ResultError sets the result of the function an error.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultError(err error) {
|
||||
if errors.Is(err, NOMEM) {
|
||||
c.call(c.api.resultErrorMem, uint64(c.handle))
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, TOOBIG) {
|
||||
c.call(c.api.resultErrorBig, uint64(c.handle))
|
||||
return
|
||||
}
|
||||
|
||||
str := err.Error()
|
||||
ptr := c.newString(str)
|
||||
c.call(c.api.resultError,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(str)))
|
||||
c.free(ptr)
|
||||
|
||||
var code uint64
|
||||
var ecode ErrorCode
|
||||
var xcode xErrorCode
|
||||
switch {
|
||||
case errors.As(err, &xcode):
|
||||
code = uint64(xcode)
|
||||
case errors.As(err, &ecode):
|
||||
code = uint64(ecode)
|
||||
}
|
||||
if code != 0 {
|
||||
c.call(c.api.resultErrorCode,
|
||||
uint64(c.handle), code)
|
||||
}
|
||||
}
|
||||
418
driver/driver.go
418
driver/driver.go
@@ -1,181 +1,352 @@
|
||||
// Package driver provides a database/sql driver for SQLite.
|
||||
//
|
||||
// Importing package driver registers a [database/sql] driver named "sqlite3".
|
||||
// You may also need to import package embed.
|
||||
//
|
||||
// import _ "github.com/ncruces/go-sqlite3/driver"
|
||||
// import _ "github.com/ncruces/go-sqlite3/embed"
|
||||
//
|
||||
// The data source name for "sqlite3" databases can be a filename or a "file:" [URI].
|
||||
//
|
||||
// The [TRANSACTION] mode can be specified using "_txlock":
|
||||
//
|
||||
// sql.Open("sqlite3", "file:demo.db?_txlock=immediate")
|
||||
//
|
||||
// [PRAGMA] statements can be specified using "_pragma":
|
||||
//
|
||||
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)")
|
||||
//
|
||||
// If no PRAGMAs are specified, a busy timeout of 1 minute is set.
|
||||
//
|
||||
// Order matters:
|
||||
// busy timeout and locking mode should be the first PRAGMAs set, in that order.
|
||||
//
|
||||
// [URI]: https://www.sqlite.org/uri.html
|
||||
// [PRAGMA]: https://www.sqlite.org/pragma.html
|
||||
// [TRANSACTION]: https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
|
||||
package driver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// This variable can be replaced with -ldflags:
|
||||
//
|
||||
// go build -ldflags="-X github.com/ncruces/go-sqlite3.driverName=sqlite"
|
||||
var driverName = "sqlite3"
|
||||
|
||||
func init() {
|
||||
sql.Register("sqlite3", sqlite{})
|
||||
if driverName != "" {
|
||||
sql.Register(driverName, sqlite{})
|
||||
}
|
||||
}
|
||||
|
||||
// Open opens the SQLite database specified by dataSourceName as a [database/sql.DB].
|
||||
//
|
||||
// The init function is called by the driver on new connections.
|
||||
// The conn can be used to execute queries, register functions, etc.
|
||||
// Any error return closes the conn and passes the error to database/sql.
|
||||
func Open(dataSourceName string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*sql.DB, error) {
|
||||
c, err := newConnector(dataSourceName, init)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sql.OpenDB(c), nil
|
||||
}
|
||||
|
||||
type sqlite struct{}
|
||||
|
||||
func (sqlite) Open(name string) (driver.Conn, error) {
|
||||
c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE)
|
||||
c, err := newConnector(name, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// If the database is not in WAL mode,
|
||||
// use normal locking mode.
|
||||
journal, err := pragma(c, "journal_mode")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if journal != "wal" {
|
||||
pragma(c, "locking_mode=normal")
|
||||
}
|
||||
return conn{c}, nil
|
||||
return c.Connect(context.Background())
|
||||
}
|
||||
|
||||
type conn struct{ conn *sqlite3.Conn }
|
||||
func (sqlite) OpenConnector(name string) (driver.Connector, error) {
|
||||
return newConnector(name, nil)
|
||||
}
|
||||
|
||||
func newConnector(name string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*connector, error) {
|
||||
c := connector{name: name, init: init}
|
||||
if strings.HasPrefix(name, "file:") {
|
||||
if _, after, ok := strings.Cut(name, "?"); ok {
|
||||
query, err := url.ParseQuery(after)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.txlock = query.Get("_txlock")
|
||||
c.pragmas = len(query["_pragma"]) > 0
|
||||
}
|
||||
}
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
type connector struct {
|
||||
init func(ctx context.Context, conn *sqlite3.Conn) error
|
||||
name string
|
||||
txlock string
|
||||
pragmas bool
|
||||
}
|
||||
|
||||
func (n *connector) Driver() driver.Driver {
|
||||
return sqlite{}
|
||||
}
|
||||
|
||||
func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
|
||||
var c conn
|
||||
c.Conn, err = sqlite3.Open(n.name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
c.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
|
||||
switch n.txlock {
|
||||
case "":
|
||||
c.txBegin = "BEGIN"
|
||||
case "deferred", "immediate", "exclusive":
|
||||
c.txBegin = "BEGIN " + n.txlock
|
||||
default:
|
||||
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", n.txlock)
|
||||
}
|
||||
if !n.pragmas {
|
||||
err = c.Conn.Exec(`PRAGMA busy_timeout=60000`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.reusable = true
|
||||
} else {
|
||||
s, _, err := c.Conn.Prepare(`
|
||||
SELECT * FROM
|
||||
PRAGMA_locking_mode,
|
||||
PRAGMA_query_only;
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.Step() {
|
||||
c.reusable = s.ColumnText(0) == "normal"
|
||||
c.readOnly = s.ColumnRawText(1)[0] // 0 or 1
|
||||
}
|
||||
err = s.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if n.init != nil {
|
||||
err = n.init(ctx, c.Conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
*sqlite3.Conn
|
||||
txBegin string
|
||||
txCommit string
|
||||
txRollback string
|
||||
reusable bool
|
||||
readOnly byte
|
||||
}
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ driver.Validator = conn{}
|
||||
_ driver.ExecerContext = conn{}
|
||||
// _ driver.ConnBeginTx = conn{}
|
||||
// _ driver.SessionResetter = conn{}
|
||||
_ driver.ConnPrepareContext = &conn{}
|
||||
_ driver.ExecerContext = &conn{}
|
||||
_ driver.ConnBeginTx = &conn{}
|
||||
_ driver.Validator = &conn{}
|
||||
_ sqlite3.DriverConn = &conn{}
|
||||
)
|
||||
|
||||
func (c conn) Close() error {
|
||||
return c.conn.Close()
|
||||
func (c *conn) Raw() *sqlite3.Conn {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c conn) IsValid() bool {
|
||||
// Pool only normal locking mode connections.
|
||||
mode, _ := pragma(c.conn, "locking_mode")
|
||||
return mode == "normal"
|
||||
func (c *conn) IsValid() bool {
|
||||
return c.reusable
|
||||
}
|
||||
|
||||
func (c conn) Begin() (driver.Tx, error) {
|
||||
err := c.conn.Exec(`BEGIN`)
|
||||
func (c *conn) Begin() (driver.Tx, error) {
|
||||
return c.BeginTx(context.Background(), driver.TxOptions{})
|
||||
}
|
||||
|
||||
func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
txBegin := c.txBegin
|
||||
c.txCommit = `COMMIT`
|
||||
c.txRollback = `ROLLBACK`
|
||||
|
||||
if opts.ReadOnly {
|
||||
txBegin = `
|
||||
BEGIN deferred;
|
||||
PRAGMA query_only=on`
|
||||
c.txCommit = `
|
||||
ROLLBACK;
|
||||
PRAGMA query_only=` + string(c.readOnly)
|
||||
c.txRollback = c.txCommit
|
||||
}
|
||||
|
||||
switch opts.Isolation {
|
||||
default:
|
||||
return nil, util.IsolationErr
|
||||
case
|
||||
driver.IsolationLevel(sql.LevelDefault),
|
||||
driver.IsolationLevel(sql.LevelSerializable):
|
||||
break
|
||||
}
|
||||
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
|
||||
err := c.Conn.Exec(txBegin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c conn) Commit() error {
|
||||
err := c.conn.Exec(`COMMIT`)
|
||||
if err != nil {
|
||||
func (c *conn) Commit() error {
|
||||
err := c.Conn.Exec(c.txCommit)
|
||||
if err != nil && !c.Conn.GetAutocommit() {
|
||||
c.Rollback()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c conn) Rollback() error {
|
||||
return c.conn.Exec(`ROLLBACK`)
|
||||
func (c *conn) Rollback() error {
|
||||
return c.Conn.Exec(c.txRollback)
|
||||
}
|
||||
|
||||
func (c conn) Prepare(query string) (driver.Stmt, error) {
|
||||
s, tail, err := c.conn.Prepare(query)
|
||||
func (c *conn) Prepare(query string) (driver.Stmt, error) {
|
||||
return c.PrepareContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
|
||||
s, tail, err := c.Conn.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tail != "" {
|
||||
// Check if the tail contains any SQL.
|
||||
s, _, err := c.conn.Prepare(tail)
|
||||
st, _, err := c.Conn.Prepare(tail)
|
||||
if err != nil {
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
if s != nil {
|
||||
if st != nil {
|
||||
s.Close()
|
||||
return nil, tailErr
|
||||
st.Close()
|
||||
return nil, util.TailErr
|
||||
}
|
||||
}
|
||||
return stmt{s, c.conn}, nil
|
||||
return &stmt{s, c.Conn}, nil
|
||||
}
|
||||
|
||||
func (c conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
if len(args) != 0 {
|
||||
// Slow path.
|
||||
return nil, driver.ErrSkip
|
||||
}
|
||||
|
||||
ch := c.conn.SetInterrupt(ctx.Done())
|
||||
defer c.conn.SetInterrupt(ch)
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
|
||||
err := c.conn.Exec(query)
|
||||
err := c.Conn.Exec(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result{
|
||||
int64(c.conn.LastInsertRowID()),
|
||||
int64(c.conn.Changes()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func pragma(c *sqlite3.Conn, pragma string) (string, error) {
|
||||
stmt, _, err := c.Prepare(`PRAGMA ` + pragma)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if stmt.Step() {
|
||||
return stmt.ColumnText(0), nil
|
||||
}
|
||||
return "", stmt.Err()
|
||||
return newResult(c.Conn), nil
|
||||
}
|
||||
|
||||
type stmt struct {
|
||||
stmt *sqlite3.Stmt
|
||||
conn *sqlite3.Conn
|
||||
Stmt *sqlite3.Stmt
|
||||
Conn *sqlite3.Conn
|
||||
}
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ driver.StmtExecContext = stmt{}
|
||||
_ driver.StmtQueryContext = stmt{}
|
||||
_ driver.StmtExecContext = &stmt{}
|
||||
_ driver.StmtQueryContext = &stmt{}
|
||||
_ driver.NamedValueChecker = &stmt{}
|
||||
)
|
||||
|
||||
func (s stmt) Close() error {
|
||||
return s.stmt.Close()
|
||||
func (s *stmt) Close() error {
|
||||
return s.Stmt.Close()
|
||||
}
|
||||
|
||||
func (s stmt) NumInput() int {
|
||||
return s.stmt.BindCount()
|
||||
func (s *stmt) NumInput() int {
|
||||
n := s.Stmt.BindCount()
|
||||
for i := 1; i <= n; i++ {
|
||||
if s.Stmt.BindName(i) != "" {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// Deprecated: use ExecContext instead.
|
||||
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
return s.ExecContext(context.Background(), namedValues(args))
|
||||
}
|
||||
|
||||
// Deprecated: use QueryContext instead.
|
||||
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
return s.QueryContext(context.Background(), namedValues(args))
|
||||
}
|
||||
|
||||
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
_, err := s.QueryContext(ctx, args)
|
||||
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
err := s.setupBindings(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.stmt.Exec()
|
||||
old := s.Conn.SetInterrupt(ctx)
|
||||
defer s.Conn.SetInterrupt(old)
|
||||
|
||||
err = s.Stmt.Exec()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result{
|
||||
int64(s.conn.LastInsertRowID()),
|
||||
int64(s.conn.Changes()),
|
||||
}, nil
|
||||
return newResult(s.Conn), nil
|
||||
}
|
||||
|
||||
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
err := s.stmt.ClearBindings()
|
||||
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
err := s.setupBindings(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rows{ctx, s.Stmt, s.Conn}, nil
|
||||
}
|
||||
|
||||
func (s *stmt) setupBindings(args []driver.NamedValue) error {
|
||||
err := s.Stmt.ClearBindings()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ids [3]int
|
||||
for _, arg := range args {
|
||||
@@ -184,7 +355,7 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
ids = append(ids, arg.Ordinal)
|
||||
} else {
|
||||
for _, prefix := range []string{":", "@", "$"} {
|
||||
if id := s.stmt.BindIndex(prefix + arg.Name); id != 0 {
|
||||
if id := s.Stmt.BindIndex(prefix + arg.Name); id != 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
@@ -193,29 +364,53 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
for _, id := range ids {
|
||||
switch a := arg.Value.(type) {
|
||||
case bool:
|
||||
err = s.stmt.BindBool(id, a)
|
||||
err = s.Stmt.BindBool(id, a)
|
||||
case int:
|
||||
err = s.Stmt.BindInt(id, a)
|
||||
case int64:
|
||||
err = s.stmt.BindInt64(id, a)
|
||||
err = s.Stmt.BindInt64(id, a)
|
||||
case float64:
|
||||
err = s.stmt.BindFloat(id, a)
|
||||
err = s.Stmt.BindFloat(id, a)
|
||||
case string:
|
||||
err = s.stmt.BindText(id, a)
|
||||
err = s.Stmt.BindText(id, a)
|
||||
case []byte:
|
||||
err = s.stmt.BindBlob(id, a)
|
||||
err = s.Stmt.BindBlob(id, a)
|
||||
case sqlite3.ZeroBlob:
|
||||
err = s.Stmt.BindZeroBlob(id, int64(a))
|
||||
case time.Time:
|
||||
err = s.stmt.BindText(id, a.Format(time.RFC3339Nano))
|
||||
err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
|
||||
case nil:
|
||||
err = s.stmt.BindNull(id)
|
||||
err = s.Stmt.BindNull(id)
|
||||
default:
|
||||
panic(assertErr)
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return rows{ctx, s.stmt, s.conn}, nil
|
||||
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
|
||||
switch arg.Value.(type) {
|
||||
case bool, int, int64, float64, string, []byte,
|
||||
sqlite3.ZeroBlob, time.Time, nil:
|
||||
return nil
|
||||
default:
|
||||
return driver.ErrSkip
|
||||
}
|
||||
}
|
||||
|
||||
func newResult(c *sqlite3.Conn) driver.Result {
|
||||
rows := c.Changes()
|
||||
if rows != 0 {
|
||||
id := c.LastInsertRowID()
|
||||
if id != 0 {
|
||||
return result{id, rows}
|
||||
}
|
||||
}
|
||||
return resultRowsAffected(rows)
|
||||
}
|
||||
|
||||
type result struct{ lastInsertId, rowsAffected int64 }
|
||||
@@ -228,47 +423,56 @@ func (r result) RowsAffected() (int64, error) {
|
||||
return r.rowsAffected, nil
|
||||
}
|
||||
|
||||
type resultRowsAffected int64
|
||||
|
||||
func (r resultRowsAffected) LastInsertId() (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r resultRowsAffected) RowsAffected() (int64, error) {
|
||||
return int64(r), nil
|
||||
}
|
||||
|
||||
type rows struct {
|
||||
ctx context.Context
|
||||
stmt *sqlite3.Stmt
|
||||
conn *sqlite3.Conn
|
||||
Stmt *sqlite3.Stmt
|
||||
Conn *sqlite3.Conn
|
||||
}
|
||||
|
||||
func (r rows) Close() error {
|
||||
return r.stmt.Reset()
|
||||
func (r *rows) Close() error {
|
||||
return r.Stmt.Reset()
|
||||
}
|
||||
|
||||
func (r rows) Columns() []string {
|
||||
count := r.stmt.ColumnCount()
|
||||
func (r *rows) Columns() []string {
|
||||
count := r.Stmt.ColumnCount()
|
||||
columns := make([]string, count)
|
||||
for i := range columns {
|
||||
columns[i] = r.stmt.ColumnName(i)
|
||||
columns[i] = r.Stmt.ColumnName(i)
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
func (r rows) Next(dest []driver.Value) error {
|
||||
ch := r.conn.SetInterrupt(r.ctx.Done())
|
||||
defer r.conn.SetInterrupt(ch)
|
||||
func (r *rows) Next(dest []driver.Value) error {
|
||||
old := r.Conn.SetInterrupt(r.ctx)
|
||||
defer r.Conn.SetInterrupt(old)
|
||||
|
||||
if !r.stmt.Step() {
|
||||
if err := r.stmt.Err(); err != nil {
|
||||
if !r.Stmt.Step() {
|
||||
if err := r.Stmt.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
for i := range dest {
|
||||
switch r.stmt.ColumnType(i) {
|
||||
switch r.Stmt.ColumnType(i) {
|
||||
case sqlite3.INTEGER:
|
||||
dest[i] = r.stmt.ColumnInt64(i)
|
||||
dest[i] = r.Stmt.ColumnInt64(i)
|
||||
case sqlite3.FLOAT:
|
||||
dest[i] = r.stmt.ColumnFloat(i)
|
||||
case sqlite3.TEXT:
|
||||
dest[i] = maybeDate(r.stmt.ColumnText(i))
|
||||
dest[i] = r.Stmt.ColumnFloat(i)
|
||||
case sqlite3.BLOB:
|
||||
buf, _ := dest[i].([]byte)
|
||||
dest[i] = r.stmt.ColumnBlob(i, buf)
|
||||
dest[i] = r.Stmt.ColumnRawBlob(i)
|
||||
case sqlite3.TEXT:
|
||||
dest[i] = stringOrTime(r.Stmt.ColumnRawText(i))
|
||||
case sqlite3.NULL:
|
||||
if buf, ok := dest[i].([]byte); ok {
|
||||
dest[i] = buf[0:0]
|
||||
@@ -276,9 +480,9 @@ func (r rows) Next(dest []driver.Value) error {
|
||||
dest[i] = nil
|
||||
}
|
||||
default:
|
||||
panic(assertErr)
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
}
|
||||
|
||||
return r.stmt.Err()
|
||||
return r.Stmt.Err()
|
||||
}
|
||||
|
||||
311
driver/driver_test.go
Normal file
311
driver/driver_test.go
Normal file
@@ -0,0 +1,311 @@
|
||||
package driver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"math"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
func Test_Open_dir(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sql.Open("sqlite3", ".")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Conn(context.TODO())
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.CANTOPEN) {
|
||||
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Open_pragma(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sql.Open("sqlite3", "file::memory:?_pragma=busy_timeout(1000)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
var timeout int
|
||||
err = db.QueryRow(`PRAGMA busy_timeout`).Scan(&timeout)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if timeout != 1000 {
|
||||
t.Errorf("got %v, want 1000", timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Open_pragma_invalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sql.Open("sqlite3", "file::memory:?_pragma=busy_timeout+1000")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Conn(context.TODO())
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
var serr *sqlite3.Error
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.ERROR {
|
||||
t.Errorf("got %d, want sqlite3.ERROR", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: invalid _pragma: sqlite3: SQL logic error: near "1000": syntax error` {
|
||||
t.Error("got message:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Open_txLock(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:"+
|
||||
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
|
||||
"?_txlock=exclusive&_pragma=busy_timeout(0)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
tx1, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.Begin()
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.BUSY) {
|
||||
t.Errorf("got %v, want sqlite3.BUSY", err)
|
||||
}
|
||||
var terr interface{ Temporary() bool }
|
||||
if !errors.As(err, &terr) || !terr.Temporary() {
|
||||
t.Error("not temporary", err)
|
||||
}
|
||||
|
||||
err = tx1.Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Open_txLock_invalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sql.Open("sqlite3", "file::memory:?_txlock=xclusive")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Conn(context.TODO())
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: invalid _txlock: xclusive` {
|
||||
t.Error("got message:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_BeginTx(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:"+
|
||||
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
|
||||
"?_txlock=exclusive&_pragma=busy_timeout(0)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
|
||||
if err.Error() != string(util.IsolationErr) {
|
||||
t.Error("want isolationErr")
|
||||
}
|
||||
|
||||
tx1, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tx2, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = tx1.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.READONLY) {
|
||||
t.Errorf("got %v, want sqlite3.READONLY", err)
|
||||
}
|
||||
|
||||
err = tx2.Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = tx1.Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Prepare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, err := db.Prepare(`SELECT 1; -- HERE`)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
var serr *sqlite3.Error
|
||||
_, err = db.Prepare(`SELECT`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.ERROR {
|
||||
t.Errorf("got %d, want sqlite3.ERROR", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
|
||||
t.Error("got message:", got)
|
||||
}
|
||||
|
||||
_, err = db.Prepare(`SELECT 1; SELECT`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.ERROR {
|
||||
t.Errorf("got %d, want sqlite3.ERROR", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
|
||||
t.Error("got message:", got)
|
||||
}
|
||||
|
||||
_, err = db.Prepare(`SELECT 1; SELECT 2`)
|
||||
if err.Error() != string(util.TailErr) {
|
||||
t.Error("want tailErr")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_QueryRow_named(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn, err := db.Conn(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
stmt, err := conn.PrepareContext(ctx, `SELECT ?, ?5, :AAA, @AAA, $AAA`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
date := time.Now()
|
||||
row := stmt.QueryRow(true, sql.Named("AAA", math.Pi), nil /*3*/, nil /*4*/, date /*5*/)
|
||||
|
||||
var first bool
|
||||
var fifth time.Time
|
||||
var colon, at, dollar float32
|
||||
err = row.Scan(&first, &fifth, &colon, &at, &dollar)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if first != true {
|
||||
t.Errorf("want true, got %v", first)
|
||||
}
|
||||
if colon != math.Pi {
|
||||
t.Errorf("want π, got %v", colon)
|
||||
}
|
||||
if at != math.Pi {
|
||||
t.Errorf("want π, got %v", at)
|
||||
}
|
||||
if dollar != math.Pi {
|
||||
t.Errorf("want π, got %v", dollar)
|
||||
}
|
||||
if !fifth.Equal(date) {
|
||||
t.Errorf("want %v, got %v", date, fifth)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_QueryRow_blob_null(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows, err := db.Query(`
|
||||
SELECT NULL UNION ALL
|
||||
SELECT x'cafe' UNION ALL
|
||||
SELECT x'babe' UNION ALL
|
||||
SELECT NULL
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
want := [][]byte{nil, {0xca, 0xfe}, {0xba, 0xbe}, nil}
|
||||
for i := 0; rows.Next(); i++ {
|
||||
var buf sql.RawBytes
|
||||
err = rows.Scan(&buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(buf, want[i]) {
|
||||
t.Errorf("got %q, want %q", buf, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
package driver
|
||||
|
||||
type errorString string
|
||||
|
||||
func (e errorString) Error() string { return string(e) }
|
||||
|
||||
const (
|
||||
assertErr = errorString("sqlite3: assertion failed")
|
||||
tailErr = errorString("sqlite3: multiple statements")
|
||||
)
|
||||
@@ -28,10 +28,11 @@ func Example() {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
defer os.Remove("./recordings.db")
|
||||
defer db.Close()
|
||||
|
||||
err = createAlbumsTable()
|
||||
// Create a table with some data in it.
|
||||
err = albumsSetup()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -58,14 +59,13 @@ func Example() {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("ID of added album: %v\n", albID)
|
||||
|
||||
// Output:
|
||||
// Albums found: [{1 Blue Train John Coltrane 56.99} {2 Giant Steps John Coltrane 63.99}]
|
||||
// Album found: {2 Giant Steps John Coltrane 63.99}
|
||||
// ID of added album: 5
|
||||
}
|
||||
|
||||
func createAlbumsTable() error {
|
||||
func albumsSetup() error {
|
||||
_, err := db.Exec(`
|
||||
DROP TABLE IF EXISTS album;
|
||||
CREATE TABLE album (
|
||||
|
||||
@@ -8,12 +8,24 @@ import (
|
||||
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
|
||||
// if it roundtrips back to the same string.
|
||||
// This way times can be persisted to, and recovered from, the database,
|
||||
// but if a string is needed, [database.sql] will recover the same string.
|
||||
// TODO: optimize and fuzz test.
|
||||
func maybeDate(text string) driver.Value {
|
||||
date, err := time.Parse(time.RFC3339Nano, text)
|
||||
if err == nil && date.Format(time.RFC3339Nano) == text {
|
||||
// but if a string is needed, [database/sql] will recover the same string.
|
||||
func stringOrTime(text []byte) driver.Value {
|
||||
// Weed out (some) values that can't possibly be
|
||||
// [time.RFC3339Nano] timestamps.
|
||||
if len(text) < len("2006-01-02T15:04:05Z") {
|
||||
return string(text)
|
||||
}
|
||||
if len(text) > len(time.RFC3339Nano) {
|
||||
return string(text)
|
||||
}
|
||||
if text[4] != '-' || text[10] != 'T' || text[16] != ':' {
|
||||
return string(text)
|
||||
}
|
||||
|
||||
// Slow path.
|
||||
date, err := time.Parse(time.RFC3339Nano, string(text))
|
||||
if err == nil && date.Format(time.RFC3339Nano) == string(text) {
|
||||
return date
|
||||
}
|
||||
return text
|
||||
return string(text)
|
||||
}
|
||||
|
||||
100
driver/time_test.go
Normal file
100
driver/time_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package driver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// This checks that any string can be recovered as the same string.
|
||||
func Fuzz_stringOrTime_1(f *testing.F) {
|
||||
f.Add("")
|
||||
f.Add(" ")
|
||||
f.Add("SQLite")
|
||||
f.Add(time.RFC3339)
|
||||
f.Add(time.RFC3339Nano)
|
||||
f.Add(time.Layout)
|
||||
f.Add(time.DateTime)
|
||||
f.Add(time.DateOnly)
|
||||
f.Add(time.TimeOnly)
|
||||
f.Add("2006-01-02T15:04:05Z")
|
||||
f.Add("2006-01-02T15:04:05.000Z")
|
||||
f.Add("2006-01-02T15:04:05.9999999999Z")
|
||||
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
|
||||
|
||||
f.Fuzz(func(t *testing.T, str string) {
|
||||
value := stringOrTime([]byte(str))
|
||||
|
||||
switch v := value.(type) {
|
||||
case time.Time:
|
||||
// Make sure times round-trip to the same string:
|
||||
// https://pkg.go.dev/database/sql#Rows.Scan
|
||||
if v.Format(time.RFC3339Nano) != str {
|
||||
t.Fatalf("did not round-trip: %q", str)
|
||||
}
|
||||
case string:
|
||||
if v != str {
|
||||
t.Fatalf("did not round-trip: %q", str)
|
||||
}
|
||||
|
||||
date, err := time.Parse(time.RFC3339Nano, str)
|
||||
if err == nil && date.Format(time.RFC3339Nano) == str {
|
||||
t.Fatalf("would round-trip: %q", str)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("invalid type %T: %q", v, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// This checks that any [time.Time] can be recovered as a [time.Time],
|
||||
// with nanosecond accuracy, and preserving any timezone offset.
|
||||
func Fuzz_stringOrTime_2(f *testing.F) {
|
||||
f.Add(0, 0)
|
||||
f.Add(0, 1)
|
||||
f.Add(0, -1)
|
||||
f.Add(0, 999_999_999)
|
||||
f.Add(0, 1_000_000_000)
|
||||
f.Add(7956915742, 222_222_222) // twosday
|
||||
f.Add(639095955742, 222_222_222) // twosday, year 22222AD
|
||||
f.Add(-763421161058, 222_222_222) // twosday, year 22222BC
|
||||
|
||||
checkTime := func(t *testing.T, date time.Time) {
|
||||
value := stringOrTime([]byte(date.Format(time.RFC3339Nano)))
|
||||
|
||||
switch v := value.(type) {
|
||||
case time.Time:
|
||||
// Make sure times round-trip to the same time:
|
||||
if !v.Equal(date) {
|
||||
t.Fatalf("did not round-trip: %v", date)
|
||||
}
|
||||
// Make with the same zone offset:
|
||||
_, off1 := v.Zone()
|
||||
_, off2 := date.Zone()
|
||||
if off1 != off2 {
|
||||
t.Fatalf("did not round-trip: %v", date)
|
||||
}
|
||||
case string:
|
||||
t.Fatalf("was not recovered: %v", date)
|
||||
default:
|
||||
t.Fatalf("invalid type %T: %v", v, date)
|
||||
}
|
||||
}
|
||||
|
||||
f.Fuzz(func(t *testing.T, sec, nsec int) {
|
||||
// Reduce the search space.
|
||||
if 1e12 < sec || sec < -1e12 {
|
||||
// Dates before 29000BC and after 33000AD; I think we're safe.
|
||||
return
|
||||
}
|
||||
if 0 < nsec || nsec > 1e10 {
|
||||
// Out of range nsec: [time.Time.Unix] handles these.
|
||||
return
|
||||
}
|
||||
|
||||
unix := time.Unix(int64(sec), int64(nsec))
|
||||
checkTime(t, unix)
|
||||
checkTime(t, unix.UTC())
|
||||
checkTime(t, unix.In(time.FixedZone("", -8*3600)))
|
||||
checkTime(t, unix.In(time.FixedZone("", +8*3600)))
|
||||
})
|
||||
}
|
||||
75
driver_test.go
Normal file
75
driver_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
var db *sql.DB
|
||||
|
||||
func ExampleDriverConn() {
|
||||
var err error
|
||||
db, err = sql.Open("sqlite3", "demo.db")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer os.Remove("demo.db")
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
conn, err := db.Conn(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
res, err := conn.ExecContext(ctx, `INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(11))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = conn.Raw(func(driverConn any) error {
|
||||
conn := driverConn.(sqlite3.DriverConn).Raw()
|
||||
savept := conn.Savepoint()
|
||||
defer savept.Release(&err)
|
||||
|
||||
blob, err := conn.OpenBlob("main", "test", "col", id, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer blob.Close()
|
||||
|
||||
_, err = fmt.Fprint(blob, "Hello BLOB!")
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var msg string
|
||||
err = conn.QueryRowContext(ctx, `SELECT col FROM test`).Scan(&msg)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println(msg)
|
||||
// Output:
|
||||
// Hello BLOB!
|
||||
}
|
||||
25
embed/README.md
Normal file
25
embed/README.md
Normal file
@@ -0,0 +1,25 @@
|
||||
# Embeddable WASM build of SQLite
|
||||
|
||||
This folder includes an embeddable WASM build of SQLite 3.43.2 for use with
|
||||
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
|
||||
|
||||
The following optional features are compiled in:
|
||||
- [math functions](https://www.sqlite.org/lang_mathfunc.html)
|
||||
- [FTS3/4](https://www.sqlite.org/fts3.html)/[5](https://www.sqlite.org/fts5.html)
|
||||
- [JSON](https://www.sqlite.org/json1.html)
|
||||
- [R*Tree](https://www.sqlite.org/rtree.html)
|
||||
- [GeoPoly](https://www.sqlite.org/geopoly.html)
|
||||
- [soundex](https://www.sqlite.org/lang_corefunc.html#soundex)
|
||||
- [base64](https://github.com/sqlite/sqlite/blob/master/ext/misc/base64.c)
|
||||
- [decimal](https://github.com/sqlite/sqlite/blob/master/ext/misc/decimal.c)
|
||||
- [regexp](https://github.com/sqlite/sqlite/blob/master/ext/misc/regexp.c)
|
||||
- [series](https://github.com/sqlite/sqlite/blob/master/ext/misc/series.c)
|
||||
- [uint](https://github.com/sqlite/sqlite/blob/master/ext/misc/uint.c)
|
||||
- [uuid](https://github.com/sqlite/sqlite/blob/master/ext/misc/uuid.c)
|
||||
- [time](../sqlite3/time.c)
|
||||
|
||||
See the [configuration options](../sqlite3/sqlite_cfg.h),
|
||||
and [patches](../sqlite3) applied.
|
||||
|
||||
Built using [`wasi-sdk`](https://github.com/WebAssembly/wasi-sdk),
|
||||
and [`binaryen`](https://github.com/WebAssembly/binaryen).
|
||||
@@ -1,50 +1,30 @@
|
||||
#!/usr/bin/env bash
|
||||
set -eo pipefail
|
||||
set -euo pipefail
|
||||
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
# download SQLite
|
||||
../sqlite3/download.sh
|
||||
ROOT=../
|
||||
BINARYEN="$ROOT/tools/binaryen-version_116/bin"
|
||||
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
|
||||
|
||||
# build SQLite
|
||||
zig cc --target=wasm32-wasi -flto -g0 -Os \
|
||||
-o sqlite3.wasm ../sqlite3/*.c \
|
||||
-mmutable-globals \
|
||||
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
|
||||
-o sqlite3.wasm "$ROOT/sqlite3/main.c" \
|
||||
-I"$ROOT/sqlite3" \
|
||||
-mexec-model=reactor \
|
||||
-msimd128 -mmutable-globals \
|
||||
-mbulk-memory -mreference-types \
|
||||
-mnontrapping-fptoint -msign-ext \
|
||||
-fno-stack-protector -fno-stack-clash-protection \
|
||||
-Wl,--initial-memory=327680 \
|
||||
-Wl,--stack-first \
|
||||
-Wl,--import-undefined \
|
||||
-D_HAVE_SQLITE_CONFIG_H \
|
||||
-Wl,--export=malloc \
|
||||
-Wl,--export=free \
|
||||
-Wl,--export=malloc_destructor \
|
||||
-Wl,--export=sqlite3_errcode \
|
||||
-Wl,--export=sqlite3_errstr \
|
||||
-Wl,--export=sqlite3_errmsg \
|
||||
-Wl,--export=sqlite3_error_offset \
|
||||
-Wl,--export=sqlite3_open_v2 \
|
||||
-Wl,--export=sqlite3_close \
|
||||
-Wl,--export=sqlite3_prepare_v3 \
|
||||
-Wl,--export=sqlite3_finalize \
|
||||
-Wl,--export=sqlite3_reset \
|
||||
-Wl,--export=sqlite3_step \
|
||||
-Wl,--export=sqlite3_exec \
|
||||
-Wl,--export=sqlite3_clear_bindings \
|
||||
-Wl,--export=sqlite3_bind_parameter_count \
|
||||
-Wl,--export=sqlite3_bind_parameter_index \
|
||||
-Wl,--export=sqlite3_bind_parameter_name \
|
||||
-Wl,--export=sqlite3_bind_null \
|
||||
-Wl,--export=sqlite3_bind_int64 \
|
||||
-Wl,--export=sqlite3_bind_double \
|
||||
-Wl,--export=sqlite3_bind_text64 \
|
||||
-Wl,--export=sqlite3_bind_blob64 \
|
||||
-Wl,--export=sqlite3_bind_zeroblob64 \
|
||||
-Wl,--export=sqlite3_column_count \
|
||||
-Wl,--export=sqlite3_column_name \
|
||||
-Wl,--export=sqlite3_column_type \
|
||||
-Wl,--export=sqlite3_column_int64 \
|
||||
-Wl,--export=sqlite3_column_double \
|
||||
-Wl,--export=sqlite3_column_text \
|
||||
-Wl,--export=sqlite3_column_blob \
|
||||
-Wl,--export=sqlite3_column_bytes \
|
||||
-Wl,--export=sqlite3_last_insert_rowid \
|
||||
-Wl,--export=sqlite3_changes64 \
|
||||
-Wl,--export=sqlite3_interrupt \
|
||||
$(awk '{print "-Wl,--export="$0}' exports.txt)
|
||||
|
||||
trap 'rm -f sqlite3.tmp' EXIT
|
||||
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
|
||||
"$BINARYEN/wasm-opt" -g --strip -c -O3 \
|
||||
sqlite3.tmp -o sqlite3.wasm \
|
||||
--enable-simd --enable-mutable-globals --enable-multivalue \
|
||||
--enable-bulk-memory --enable-reference-types \
|
||||
--enable-nontrapping-float-to-int --enable-sign-ext
|
||||
74
embed/exports.txt
Normal file
74
embed/exports.txt
Normal file
@@ -0,0 +1,74 @@
|
||||
free
|
||||
malloc
|
||||
malloc_destructor
|
||||
sqlite3_errcode
|
||||
sqlite3_errstr
|
||||
sqlite3_errmsg
|
||||
sqlite3_error_offset
|
||||
sqlite3_open_v2
|
||||
sqlite3_close
|
||||
sqlite3_close_v2
|
||||
sqlite3_prepare_v3
|
||||
sqlite3_finalize
|
||||
sqlite3_reset
|
||||
sqlite3_step
|
||||
sqlite3_exec
|
||||
sqlite3_clear_bindings
|
||||
sqlite3_bind_parameter_count
|
||||
sqlite3_bind_parameter_index
|
||||
sqlite3_bind_parameter_name
|
||||
sqlite3_bind_null
|
||||
sqlite3_bind_int64
|
||||
sqlite3_bind_double
|
||||
sqlite3_bind_text64
|
||||
sqlite3_bind_blob64
|
||||
sqlite3_bind_zeroblob64
|
||||
sqlite3_column_count
|
||||
sqlite3_column_name
|
||||
sqlite3_column_type
|
||||
sqlite3_column_int64
|
||||
sqlite3_column_double
|
||||
sqlite3_column_text
|
||||
sqlite3_column_blob
|
||||
sqlite3_column_bytes
|
||||
sqlite3_blob_open
|
||||
sqlite3_blob_close
|
||||
sqlite3_blob_reopen
|
||||
sqlite3_blob_bytes
|
||||
sqlite3_blob_read
|
||||
sqlite3_blob_write
|
||||
sqlite3_backup_init
|
||||
sqlite3_backup_step
|
||||
sqlite3_backup_finish
|
||||
sqlite3_backup_remaining
|
||||
sqlite3_backup_pagecount
|
||||
sqlite3_uri_parameter
|
||||
sqlite3_uri_key
|
||||
sqlite3_changes64
|
||||
sqlite3_last_insert_rowid
|
||||
sqlite3_get_autocommit
|
||||
sqlite3_anycollseq_init
|
||||
sqlite3_create_collation_go
|
||||
sqlite3_create_function_go
|
||||
sqlite3_create_aggregate_function_go
|
||||
sqlite3_create_window_function_go
|
||||
sqlite3_aggregate_context
|
||||
sqlite3_user_data
|
||||
sqlite3_set_auxdata_go
|
||||
sqlite3_get_auxdata
|
||||
sqlite3_value_type
|
||||
sqlite3_value_int64
|
||||
sqlite3_value_double
|
||||
sqlite3_value_text
|
||||
sqlite3_value_blob
|
||||
sqlite3_value_bytes
|
||||
sqlite3_result_null
|
||||
sqlite3_result_int64
|
||||
sqlite3_result_double
|
||||
sqlite3_result_text64
|
||||
sqlite3_result_blob64
|
||||
sqlite3_result_zeroblob64
|
||||
sqlite3_result_error
|
||||
sqlite3_result_error_code
|
||||
sqlite3_result_error_nomem
|
||||
sqlite3_result_error_toobig
|
||||
@@ -1,7 +1,9 @@
|
||||
// Package embed embeds SQLite into your application.
|
||||
//
|
||||
// You can obtain this build of SQLite from:
|
||||
// https://github.com/ncruces/go-sqlite3/tree/main/embed
|
||||
// Importing package embed initializes the [sqlite3.Binary] variable
|
||||
// with an appropriate build of SQLite:
|
||||
//
|
||||
// import _ "github.com/ncruces/go-sqlite3/embed"
|
||||
package embed
|
||||
|
||||
import (
|
||||
|
||||
Binary file not shown.
102
error.go
102
error.go
@@ -1,19 +1,20 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Error wraps an SQLite Error Code.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/errcode.html
|
||||
type Error struct {
|
||||
code uint64
|
||||
str string
|
||||
msg string
|
||||
sql string
|
||||
code uint64
|
||||
}
|
||||
|
||||
// Code returns the primary error code for this error.
|
||||
@@ -50,28 +51,87 @@ func (e *Error) Error() string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Is tests whether this error matches a given [ErrorCode] or [ExtendedErrorCode].
|
||||
//
|
||||
// It makes it possible to do:
|
||||
//
|
||||
// if errors.Is(err, sqlite3.BUSY) {
|
||||
// // ... handle BUSY
|
||||
// }
|
||||
func (e *Error) Is(err error) bool {
|
||||
switch c := err.(type) {
|
||||
case ErrorCode:
|
||||
return c == e.Code()
|
||||
case ExtendedErrorCode:
|
||||
return c == e.ExtendedCode()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// As converts this error to an [ErrorCode] or [ExtendedErrorCode].
|
||||
func (e *Error) As(err any) bool {
|
||||
switch c := err.(type) {
|
||||
case *ErrorCode:
|
||||
*c = e.Code()
|
||||
return true
|
||||
case *ExtendedErrorCode:
|
||||
*c = e.ExtendedCode()
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Temporary returns true for [BUSY] errors.
|
||||
func (e *Error) Temporary() bool {
|
||||
return e.Code() == BUSY
|
||||
}
|
||||
|
||||
// Timeout returns true for [BUSY_TIMEOUT] errors.
|
||||
func (e *Error) Timeout() bool {
|
||||
return e.ExtendedCode() == BUSY_TIMEOUT
|
||||
}
|
||||
|
||||
// SQL returns the SQL starting at the token that triggered a syntax error.
|
||||
func (e *Error) SQL() string {
|
||||
return e.sql
|
||||
}
|
||||
|
||||
type errorString string
|
||||
|
||||
func (e errorString) Error() string { return string(e) }
|
||||
|
||||
const (
|
||||
nilErr = errorString("sqlite3: invalid memory address or null pointer dereference")
|
||||
oomErr = errorString("sqlite3: out of memory")
|
||||
rangeErr = errorString("sqlite3: index out of range")
|
||||
noNulErr = errorString("sqlite3: missing NUL terminator")
|
||||
noGlobalErr = errorString("sqlite3: could not find global: ")
|
||||
noFuncErr = errorString("sqlite3: could not find function: ")
|
||||
)
|
||||
|
||||
func assertErr() errorString {
|
||||
msg := "sqlite3: assertion failed"
|
||||
if _, file, line, ok := runtime.Caller(1); ok {
|
||||
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
|
||||
}
|
||||
return errorString(msg)
|
||||
// Error implements the error interface.
|
||||
func (e ErrorCode) Error() string {
|
||||
return util.ErrorCodeString(uint32(e))
|
||||
}
|
||||
|
||||
// Temporary returns true for [BUSY] errors.
|
||||
func (e ErrorCode) Temporary() bool {
|
||||
return e == BUSY
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e ExtendedErrorCode) Error() string {
|
||||
return util.ErrorCodeString(uint32(e))
|
||||
}
|
||||
|
||||
// Is tests whether this error matches a given [ErrorCode].
|
||||
func (e ExtendedErrorCode) Is(err error) bool {
|
||||
c, ok := err.(ErrorCode)
|
||||
return ok && c == ErrorCode(e)
|
||||
}
|
||||
|
||||
// As converts this error to an [ErrorCode].
|
||||
func (e ExtendedErrorCode) As(err any) bool {
|
||||
c, ok := err.(*ErrorCode)
|
||||
if ok {
|
||||
*c = ErrorCode(e)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// Temporary returns true for [BUSY] errors.
|
||||
func (e ExtendedErrorCode) Temporary() bool {
|
||||
return ErrorCode(e) == BUSY
|
||||
}
|
||||
|
||||
// Timeout returns true for [BUSY_TIMEOUT] errors.
|
||||
func (e ExtendedErrorCode) Timeout() bool {
|
||||
return e == BUSY_TIMEOUT
|
||||
}
|
||||
|
||||
166
error_test.go
166
error_test.go
@@ -1,26 +1,168 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
err := Error{code: 0x8080}
|
||||
if rc := err.Code(); rc != 0x80 {
|
||||
t.Errorf("got %#x, want 0x80", rc)
|
||||
func Test_assertErr(t *testing.T) {
|
||||
err := util.AssertErr()
|
||||
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:12)") {
|
||||
t.Errorf("got %q", s)
|
||||
}
|
||||
if rc := err.ExtendedCode(); rc != 0x8080 {
|
||||
t.Errorf("got %#x, want 0x8080", rc)
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var ecode ErrorCode
|
||||
var xcode xErrorCode
|
||||
err := &Error{code: 0x8080}
|
||||
if !errors.As(err, &err) {
|
||||
t.Fatal("want true")
|
||||
}
|
||||
if ecode := err.Code(); ecode != 0x80 {
|
||||
t.Errorf("got %#x, want 0x80", uint8(ecode))
|
||||
}
|
||||
if ok := errors.As(err, &ecode); !ok || ecode != ErrorCode(0x80) {
|
||||
t.Errorf("got %#x, want 0x80", uint8(ecode))
|
||||
}
|
||||
if !errors.Is(err, ErrorCode(0x80)) {
|
||||
t.Errorf("want true")
|
||||
}
|
||||
if xcode := err.ExtendedCode(); xcode != 0x8080 {
|
||||
t.Errorf("got %#x, want 0x8080", uint16(xcode))
|
||||
}
|
||||
if ok := errors.As(err, &xcode); !ok || xcode != xErrorCode(0x8080) {
|
||||
t.Errorf("got %#x, want 0x8080", uint16(xcode))
|
||||
}
|
||||
if !errors.Is(err, xErrorCode(0x8080)) {
|
||||
t.Errorf("want true")
|
||||
}
|
||||
if s := err.Error(); s != "sqlite3: 32896" {
|
||||
t.Errorf("got %q", s)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_assertErr(t *testing.T) {
|
||||
err := assertErr()
|
||||
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:22)") {
|
||||
t.Errorf("got %q", s)
|
||||
if ok := errors.As(err.ExtendedCode(), &ecode); !ok || ecode != ErrorCode(0x80) {
|
||||
t.Errorf("got %#x, want 0x80", uint8(ecode))
|
||||
}
|
||||
if !errors.Is(err.ExtendedCode(), ErrorCode(0x80)) {
|
||||
t.Errorf("want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestError_Temporary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code uint64
|
||||
want bool
|
||||
}{
|
||||
{"ERROR", uint64(ERROR), false},
|
||||
{"BUSY", uint64(BUSY), true},
|
||||
{"BUSY_RECOVERY", uint64(BUSY_RECOVERY), true},
|
||||
{"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), true},
|
||||
{"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
{
|
||||
err := &Error{code: tt.code}
|
||||
if got := err.Temporary(); got != tt.want {
|
||||
t.Errorf("Error.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
|
||||
}
|
||||
}
|
||||
{
|
||||
err := ErrorCode(tt.code)
|
||||
if got := err.Temporary(); got != tt.want {
|
||||
t.Errorf("ErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
|
||||
}
|
||||
}
|
||||
{
|
||||
err := ExtendedErrorCode(tt.code)
|
||||
if got := err.Temporary(); got != tt.want {
|
||||
t.Errorf("ExtendedErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestError_Timeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code uint64
|
||||
want bool
|
||||
}{
|
||||
{"ERROR", uint64(ERROR), false},
|
||||
{"BUSY", uint64(BUSY), false},
|
||||
{"BUSY_RECOVERY", uint64(BUSY_RECOVERY), false},
|
||||
{"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), false},
|
||||
{"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
{
|
||||
err := &Error{code: tt.code}
|
||||
if got := err.Timeout(); got != tt.want {
|
||||
t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want)
|
||||
}
|
||||
}
|
||||
{
|
||||
err := ExtendedErrorCode(tt.code)
|
||||
if got := err.Timeout(); got != tt.want {
|
||||
t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ErrorCode_Error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Test all error codes.
|
||||
for i := 0; i == int(ErrorCode(i)); i++ {
|
||||
want := "sqlite3: "
|
||||
r := db.call(db.api.errstr, uint64(i))
|
||||
want += util.ReadString(db.mod, uint32(r), _MAX_STRING)
|
||||
|
||||
got := ErrorCode(i).Error()
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q, with %d", got, want, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ExtendedErrorCode_Error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Test all extended error codes.
|
||||
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
|
||||
want := "sqlite3: "
|
||||
r := db.call(db.api.errstr, uint64(i))
|
||||
want += util.ReadString(db.mod, uint32(r), _MAX_STRING)
|
||||
|
||||
got := ExtendedErrorCode(i).Error()
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q, with %d", got, want, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ func Example() {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
|
||||
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -30,6 +30,7 @@ func Example() {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnInt(0), stmt.ColumnText(1))
|
||||
@@ -47,7 +48,6 @@ func Example() {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Output:
|
||||
// 0 go
|
||||
// 1 zig
|
||||
|
||||
109
ext/stats/stats.go
Normal file
109
ext/stats/stats.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Package stats provides aggregate functions for statistics.
|
||||
//
|
||||
// Functions:
|
||||
// - stddev_pop: population standard deviation
|
||||
// - stddev_samp: sample standard deviation
|
||||
// - var_pop: population variance
|
||||
// - var_samp: sample variance
|
||||
// - covar_pop: population covariance
|
||||
// - covar_samp: sample covariance
|
||||
// - corr: correlation coefficient
|
||||
//
|
||||
// See: [ANSI SQL Aggregate Functions]
|
||||
//
|
||||
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
|
||||
package stats
|
||||
|
||||
import "github.com/ncruces/go-sqlite3"
|
||||
|
||||
// Register registers statistics functions.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop))
|
||||
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp))
|
||||
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop))
|
||||
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp))
|
||||
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
|
||||
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
|
||||
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr))
|
||||
}
|
||||
|
||||
const (
|
||||
var_pop = iota
|
||||
var_samp
|
||||
stddev_pop
|
||||
stddev_samp
|
||||
corr
|
||||
)
|
||||
|
||||
func newVariance(kind int) func() sqlite3.AggregateFunction {
|
||||
return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
|
||||
}
|
||||
|
||||
type variance struct {
|
||||
kind int
|
||||
welford
|
||||
}
|
||||
|
||||
func (fn *variance) Value(ctx sqlite3.Context) {
|
||||
var r float64
|
||||
switch fn.kind {
|
||||
case var_pop:
|
||||
r = fn.var_pop()
|
||||
case var_samp:
|
||||
r = fn.var_samp()
|
||||
case stddev_pop:
|
||||
r = fn.stddev_pop()
|
||||
case stddev_samp:
|
||||
r = fn.stddev_samp()
|
||||
}
|
||||
ctx.ResultFloat(r)
|
||||
}
|
||||
|
||||
func (fn *variance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if a := arg[0]; a.Type() != sqlite3.NULL {
|
||||
fn.enqueue(a.Float())
|
||||
}
|
||||
}
|
||||
|
||||
func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if a := arg[0]; a.Type() != sqlite3.NULL {
|
||||
fn.dequeue(a.Float())
|
||||
}
|
||||
}
|
||||
|
||||
func newCovariance(kind int) func() sqlite3.AggregateFunction {
|
||||
return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
|
||||
}
|
||||
|
||||
type covariance struct {
|
||||
kind int
|
||||
welford2
|
||||
}
|
||||
|
||||
func (fn *covariance) Value(ctx sqlite3.Context) {
|
||||
var r float64
|
||||
switch fn.kind {
|
||||
case var_pop:
|
||||
r = fn.covar_pop()
|
||||
case var_samp:
|
||||
r = fn.covar_samp()
|
||||
case corr:
|
||||
r = fn.correlation()
|
||||
}
|
||||
ctx.ResultFloat(r)
|
||||
}
|
||||
|
||||
func (fn *covariance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
a, b := arg[0], arg[1]
|
||||
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
|
||||
fn.enqueue(a.Float(), b.Float())
|
||||
}
|
||||
}
|
||||
|
||||
func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
a, b := arg[0], arg[1]
|
||||
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
|
||||
fn.dequeue(a.Float(), b.Float())
|
||||
}
|
||||
}
|
||||
140
ext/stats/stats_test.go
Normal file
140
ext/stats/stats_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestRegister_variance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`
|
||||
SELECT
|
||||
sum(x), avg(x),
|
||||
var_samp(x), var_pop(x),
|
||||
stddev_samp(x), stddev_pop(x)
|
||||
FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 40 {
|
||||
t.Errorf("got %v, want 40", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(1); got != 10 {
|
||||
t.Errorf("got %v, want 10", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(2); got != 30 {
|
||||
t.Errorf("got %v, want 30", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(3); got != 22.5 {
|
||||
t.Errorf("got %v, want 22.5", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(4); got != math.Sqrt(30) {
|
||||
t.Errorf("got %v, want √30", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(5); got != math.Sqrt(22.5) {
|
||||
t.Errorf("got %v, want √22.5", got)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
stmt, _, err := db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
want := [...]float64{0, 4.5, 18, 0, 0}
|
||||
for i := 0; stmt.Step(); i++ {
|
||||
if got := stmt.ColumnFloat(0); got != want[i] {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_covariance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x, y)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO data (x, y) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT
|
||||
corr(x, y), covar_samp(x, y), covar_pop(x, y) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
|
||||
t.Errorf("got %v, want 0.9881049293224639", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(1); got != 21.25 {
|
||||
t.Errorf("got %v, want 21.25", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(2); got != 17 {
|
||||
t.Errorf("got %v, want 17", got)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
stmt, _, err := db.Prepare(`SELECT covar_samp(x, y) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
want := [...]float64{0, 10, 30, 75, 22.5}
|
||||
for i := 0; stmt.Step(); i++ {
|
||||
if got := stmt.ColumnFloat(0); got != want[i] {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
109
ext/stats/welford.go
Normal file
109
ext/stats/welford.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package stats
|
||||
|
||||
import "math"
|
||||
|
||||
// Welford's algorithm with Kahan summation:
|
||||
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
|
||||
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
|
||||
|
||||
type welford struct {
|
||||
m1, m2 kahan
|
||||
n uint64
|
||||
}
|
||||
|
||||
func (w welford) average() float64 {
|
||||
return w.m1.hi
|
||||
}
|
||||
|
||||
func (w welford) var_pop() float64 {
|
||||
return w.m2.hi / float64(w.n)
|
||||
}
|
||||
|
||||
func (w welford) var_samp() float64 {
|
||||
return w.m2.hi / float64(w.n-1) // Bessel's correction
|
||||
}
|
||||
|
||||
func (w welford) stddev_pop() float64 {
|
||||
return math.Sqrt(w.var_pop())
|
||||
}
|
||||
|
||||
func (w welford) stddev_samp() float64 {
|
||||
return math.Sqrt(w.var_samp())
|
||||
}
|
||||
|
||||
func (w *welford) enqueue(x float64) {
|
||||
w.n++
|
||||
d1 := x - w.m1.hi - w.m1.lo
|
||||
w.m1.add(d1 / float64(w.n))
|
||||
d2 := x - w.m1.hi - w.m1.lo
|
||||
w.m2.add(d1 * d2)
|
||||
}
|
||||
|
||||
func (w *welford) dequeue(x float64) {
|
||||
w.n--
|
||||
d1 := x - w.m1.hi - w.m1.lo
|
||||
w.m1.sub(d1 / float64(w.n))
|
||||
d2 := x - w.m1.hi - w.m1.lo
|
||||
w.m2.sub(d1 * d2)
|
||||
}
|
||||
|
||||
type welford2 struct {
|
||||
m1x, m2x kahan
|
||||
m1y, m2y kahan
|
||||
cov kahan
|
||||
n uint64
|
||||
}
|
||||
|
||||
func (w welford2) covar_pop() float64 {
|
||||
return w.cov.hi / float64(w.n)
|
||||
}
|
||||
|
||||
func (w welford2) covar_samp() float64 {
|
||||
return w.cov.hi / float64(w.n-1) // Bessel's correction
|
||||
}
|
||||
|
||||
func (w welford2) correlation() float64 {
|
||||
return w.cov.hi / math.Sqrt(w.m2x.hi*w.m2y.hi)
|
||||
}
|
||||
|
||||
func (w *welford2) enqueue(x, y float64) {
|
||||
w.n++
|
||||
d1x := x - w.m1x.hi - w.m1x.lo
|
||||
d1y := y - w.m1y.hi - w.m1y.lo
|
||||
w.m1x.add(d1x / float64(w.n))
|
||||
w.m1y.add(d1y / float64(w.n))
|
||||
d2x := x - w.m1x.hi - w.m1x.lo
|
||||
d2y := y - w.m1y.hi - w.m1y.lo
|
||||
w.m2x.add(d1x * d2x)
|
||||
w.m2y.add(d1y * d2y)
|
||||
w.cov.add(d1x * d2y)
|
||||
}
|
||||
|
||||
func (w *welford2) dequeue(x, y float64) {
|
||||
w.n--
|
||||
d1x := x - w.m1x.hi - w.m1x.lo
|
||||
d1y := y - w.m1y.hi - w.m1y.lo
|
||||
w.m1x.sub(d1x / float64(w.n))
|
||||
w.m1y.sub(d1y / float64(w.n))
|
||||
d2x := x - w.m1x.hi - w.m1x.lo
|
||||
d2y := y - w.m1y.hi - w.m1y.lo
|
||||
w.m2x.sub(d1x * d2x)
|
||||
w.m2y.sub(d1y * d2y)
|
||||
w.cov.sub(d1x * d2y)
|
||||
}
|
||||
|
||||
type kahan struct{ hi, lo float64 }
|
||||
|
||||
func (k *kahan) add(x float64) {
|
||||
y := k.lo + x
|
||||
t := k.hi + y
|
||||
k.lo = y - (t - k.hi)
|
||||
k.hi = t
|
||||
}
|
||||
|
||||
func (k *kahan) sub(x float64) {
|
||||
y := k.lo - x
|
||||
t := k.hi + y
|
||||
k.lo = y - (t - k.hi)
|
||||
k.hi = t
|
||||
}
|
||||
75
ext/stats/welford_test.go
Normal file
75
ext/stats/welford_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_welford(t *testing.T) {
|
||||
var s1, s2 welford
|
||||
|
||||
s1.enqueue(4)
|
||||
s1.enqueue(7)
|
||||
s1.enqueue(13)
|
||||
s1.enqueue(16)
|
||||
if got := s1.average(); got != 10 {
|
||||
t.Errorf("got %v, want 10", got)
|
||||
}
|
||||
if got := s1.var_samp(); got != 30 {
|
||||
t.Errorf("got %v, want 30", got)
|
||||
}
|
||||
if got := s1.var_pop(); got != 22.5 {
|
||||
t.Errorf("got %v, want 22.5", got)
|
||||
}
|
||||
if got := s1.stddev_samp(); got != math.Sqrt(30) {
|
||||
t.Errorf("got %v, want √30", got)
|
||||
}
|
||||
if got := s1.stddev_pop(); got != math.Sqrt(22.5) {
|
||||
t.Errorf("got %v, want √22.5", got)
|
||||
}
|
||||
|
||||
s1.dequeue(4)
|
||||
s2.enqueue(7)
|
||||
s2.enqueue(13)
|
||||
s2.enqueue(16)
|
||||
if s1.var_pop() != s2.var_pop() {
|
||||
t.Errorf("got %v, want %v", s1, s2)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_covar(t *testing.T) {
|
||||
var c1, c2 welford2
|
||||
|
||||
c1.enqueue(3, 70)
|
||||
c1.enqueue(5, 80)
|
||||
c1.enqueue(2, 60)
|
||||
c1.enqueue(7, 90)
|
||||
c1.enqueue(4, 75)
|
||||
|
||||
if got := c1.covar_samp(); got != 21.25 {
|
||||
t.Errorf("got %v, want 21.25", got)
|
||||
}
|
||||
if got := c1.covar_pop(); got != 17 {
|
||||
t.Errorf("got %v, want 17", got)
|
||||
}
|
||||
|
||||
c1.dequeue(3, 70)
|
||||
c2.enqueue(5, 80)
|
||||
c2.enqueue(2, 60)
|
||||
c2.enqueue(7, 90)
|
||||
c2.enqueue(4, 75)
|
||||
if c1.covar_pop() != c2.covar_pop() {
|
||||
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_correlation(t *testing.T) {
|
||||
var c welford2
|
||||
c.enqueue(1, 3)
|
||||
c.enqueue(2, 2)
|
||||
c.enqueue(3, 1)
|
||||
|
||||
if got := c.correlation(); got != -1 {
|
||||
t.Errorf("got %v, want -1", got)
|
||||
}
|
||||
}
|
||||
181
ext/unicode/unicode.go
Normal file
181
ext/unicode/unicode.go
Normal file
@@ -0,0 +1,181 @@
|
||||
// Package unicode provides an alternative to the SQLite ICU extension.
|
||||
//
|
||||
// Like the [ICU extension], it provides Unicode aware:
|
||||
// - upper() and lower() functions,
|
||||
// - LIKE and REGEXP operators,
|
||||
// - collation sequences.
|
||||
//
|
||||
// The implementation is not 100% compatible with the [ICU extension]:
|
||||
// - upper() and lower() use [strings.ToUpper], [strings.ToLower] and [cases];
|
||||
// - the LIKE operator follows [strings.EqualFold] rules;
|
||||
// - the REGEXP operator uses Go [regex/syntax];
|
||||
// - collation sequences use [collate].
|
||||
//
|
||||
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
|
||||
//
|
||||
// [ICU extension]: https://sqlite.org/src/dir/ext/icu
|
||||
package unicode
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/collate"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// Register registers Unicode aware functions for a database connection.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
|
||||
db.CreateFunction("like", 2, flags, like)
|
||||
db.CreateFunction("like", 3, flags, like)
|
||||
db.CreateFunction("upper", 1, flags, upper)
|
||||
db.CreateFunction("upper", 2, flags, upper)
|
||||
db.CreateFunction("lower", 1, flags, lower)
|
||||
db.CreateFunction("lower", 2, flags, lower)
|
||||
db.CreateFunction("regexp", 2, flags, regex)
|
||||
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
|
||||
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
name := arg[1].Text()
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err := RegisterCollation(db, arg[0].Text(), name)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterCollation registers a Unicode collation sequence for a database connection.
|
||||
func RegisterCollation(db *sqlite3.Conn, locale, name string) error {
|
||||
tag, err := language.Parse(locale)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.CreateCollation(name, collate.New(tag).Compare)
|
||||
}
|
||||
|
||||
func upper(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if len(arg) == 1 {
|
||||
ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob()))
|
||||
return
|
||||
}
|
||||
cs, ok := ctx.GetAuxData(1).(cases.Caser)
|
||||
if !ok {
|
||||
t, err := language.Parse(arg[1].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
c := cases.Upper(t)
|
||||
ctx.SetAuxData(1, c)
|
||||
cs = c
|
||||
}
|
||||
ctx.ResultBlob(cs.Bytes(arg[0].RawBlob()))
|
||||
}
|
||||
|
||||
func lower(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if len(arg) == 1 {
|
||||
ctx.ResultBlob(bytes.ToLower(arg[0].RawBlob()))
|
||||
return
|
||||
}
|
||||
cs, ok := ctx.GetAuxData(1).(cases.Caser)
|
||||
if !ok {
|
||||
t, err := language.Parse(arg[1].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
c := cases.Lower(t)
|
||||
ctx.SetAuxData(1, c)
|
||||
cs = c
|
||||
}
|
||||
ctx.ResultBlob(cs.Bytes(arg[0].RawBlob()))
|
||||
}
|
||||
|
||||
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
|
||||
if !ok {
|
||||
r, err := regexp.Compile(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
re = r
|
||||
ctx.SetAuxData(0, re)
|
||||
}
|
||||
ctx.ResultBool(re.Match(arg[1].RawBlob()))
|
||||
}
|
||||
|
||||
func like(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
escape := rune(-1)
|
||||
if len(arg) == 3 {
|
||||
var size int
|
||||
b := arg[2].RawBlob()
|
||||
escape, size = utf8.DecodeRune(b)
|
||||
if size != len(b) {
|
||||
ctx.ResultError(util.ErrorString("ESCAPE expression must be a single character"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type likeData struct {
|
||||
*regexp.Regexp
|
||||
escape rune
|
||||
}
|
||||
|
||||
re, ok := ctx.GetAuxData(0).(likeData)
|
||||
if !ok || re.escape != escape {
|
||||
re = likeData{
|
||||
regexp.MustCompile(like2regex(arg[0].Text(), escape)),
|
||||
escape,
|
||||
}
|
||||
ctx.SetAuxData(0, re)
|
||||
}
|
||||
ctx.ResultBool(re.Match(arg[1].RawBlob()))
|
||||
}
|
||||
|
||||
func like2regex(pattern string, escape rune) string {
|
||||
var re strings.Builder
|
||||
start := 0
|
||||
literal := false
|
||||
re.Grow(len(pattern) + 10)
|
||||
re.WriteString(`(?is)\A`) // case insensitive, . matches any character
|
||||
for i, r := range pattern {
|
||||
if start < 0 {
|
||||
start = i
|
||||
}
|
||||
if literal {
|
||||
literal = false
|
||||
continue
|
||||
}
|
||||
var symbol string
|
||||
switch r {
|
||||
case '_':
|
||||
symbol = `.`
|
||||
case '%':
|
||||
symbol = `.*`
|
||||
case escape:
|
||||
literal = true
|
||||
default:
|
||||
continue
|
||||
}
|
||||
re.WriteString(regexp.QuoteMeta(pattern[start:i]))
|
||||
re.WriteString(symbol)
|
||||
start = -1
|
||||
}
|
||||
if start >= 0 {
|
||||
re.WriteString(regexp.QuoteMeta(pattern[start:]))
|
||||
}
|
||||
re.WriteString(`\z`)
|
||||
return re.String()
|
||||
}
|
||||
215
ext/unicode/unicode_test.go
Normal file
215
ext/unicode/unicode_test.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package unicode
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
exec := func(fn string) string {
|
||||
stmt, _, err := db.Prepare(`SELECT ` + fn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
return stmt.ColumnText(0)
|
||||
}
|
||||
t.Fatal(stmt.Err())
|
||||
return ""
|
||||
}
|
||||
|
||||
Register(db)
|
||||
|
||||
tests := []struct {
|
||||
test string
|
||||
want string
|
||||
}{
|
||||
{`upper('hello')`, "HELLO"},
|
||||
{`lower('HELLO')`, "hello"},
|
||||
{`upper('привет')`, "ПРИВЕТ"},
|
||||
{`lower('ПРИВЕТ')`, "привет"},
|
||||
{`upper('istanbul')`, "ISTANBUL"},
|
||||
{`upper('istanbul', 'tr-TR')`, "İSTANBUL"},
|
||||
{`lower('Dünyanın İlk Borsası', 'tr-TR')`, "dünyanın ilk borsası"},
|
||||
{`upper('Dünyanın İlk Borsası', 'tr-TR')`, "DÜNYANIN İLK BORSASI"},
|
||||
{`'Hello' REGEXP 'ell'`, "1"},
|
||||
{`'Hello' REGEXP 'el.'`, "1"},
|
||||
{`'Hello' LIKE 'hel_'`, "0"},
|
||||
{`'Hello' LIKE 'hel%'`, "1"},
|
||||
{`'Hello' LIKE 'h_llo'`, "1"},
|
||||
{`'Hello' LIKE 'hello'`, "1"},
|
||||
{`'Привет' LIKE 'ПРИВЕТ'`, "1"},
|
||||
{`'100%' LIKE '100|%' ESCAPE '|'`, "1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.test, func(t *testing.T) {
|
||||
if got := exec(tt.test); got != tt.want {
|
||||
t.Errorf("exec(%q) = %q, want %q", tt.test, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_collation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT icu_load_collation('fr_FR', 'french')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT word FROM words ORDER BY word COLLATE french`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
got, want := []string{}, []string{"cote", "coté", "côte", "côté", "cotée", "coter"}
|
||||
|
||||
for stmt.Step() {
|
||||
got = append(got, stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Error("not equal")
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
Register(db)
|
||||
|
||||
err = db.Exec(`SELECT upper('hello', 'enUS')`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT lower('hello', 'enUS')`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT 'hello' REGEXP '\'`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT 'hello' LIKE 'HELLO' ESCAPE '\\'`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT icu_load_collation('enUS', 'error')`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT icu_load_collation('enUS', '')`)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_like2regex(t *testing.T) {
|
||||
const prefix = `(?is)\A`
|
||||
const sufix = `\z`
|
||||
tests := []struct {
|
||||
pattern string
|
||||
escape rune
|
||||
want string
|
||||
}{
|
||||
{`a`, -1, `a`},
|
||||
{`a.`, -1, `a\.`},
|
||||
{`a%`, -1, `a.*`},
|
||||
{`a\`, -1, `a\\`},
|
||||
{`a_b`, -1, `a.b`},
|
||||
{`a|b`, '|', `ab`},
|
||||
{`a|_`, '|', `a_`},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.pattern, func(t *testing.T) {
|
||||
want := prefix + tt.want + sufix
|
||||
if got := like2regex(tt.pattern, tt.escape); got != want {
|
||||
t.Errorf("like2regex() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
186
func.go
Normal file
186
func.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
// AnyCollationNeeded registers a fake collating function
|
||||
// for any unknown collating sequence.
|
||||
// The fake collating function works like BINARY.
|
||||
//
|
||||
// This can be used to load schemas that contain
|
||||
// one or more unknown collating sequences.
|
||||
func (c *Conn) AnyCollationNeeded() {
|
||||
c.call(c.api.anyCollation, uint64(c.handle), 0, 0)
|
||||
}
|
||||
|
||||
// CreateCollation defines a new collating sequence.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/create_collation.html
|
||||
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
|
||||
namePtr := c.arena.string(name)
|
||||
funcPtr := util.AddHandle(c.ctx, fn)
|
||||
r := c.call(c.api.createCollation,
|
||||
uint64(c.handle), uint64(namePtr), uint64(funcPtr))
|
||||
if err := c.error(r); err != nil {
|
||||
util.DelHandle(c.ctx, funcPtr)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateFunction defines a new scalar SQL function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/create_function.html
|
||||
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error {
|
||||
namePtr := c.arena.string(name)
|
||||
funcPtr := util.AddHandle(c.ctx, fn)
|
||||
r := c.call(c.api.createFunction,
|
||||
uint64(c.handle), uint64(namePtr), uint64(nArg),
|
||||
uint64(flag), uint64(funcPtr))
|
||||
return c.error(r)
|
||||
}
|
||||
|
||||
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
|
||||
// If fn returns a [WindowFunction], then an aggregate window function is created.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/create_function.html
|
||||
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
|
||||
call := c.api.createAggregate
|
||||
namePtr := c.arena.string(name)
|
||||
funcPtr := util.AddHandle(c.ctx, fn)
|
||||
if _, ok := fn().(WindowFunction); ok {
|
||||
call = c.api.createWindow
|
||||
}
|
||||
r := c.call(call,
|
||||
uint64(c.handle), uint64(namePtr), uint64(nArg),
|
||||
uint64(flag), uint64(funcPtr))
|
||||
return c.error(r)
|
||||
}
|
||||
|
||||
// AggregateFunction is the interface an aggregate function should implement.
|
||||
//
|
||||
// https://www.sqlite.org/appfunc.html
|
||||
type AggregateFunction interface {
|
||||
// Step is invoked to add a row to the current window.
|
||||
// The function arguments, if any, corresponding to the row being added are passed to Step.
|
||||
Step(ctx Context, arg ...Value)
|
||||
|
||||
// Value is invoked to return the current value of the aggregate.
|
||||
Value(ctx Context)
|
||||
}
|
||||
|
||||
// WindowFunction is the interface an aggregate window function should implement.
|
||||
//
|
||||
// https://www.sqlite.org/windowfunctions.html
|
||||
type WindowFunction interface {
|
||||
AggregateFunction
|
||||
|
||||
// Inverse is invoked to remove the oldest presently aggregated result of Step from the current window.
|
||||
// The function arguments, if any, are those passed to Step for the row being removed.
|
||||
Inverse(ctx Context, arg ...Value)
|
||||
}
|
||||
|
||||
func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
|
||||
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
|
||||
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
|
||||
util.ExportFuncVIII(env, "go_func", callbackFunc)
|
||||
util.ExportFuncVIII(env, "go_step", callbackStep)
|
||||
util.ExportFuncVI(env, "go_final", callbackFinal)
|
||||
util.ExportFuncVI(env, "go_value", callbackValue)
|
||||
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
|
||||
return env
|
||||
}
|
||||
|
||||
func callbackDestroy(ctx context.Context, mod api.Module, pApp uint32) {
|
||||
util.DelHandle(ctx, pApp)
|
||||
}
|
||||
|
||||
func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
|
||||
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
|
||||
return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))
|
||||
}
|
||||
|
||||
func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackHandle(sqlite, pCtx).(func(ctx Context, arg ...Value))
|
||||
fn(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
|
||||
fn.Step(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) {
|
||||
var handle uint32
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackAggregate(sqlite, pCtx, &handle).(AggregateFunction)
|
||||
fn.Value(Context{sqlite, pCtx})
|
||||
if err := util.DelHandle(ctx, handle); err != nil {
|
||||
Context{sqlite, pCtx}.ResultError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
|
||||
fn.Value(Context{sqlite, pCtx})
|
||||
}
|
||||
|
||||
func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackAggregate(sqlite, pCtx, nil).(WindowFunction)
|
||||
fn.Inverse(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackHandle(sqlite *sqlite, pCtx uint32) any {
|
||||
pApp := uint32(sqlite.call(sqlite.api.userData, uint64(pCtx)))
|
||||
return util.GetHandle(sqlite.ctx, pApp)
|
||||
}
|
||||
|
||||
func callbackAggregate(sqlite *sqlite, pCtx uint32, close *uint32) any {
|
||||
// On close, we're getting rid of the handle.
|
||||
// Don't allocate space to store it.
|
||||
var size uint64
|
||||
if close == nil {
|
||||
size = ptrlen
|
||||
}
|
||||
ptr := uint32(sqlite.call(sqlite.api.aggregateCtx, uint64(pCtx), size))
|
||||
|
||||
// Try loading the handle, if we already have one, or want a new one.
|
||||
if ptr != 0 || size != 0 {
|
||||
if handle := util.ReadUint32(sqlite.mod, ptr); handle != 0 {
|
||||
fn := util.GetHandle(sqlite.ctx, handle)
|
||||
if close != nil {
|
||||
*close = handle
|
||||
}
|
||||
if fn != nil {
|
||||
return fn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new aggregate and store the handle.
|
||||
fn := callbackHandle(sqlite, pCtx).(func() AggregateFunction)()
|
||||
if ptr != 0 {
|
||||
util.WriteUint32(sqlite.mod, ptr, util.AddHandle(sqlite.ctx, fn))
|
||||
}
|
||||
return fn
|
||||
}
|
||||
|
||||
func callbackArgs(sqlite *sqlite, nArg, pArg uint32) []Value {
|
||||
args := make([]Value, nArg)
|
||||
for i := range args {
|
||||
args[i] = Value{
|
||||
sqlite: sqlite,
|
||||
handle: util.ReadUint32(sqlite.mod, pArg+ptrlen*uint32(i)),
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
154
func_test.go
Normal file
154
func_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"regexp"
|
||||
|
||||
"golang.org/x/text/collate"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func ExampleConn_CreateCollation() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateCollation("french", collate.New(language.French).Compare)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT word FROM words ORDER BY word COLLATE french`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Output:
|
||||
// cote
|
||||
// coté
|
||||
// côte
|
||||
// côté
|
||||
// cotée
|
||||
// coter
|
||||
}
|
||||
|
||||
func ExampleConn_CreateFunction() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateFunction("upper", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob()))
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT upper(word) FROM words`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Unordered output:
|
||||
// COTE
|
||||
// COTÉ
|
||||
// CÔTE
|
||||
// CÔTÉ
|
||||
// COTÉE
|
||||
// COTER
|
||||
}
|
||||
|
||||
func ExampleContext_SetAuxData() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateFunction("regexp", 2, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
|
||||
if !ok {
|
||||
r, err := regexp.Compile(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
ctx.SetAuxData(0, r)
|
||||
re = r
|
||||
}
|
||||
ctx.ResultBool(re.Match(arg[1].RawBlob()))
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT word FROM words WHERE word REGEXP '^\p{L}+e$'`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Unordered output:
|
||||
// cote
|
||||
// côte
|
||||
// cotée
|
||||
}
|
||||
87
func_win_test.go
Normal file
87
func_win_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"unicode"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func ExampleConn_CreateWindowFunction() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateWindowFunction("count_ascii", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, newASCIICounter)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT count_ascii(word) OVER (ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM words`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnInt(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Output:
|
||||
// 1
|
||||
// 2
|
||||
// 2
|
||||
// 1
|
||||
// 0
|
||||
// 0
|
||||
}
|
||||
|
||||
type countASCII struct{ result int }
|
||||
|
||||
func newASCIICounter() sqlite3.AggregateFunction {
|
||||
return &countASCII{}
|
||||
}
|
||||
|
||||
func (f *countASCII) Value(ctx sqlite3.Context) {
|
||||
ctx.ResultInt(f.result)
|
||||
}
|
||||
|
||||
func (f *countASCII) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if f.isASCII(arg[0]) {
|
||||
f.result++
|
||||
}
|
||||
}
|
||||
|
||||
func (f *countASCII) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if f.isASCII(arg[0]) {
|
||||
f.result--
|
||||
}
|
||||
}
|
||||
|
||||
func (f *countASCII) isASCII(arg sqlite3.Value) bool {
|
||||
if arg.Type() != sqlite3.TEXT {
|
||||
return false
|
||||
}
|
||||
for _, c := range arg.RawBlob() {
|
||||
if c > unicode.MaxASCII {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
12
go.mod
12
go.mod
@@ -1,10 +1,14 @@
|
||||
module github.com/ncruces/go-sqlite3
|
||||
|
||||
go 1.19
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/ncruces/julianday v0.1.5
|
||||
github.com/tetratelabs/wazero v1.0.0-pre.8
|
||||
golang.org/x/sync v0.1.0
|
||||
golang.org/x/sys v0.5.0
|
||||
github.com/psanford/httpreadat v0.1.0
|
||||
github.com/tetratelabs/wazero v1.5.0
|
||||
golang.org/x/sync v0.4.0
|
||||
golang.org/x/sys v0.13.0
|
||||
golang.org/x/text v0.13.0
|
||||
)
|
||||
|
||||
retract v0.4.0 // tagged from the wrong branch
|
||||
|
||||
16
go.sum
16
go.sum
@@ -1,8 +1,12 @@
|
||||
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
|
||||
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
|
||||
github.com/tetratelabs/wazero v1.0.0-pre.8 h1:Ir82PWj79WCppH+9ny73eGY2qv+oCnE3VwMY92cBSyI=
|
||||
github.com/tetratelabs/wazero v1.0.0-pre.8/go.mod h1:u8wrFmpdrykiFK0DFPiFm5a4+0RzsdmXYVtijBKqUVo=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
|
||||
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
|
||||
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
|
||||
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
|
||||
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
|
||||
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
|
||||
5
go.work.sum
Normal file
5
go.work.sum
Normal file
@@ -0,0 +1,5 @@
|
||||
github.com/ncruces/go-sqlite3 v0.9.0/go.mod h1:IyRoNwT0Z+mNRXIVeP2DgWPNl78Kmc/B+pO9i6GNgRg=
|
||||
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
22
gormlite/LICENSE
Normal file
22
gormlite/LICENSE
Normal file
@@ -0,0 +1,22 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Nuno Cruces
|
||||
Copyright (c) 2023 Jinzhu <wosmvp@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
26
gormlite/README.md
Normal file
26
gormlite/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# GORM SQLite Driver
|
||||
|
||||
[](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import (
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/gormlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
db, err := gorm.Open(gormlite.Open("gorm.db"), &gorm.Config{})
|
||||
```
|
||||
|
||||
Checkout [https://gorm.io](https://gorm.io) for details.
|
||||
|
||||
### Foreign-key constraint activation
|
||||
|
||||
Foreign-key constraint is disabled by default in SQLite. To activate it, use connection URL parameter:
|
||||
```go
|
||||
db, err := gorm.Open(gormlite.Open(
|
||||
"file:gorm.db?_pragma=busy_timeout(10000)&_pragma=foreign_keys(1)"),
|
||||
&gorm.Config{})
|
||||
```
|
||||
231
gormlite/ddlmod.go
Normal file
231
gormlite/ddlmod.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/migrator"
|
||||
)
|
||||
|
||||
var (
|
||||
sqliteSeparator = "`|\"|'|\t"
|
||||
indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]? ON (.*)$`, sqliteSeparator, sqliteSeparator))
|
||||
tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator))
|
||||
separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator))
|
||||
columnsRegexp = regexp.MustCompile(fmt.Sprintf(`[(,][%v]?(\w+)[%v]?`, sqliteSeparator, sqliteSeparator))
|
||||
columnRegexp = regexp.MustCompile(fmt.Sprintf(`^[%v]?([\w\d]+)[%v]?\s+([\w\(\)\d]+)(.*)$`, sqliteSeparator, sqliteSeparator))
|
||||
defaultValueRegexp = regexp.MustCompile(`(?i) DEFAULT \(?(.+)?\)?( |COLLATE|GENERATED|$)`)
|
||||
regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
|
||||
)
|
||||
|
||||
func getAllColumns(s string) []string {
|
||||
allMatches := columnsRegexp.FindAllStringSubmatch(s, -1)
|
||||
columns := make([]string, 0, len(allMatches))
|
||||
for _, matches := range allMatches {
|
||||
if len(matches) > 1 {
|
||||
columns = append(columns, matches[1])
|
||||
}
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
type ddl struct {
|
||||
head string
|
||||
fields []string
|
||||
columns []migrator.ColumnType
|
||||
}
|
||||
|
||||
func parseDDL(strs ...string) (*ddl, error) {
|
||||
var result ddl
|
||||
for _, str := range strs {
|
||||
if sections := tableRegexp.FindStringSubmatch(str); len(sections) > 0 {
|
||||
var (
|
||||
ddlBody = sections[2]
|
||||
ddlBodyRunes = []rune(ddlBody)
|
||||
bracketLevel int
|
||||
quote rune
|
||||
buf string
|
||||
)
|
||||
ddlBodyRunesLen := len(ddlBodyRunes)
|
||||
|
||||
result.head = sections[1]
|
||||
|
||||
for idx := 0; idx < ddlBodyRunesLen; idx++ {
|
||||
var (
|
||||
next rune = 0
|
||||
c = ddlBodyRunes[idx]
|
||||
)
|
||||
if idx+1 < ddlBodyRunesLen {
|
||||
next = ddlBodyRunes[idx+1]
|
||||
}
|
||||
|
||||
if sc := string(c); separatorRegexp.MatchString(sc) {
|
||||
if c == next {
|
||||
buf += sc // Skip escaped quote
|
||||
idx++
|
||||
} else if quote > 0 {
|
||||
quote = 0
|
||||
} else {
|
||||
quote = c
|
||||
}
|
||||
} else if quote == 0 {
|
||||
if c == '(' {
|
||||
bracketLevel++
|
||||
} else if c == ')' {
|
||||
bracketLevel--
|
||||
} else if bracketLevel == 0 {
|
||||
if c == ',' {
|
||||
result.fields = append(result.fields, strings.TrimSpace(buf))
|
||||
buf = ""
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bracketLevel < 0 {
|
||||
return nil, errors.New("invalid DDL, unbalanced brackets")
|
||||
}
|
||||
|
||||
buf += string(c)
|
||||
}
|
||||
|
||||
if bracketLevel != 0 {
|
||||
return nil, errors.New("invalid DDL, unbalanced brackets")
|
||||
}
|
||||
|
||||
if buf != "" {
|
||||
result.fields = append(result.fields, strings.TrimSpace(buf))
|
||||
}
|
||||
|
||||
for _, f := range result.fields {
|
||||
fUpper := strings.ToUpper(f)
|
||||
if strings.HasPrefix(fUpper, "CHECK") ||
|
||||
strings.HasPrefix(fUpper, "CONSTRAINT") {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(fUpper, "PRIMARY KEY") {
|
||||
for _, name := range getAllColumns(f) {
|
||||
for idx, column := range result.columns {
|
||||
if column.NameValue.String == name {
|
||||
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
|
||||
result.columns[idx] = column
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 0 {
|
||||
columnType := migrator.ColumnType{
|
||||
NameValue: sql.NullString{String: matches[1], Valid: true},
|
||||
DataTypeValue: sql.NullString{String: matches[2], Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: matches[2], Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Valid: true},
|
||||
UniqueValue: sql.NullBool{Valid: true},
|
||||
NullableValue: sql.NullBool{Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
}
|
||||
|
||||
matchUpper := strings.ToUpper(matches[3])
|
||||
if strings.Contains(matchUpper, " NOT NULL") {
|
||||
columnType.NullableValue = sql.NullBool{Bool: false, Valid: true}
|
||||
} else if strings.Contains(matchUpper, " NULL") {
|
||||
columnType.NullableValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
if strings.Contains(matchUpper, " UNIQUE") {
|
||||
columnType.UniqueValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
if strings.Contains(matchUpper, " PRIMARY") {
|
||||
columnType.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
if defaultMatches := defaultValueRegexp.FindStringSubmatch(matches[3]); len(defaultMatches) > 1 {
|
||||
if strings.ToLower(defaultMatches[1]) != "null" {
|
||||
columnType.DefaultValueValue = sql.NullString{String: strings.Trim(defaultMatches[1], `"`), Valid: true}
|
||||
}
|
||||
}
|
||||
|
||||
// data type length
|
||||
matches := regRealDataType.FindAllStringSubmatch(columnType.DataTypeValue.String, -1)
|
||||
if len(matches) == 1 && len(matches[0]) == 2 {
|
||||
size, _ := strconv.Atoi(matches[0][1])
|
||||
columnType.LengthValue = sql.NullInt64{Valid: true, Int64: int64(size)}
|
||||
columnType.DataTypeValue.String = strings.TrimSuffix(columnType.DataTypeValue.String, matches[0][0])
|
||||
}
|
||||
|
||||
result.columns = append(result.columns, columnType)
|
||||
}
|
||||
}
|
||||
} else if matches := indexRegexp.FindStringSubmatch(str); len(matches) > 0 {
|
||||
for _, column := range getAllColumns(matches[1]) {
|
||||
for idx, c := range result.columns {
|
||||
if c.NameValue.String == column {
|
||||
c.UniqueValue = sql.NullBool{Bool: strings.ToUpper(strings.Fields(str)[1]) == "UNIQUE", Valid: true}
|
||||
result.columns[idx] = c
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("invalid DDL")
|
||||
}
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (d *ddl) compile() string {
|
||||
if len(d.fields) == 0 {
|
||||
return d.head
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
|
||||
}
|
||||
|
||||
func (d *ddl) addConstraint(name string, sql string) {
|
||||
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields[i] = sql
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
d.fields = append(d.fields, sql)
|
||||
}
|
||||
|
||||
func (d *ddl) removeConstraint(name string) bool {
|
||||
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields = append(d.fields[:i], d.fields[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *ddl) getColumns() []string {
|
||||
res := []string{}
|
||||
|
||||
for _, f := range d.fields {
|
||||
fUpper := strings.ToUpper(f)
|
||||
if strings.HasPrefix(fUpper, "PRIMARY KEY") ||
|
||||
strings.HasPrefix(fUpper, "CHECK") ||
|
||||
strings.HasPrefix(fUpper, "CONSTRAINT") ||
|
||||
strings.Contains(fUpper, "GENERATED ALWAYS AS") {
|
||||
continue
|
||||
}
|
||||
|
||||
reg := regexp.MustCompile("^[\"`']?([\\w\\d]+)[\"`']?")
|
||||
match := reg.FindStringSubmatch(f)
|
||||
|
||||
if match != nil {
|
||||
res = append(res, "`"+match[1]+"`")
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
352
gormlite/ddlmod_test.go
Normal file
352
gormlite/ddlmod_test.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestParseDDL(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
sql []string
|
||||
nFields int
|
||||
columns []migrator.ColumnType
|
||||
}{
|
||||
{"with_fk", []string{
|
||||
"CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500) DEFAULT \"hello\",`age` integer DEFAULT 18,`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
|
||||
"CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)",
|
||||
}, 6, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}},
|
||||
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
},
|
||||
},
|
||||
{"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
}},
|
||||
{"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
},
|
||||
},
|
||||
{"no brackets", []string{"create table test"}, 0, nil},
|
||||
{"with_special_characters", []string{
|
||||
"CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")",
|
||||
}, 1, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"table_name_with_dash",
|
||||
[]string{
|
||||
"CREATE TABLE `test-a` (`id` int NOT NULL)",
|
||||
"CREATE UNIQUE INDEX `idx_test-a_id` ON `test-a`(`id`)",
|
||||
},
|
||||
1,
|
||||
[]migrator.ColumnType{
|
||||
{
|
||||
NameValue: sql.NullString{String: "id", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "int", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "int", Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
UniqueValue: sql.NullBool{Bool: true, Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Valid: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"unique index",
|
||||
[]string{
|
||||
"CREATE TABLE `test-b` (`field` integer NOT NULL)",
|
||||
"CREATE UNIQUE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
|
||||
},
|
||||
1,
|
||||
[]migrator.ColumnType{
|
||||
{
|
||||
NameValue: sql.NullString{String: "field", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
|
||||
UniqueValue: sql.NullBool{Bool: true, Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"non-unique index",
|
||||
[]string{
|
||||
"CREATE TABLE `test-c` (`field` integer NOT NULL)",
|
||||
"CREATE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
|
||||
},
|
||||
1,
|
||||
[]migrator.ColumnType{
|
||||
{
|
||||
NameValue: sql.NullString{String: "field", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
|
||||
UniqueValue: sql.NullBool{Bool: false, Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
ddl, err := parseDDL(p.sql...)
|
||||
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
tests.AssertEqual(t, p.sql[0], ddl.compile())
|
||||
if len(ddl.fields) != p.nFields {
|
||||
t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields))
|
||||
}
|
||||
tests.AssertEqual(t, ddl.columns, p.columns)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDDL_Whitespaces(t *testing.T) {
|
||||
testColumns := []migrator.ColumnType{
|
||||
{
|
||||
NameValue: sql.NullString{String: "id", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
UniqueValue: sql.NullBool{Bool: true, Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true},
|
||||
},
|
||||
{
|
||||
NameValue: sql.NullString{String: "dark_mode", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "numeric", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "numeric", Valid: true},
|
||||
NullableValue: sql.NullBool{Valid: true},
|
||||
DefaultValueValue: sql.NullString{String: "true", Valid: true},
|
||||
UniqueValue: sql.NullBool{Bool: false, Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
|
||||
},
|
||||
}
|
||||
|
||||
params := []struct {
|
||||
name string
|
||||
sql []string
|
||||
nFields int
|
||||
columns []migrator.ColumnType
|
||||
}{
|
||||
{
|
||||
"with_newline",
|
||||
[]string{"CREATE TABLE `users`\n(\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
|
||||
2,
|
||||
testColumns,
|
||||
},
|
||||
{
|
||||
"with_newline_2",
|
||||
[]string{"CREATE TABLE `users` (\n\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
|
||||
2,
|
||||
testColumns,
|
||||
},
|
||||
{
|
||||
"with_missing_space",
|
||||
[]string{"CREATE TABLE `users`(id integer primary key unique, dark_mode numeric DEFAULT true)"},
|
||||
2,
|
||||
testColumns,
|
||||
},
|
||||
{
|
||||
"with_many_spaces",
|
||||
[]string{"CREATE TABLE `users` (id integer primary key unique, dark_mode numeric DEFAULT true)"},
|
||||
2,
|
||||
testColumns,
|
||||
},
|
||||
}
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
ddl, err := parseDDL(p.sql...)
|
||||
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
if len(ddl.fields) != p.nFields {
|
||||
t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields))
|
||||
}
|
||||
tests.AssertEqual(t, ddl.columns, p.columns)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDDL_error(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
sql string
|
||||
}{
|
||||
{"invalid_cmd", "CREATE TABLE"},
|
||||
{"unbalanced_brackets", "CREATE TABLE test (ID int NOT NULL,Name varchar(255)"},
|
||||
{"unbalanced_brackets2", "CREATE TABLE test (ID int NOT NULL,Name varchar(255)))"},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
_, err := parseDDL(p.sql)
|
||||
if err == nil {
|
||||
t.Fail()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddConstraint(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
fields []string
|
||||
cName string
|
||||
sql string
|
||||
expect []string
|
||||
}{
|
||||
{
|
||||
name: "add_new",
|
||||
fields: []string{"`id` integer NOT NULL"},
|
||||
cName: "fk_users_notes",
|
||||
sql: "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
|
||||
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
|
||||
},
|
||||
{
|
||||
name: "update",
|
||||
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
|
||||
cName: "fk_users_notes",
|
||||
sql: "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)) ON UPDATE CASCADE ON DELETE CASCADE",
|
||||
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)) ON UPDATE CASCADE ON DELETE CASCADE"},
|
||||
},
|
||||
{
|
||||
name: "add_check",
|
||||
fields: []string{"`id` integer NOT NULL"},
|
||||
cName: "name_checker",
|
||||
sql: "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')",
|
||||
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')"},
|
||||
},
|
||||
{
|
||||
name: "update_check",
|
||||
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')"},
|
||||
cName: "name_checker",
|
||||
sql: "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')",
|
||||
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
testDDL := ddl{fields: p.fields}
|
||||
|
||||
testDDL.addConstraint(p.cName, p.sql)
|
||||
tests.AssertEqual(t, p.expect, testDDL.fields)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveConstraint(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
fields []string
|
||||
cName string
|
||||
success bool
|
||||
expect []string
|
||||
}{
|
||||
{
|
||||
name: "fk",
|
||||
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
|
||||
cName: "fk_users_notes",
|
||||
success: true,
|
||||
expect: []string{"`id` integer NOT NULL"},
|
||||
},
|
||||
{
|
||||
name: "check",
|
||||
fields: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
|
||||
cName: "name_checker",
|
||||
success: true,
|
||||
expect: []string{"`id` integer NOT NULL"},
|
||||
},
|
||||
{
|
||||
name: "none",
|
||||
fields: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
|
||||
cName: "nothing",
|
||||
success: false,
|
||||
expect: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
testDDL := ddl{fields: p.fields}
|
||||
|
||||
success := testDDL.removeConstraint(p.cName)
|
||||
|
||||
tests.AssertEqual(t, p.success, success)
|
||||
tests.AssertEqual(t, p.expect, testDDL.fields)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetColumns(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
ddl string
|
||||
columns []string
|
||||
}{
|
||||
{
|
||||
name: "with_fk",
|
||||
ddl: "CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500),`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
|
||||
columns: []string{"`id`", "`text`", "`user_id`"},
|
||||
},
|
||||
{
|
||||
name: "with_check",
|
||||
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName!='John'))",
|
||||
columns: []string{"`ID`", "`LastName`", "`FirstName`", "`Age`"},
|
||||
},
|
||||
{
|
||||
name: "with_escaped_quote",
|
||||
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL DEFAULT \"\",FirstName varchar(255))",
|
||||
columns: []string{"`ID`", "`LastName`", "`FirstName`"},
|
||||
},
|
||||
{
|
||||
name: "with_generated_column",
|
||||
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),FullName varchar(255) GENERATED ALWAYS AS (FirstName || ' ' || LastName))",
|
||||
columns: []string{"`ID`", "`LastName`", "`FirstName`"},
|
||||
},
|
||||
{
|
||||
name: "with_new_line",
|
||||
ddl: `CREATE TABLE "tb_sys_role_menu__temp" (
|
||||
"id" integer PRIMARY KEY AUTOINCREMENT,
|
||||
"created_at" datetime NOT NULL,
|
||||
"updated_at" datetime NOT NULL,
|
||||
"created_by" integer NOT NULL DEFAULT 0,
|
||||
"updated_by" integer NOT NULL DEFAULT 0,
|
||||
"role_id" integer NOT NULL,
|
||||
"menu_id" bigint NOT NULL
|
||||
)`,
|
||||
columns: []string{"`id`", "`created_at`", "`updated_at`", "`created_by`", "`updated_by`", "`role_id`", "`menu_id`"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
testDDL, err := parseDDL(p.ddl)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
cols := testDDL.getColumns()
|
||||
|
||||
tests.AssertEqual(t, p.columns, cols)
|
||||
})
|
||||
}
|
||||
}
|
||||
11
gormlite/download.sh
Executable file
11
gormlite/download.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/ddlmod.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/ddlmod_test.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/error_translator.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/migrator.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/sqlite.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/sqlite_test.go"
|
||||
21
gormlite/error_translator.go
Normal file
21
gormlite/error_translator.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (dialector Dialector) Translate(err error) error {
|
||||
switch {
|
||||
case
|
||||
errors.Is(err, sqlite3.CONSTRAINT_UNIQUE),
|
||||
errors.Is(err, sqlite3.CONSTRAINT_PRIMARYKEY):
|
||||
return gorm.ErrDuplicatedKey
|
||||
case
|
||||
errors.Is(err, sqlite3.CONSTRAINT_FOREIGNKEY):
|
||||
return gorm.ErrForeignKeyViolated
|
||||
}
|
||||
return err
|
||||
}
|
||||
16
gormlite/go.mod
Normal file
16
gormlite/go.mod
Normal file
@@ -0,0 +1,16 @@
|
||||
module github.com/ncruces/go-sqlite3/gormlite
|
||||
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/ncruces/go-sqlite3 v0.9.0
|
||||
gorm.io/gorm v1.25.4
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/ncruces/julianday v0.1.5 // indirect
|
||||
github.com/tetratelabs/wazero v1.5.0 // indirect
|
||||
golang.org/x/sys v0.12.0 // indirect
|
||||
)
|
||||
16
gormlite/go.sum
Normal file
16
gormlite/go.sum
Normal file
@@ -0,0 +1,16 @@
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/ncruces/go-sqlite3 v0.9.0 h1:tl5eEmGEyzZH2ur8sDgPJTdzV4CRnKpsFngoP1QRjD8=
|
||||
github.com/ncruces/go-sqlite3 v0.9.0/go.mod h1:IyRoNwT0Z+mNRXIVeP2DgWPNl78Kmc/B+pO9i6GNgRg=
|
||||
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
|
||||
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
|
||||
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
|
||||
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
|
||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
gorm.io/gorm v1.25.4 h1:iyNd8fNAe8W9dvtlgeRI5zSVZPsq3OpcTu37cYcpCmw=
|
||||
gorm.io/gorm v1.25.4/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||
431
gormlite/migrator.go
Normal file
431
gormlite/migrator.go
Normal file
@@ -0,0 +1,431 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type Migrator struct {
|
||||
migrator.Migrator
|
||||
}
|
||||
|
||||
func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
|
||||
var enabled int
|
||||
m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
|
||||
if enabled == 1 {
|
||||
m.DB.Exec("PRAGMA foreign_keys = OFF")
|
||||
defer m.DB.Exec("PRAGMA foreign_keys = ON")
|
||||
}
|
||||
|
||||
return fc()
|
||||
}
|
||||
|
||||
func (m Migrator) HasTable(value interface{}) bool {
|
||||
var count int
|
||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) DropTable(values ...interface{}) error {
|
||||
return m.RunWithoutForeignKey(func() error {
|
||||
values = m.ReorderModels(values, false)
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
|
||||
for i := len(values) - 1; i >= 0; i-- {
|
||||
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
|
||||
return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) GetTables() (tableList []string, err error) {
|
||||
return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error
|
||||
}
|
||||
|
||||
func (m Migrator) HasColumn(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
|
||||
"table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", "%["+name+"]%", "%\t"+name+"\t%",
|
||||
).Row().Scan(&count)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) AlterColumn(value interface{}, name string) error {
|
||||
return m.RunWithoutForeignKey(func() error {
|
||||
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
// lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)`
|
||||
reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName))
|
||||
|
||||
if createSQL == rawDDL {
|
||||
return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL)
|
||||
}
|
||||
|
||||
return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil
|
||||
}
|
||||
return "", nil, fmt.Errorf("failed to alter field with name %v", name)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
|
||||
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
columnTypes := make([]gorm.ColumnType, 0)
|
||||
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
|
||||
var (
|
||||
sqls []string
|
||||
sqlDDL *ddl
|
||||
)
|
||||
|
||||
if err := m.DB.Raw("SELECT sql FROM sqlite_master WHERE type IN ? AND tbl_name = ? AND sql IS NOT NULL order by type = ? desc", []string{"table", "index"}, stmt.Table, "table").Scan(&sqls).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sqlDDL, err = parseDDL(sqls...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err = rows.Close()
|
||||
}()
|
||||
|
||||
var rawColumnTypes []*sql.ColumnType
|
||||
rawColumnTypes, err = rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, c := range rawColumnTypes {
|
||||
columnType := migrator.ColumnType{SQLColumnType: c}
|
||||
for _, column := range sqlDDL.columns {
|
||||
if column.NameValue.String == c.Name() {
|
||||
column.SQLColumnType = c
|
||||
columnType = column
|
||||
break
|
||||
}
|
||||
}
|
||||
columnTypes = append(columnTypes, columnType)
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return columnTypes, execErr
|
||||
}
|
||||
|
||||
func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
createSQL := reg.ReplaceAllString(rawDDL, "")
|
||||
|
||||
return createSQL, nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
|
||||
return m.recreateTable(value, &table,
|
||||
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
var (
|
||||
constraintName string
|
||||
constraintSql string
|
||||
constraintValues []interface{}
|
||||
)
|
||||
|
||||
if constraint != nil {
|
||||
constraintName = constraint.Name
|
||||
constraintSql, constraintValues = buildConstraint(constraint)
|
||||
} else if chk != nil {
|
||||
constraintName = chk.Name
|
||||
constraintSql = "CONSTRAINT ? CHECK (?)"
|
||||
constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
|
||||
} else {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
createDDL, err := parseDDL(rawDDL)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
createDDL.addConstraint(constraintName, constraintSql)
|
||||
createSQL := createDDL.compile()
|
||||
|
||||
return createSQL, constraintValues, nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.Name
|
||||
} else if chk != nil {
|
||||
name = chk.Name
|
||||
}
|
||||
|
||||
return m.recreateTable(value, &table,
|
||||
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
createDDL, err := parseDDL(rawDDL)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
createDDL.removeConstraint(name)
|
||||
createSQL := createDDL.compile()
|
||||
|
||||
return createSQL, nil, nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.Name
|
||||
} else if chk != nil {
|
||||
name = chk.Name
|
||||
}
|
||||
|
||||
m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
|
||||
"table", table, `%CONSTRAINT "`+name+`" %`, `%CONSTRAINT `+name+` %`, "%CONSTRAINT `"+name+"`%", "%CONSTRAINT ["+name+"]%", "%CONSTRAINT \t"+name+"\t%",
|
||||
).Row().Scan(&count)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) CurrentDatabase() (name string) {
|
||||
var null interface{}
|
||||
m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
|
||||
for _, opt := range opts {
|
||||
str := stmt.Quote(opt.DBName)
|
||||
if opt.Expression != "" {
|
||||
str = opt.Expression
|
||||
}
|
||||
|
||||
if opt.Collate != "" {
|
||||
str += " COLLATE " + opt.Collate
|
||||
}
|
||||
|
||||
if opt.Sort != "" {
|
||||
str += " " + opt.Sort
|
||||
}
|
||||
results = append(results, clause.Expr{SQL: str})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
opts := m.BuildIndexOptions(idx.Fields, stmt)
|
||||
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
|
||||
|
||||
createIndexSQL := "CREATE "
|
||||
if idx.Class != "" {
|
||||
createIndexSQL += idx.Class + " "
|
||||
}
|
||||
createIndexSQL += "INDEX ?"
|
||||
|
||||
if idx.Type != "" {
|
||||
createIndexSQL += " USING " + idx.Type
|
||||
}
|
||||
createIndexSQL += " ON ??"
|
||||
|
||||
if idx.Where != "" {
|
||||
createIndexSQL += " WHERE " + idx.Where
|
||||
}
|
||||
|
||||
return m.DB.Exec(createIndexSQL, values...).Error
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to create index with name %v", name)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name,
|
||||
).Row().Scan(&count)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
var sql string
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
|
||||
if sql != "" {
|
||||
if err := m.DropIndex(value, oldName); err != nil {
|
||||
return err
|
||||
}
|
||||
return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
|
||||
}
|
||||
return fmt.Errorf("failed to find index with name %v", oldName)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
}
|
||||
|
||||
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
|
||||
})
|
||||
}
|
||||
|
||||
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
|
||||
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
||||
if constraint.OnDelete != "" {
|
||||
sql += " ON DELETE " + constraint.OnDelete
|
||||
}
|
||||
|
||||
if constraint.OnUpdate != "" {
|
||||
sql += " ON UPDATE " + constraint.OnUpdate
|
||||
}
|
||||
|
||||
var foreignKeys, references []interface{}
|
||||
for _, field := range constraint.ForeignKeys {
|
||||
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
for _, field := range constraint.References {
|
||||
references = append(references, clause.Column{Name: field.DBName})
|
||||
}
|
||||
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) getRawDDL(table string) (string, error) {
|
||||
var createSQL string
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
|
||||
|
||||
if m.DB.Error != nil {
|
||||
return "", m.DB.Error
|
||||
}
|
||||
return createSQL, nil
|
||||
}
|
||||
|
||||
func (m Migrator) recreateTable(value interface{}, tablePtr *string,
|
||||
getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
table := stmt.Table
|
||||
if tablePtr != nil {
|
||||
table = *tablePtr
|
||||
}
|
||||
|
||||
rawDDL, err := m.getRawDDL(table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newTableName := table + "__temp"
|
||||
|
||||
createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if createSQL == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
|
||||
|
||||
createDDL, err := parseDDL(createSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
columns := createDDL.getColumns()
|
||||
|
||||
return m.DB.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
queries := []string{
|
||||
fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), table),
|
||||
fmt.Sprintf("DROP TABLE `%v`", table),
|
||||
fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, table),
|
||||
}
|
||||
for _, query := range queries {
|
||||
if err := tx.Exec(query).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
249
gormlite/sqlite.go
Normal file
249
gormlite/sqlite.go
Normal file
@@ -0,0 +1,249 @@
|
||||
// Package gormlite provides a GORM driver for SQLite.
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
)
|
||||
|
||||
type Dialector struct {
|
||||
DSN string
|
||||
Conn gorm.ConnPool
|
||||
}
|
||||
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &Dialector{DSN: dsn}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Name() string {
|
||||
return "sqlite"
|
||||
}
|
||||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
if dialector.Conn != nil {
|
||||
db.ConnPool = dialector.Conn
|
||||
} else {
|
||||
conn, err := driver.Open(dialector.DSN, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.ConnPool = conn
|
||||
}
|
||||
|
||||
var version string
|
||||
if err := db.ConnPool.QueryRowContext(context.Background(), "select sqlite_version()").Scan(&version); err != nil {
|
||||
return err
|
||||
}
|
||||
// https://www.sqlite.org/releaselog/3_35_0.html
|
||||
if compareVersion(version, "3.35.0") >= 0 {
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
|
||||
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
|
||||
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
|
||||
LastInsertIDReversed: true,
|
||||
})
|
||||
} else {
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
LastInsertIDReversed: true,
|
||||
})
|
||||
}
|
||||
|
||||
for k, v := range dialector.ClauseBuilders() {
|
||||
db.ClauseBuilders[k] = v
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
||||
return map[string]clause.ClauseBuilder{
|
||||
"INSERT": func(c clause.Clause, builder clause.Builder) {
|
||||
if insert, ok := c.Expression.(clause.Insert); ok {
|
||||
if stmt, ok := builder.(*gorm.Statement); ok {
|
||||
stmt.WriteString("INSERT ")
|
||||
if insert.Modifier != "" {
|
||||
stmt.WriteString(insert.Modifier)
|
||||
stmt.WriteByte(' ')
|
||||
}
|
||||
|
||||
stmt.WriteString("INTO ")
|
||||
if insert.Table.Name == "" {
|
||||
stmt.WriteQuoted(stmt.Table)
|
||||
} else {
|
||||
stmt.WriteQuoted(insert.Table)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Build(builder)
|
||||
},
|
||||
"LIMIT": func(c clause.Clause, builder clause.Builder) {
|
||||
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||
var lmt = -1
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
lmt = *limit.Limit
|
||||
}
|
||||
if lmt >= 0 || limit.Offset > 0 {
|
||||
builder.WriteString("LIMIT ")
|
||||
builder.WriteString(strconv.Itoa(lmt))
|
||||
}
|
||||
if limit.Offset > 0 {
|
||||
builder.WriteString(" OFFSET ")
|
||||
builder.WriteString(strconv.Itoa(limit.Offset))
|
||||
}
|
||||
}
|
||||
},
|
||||
"FOR": func(c clause.Clause, builder clause.Builder) {
|
||||
if _, ok := c.Expression.(clause.Locking); ok {
|
||||
// SQLite3 does not support row-level locking.
|
||||
return
|
||||
}
|
||||
c.Build(builder)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
|
||||
if field.AutoIncrement {
|
||||
return clause.Expr{SQL: "NULL"}
|
||||
}
|
||||
|
||||
// doesn't work, will raise error
|
||||
return clause.Expr{SQL: "DEFAULT"}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
CreateIndexAfterCreateTable: true,
|
||||
}}}
|
||||
}
|
||||
|
||||
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||||
writer.WriteByte('?')
|
||||
}
|
||||
|
||||
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
var (
|
||||
underQuoted, selfQuoted bool
|
||||
continuousBacktick int8
|
||||
shiftDelimiter int8
|
||||
)
|
||||
|
||||
for _, v := range []byte(str) {
|
||||
switch v {
|
||||
case '`':
|
||||
continuousBacktick++
|
||||
if continuousBacktick == 2 {
|
||||
writer.WriteString("``")
|
||||
continuousBacktick = 0
|
||||
}
|
||||
case '.':
|
||||
if continuousBacktick > 0 || !selfQuoted {
|
||||
shiftDelimiter = 0
|
||||
underQuoted = false
|
||||
continuousBacktick = 0
|
||||
writer.WriteString("`")
|
||||
}
|
||||
writer.WriteByte(v)
|
||||
continue
|
||||
default:
|
||||
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
|
||||
writer.WriteString("`")
|
||||
underQuoted = true
|
||||
if selfQuoted = continuousBacktick > 0; selfQuoted {
|
||||
continuousBacktick -= 1
|
||||
}
|
||||
}
|
||||
|
||||
for ; continuousBacktick > 0; continuousBacktick -= 1 {
|
||||
writer.WriteString("``")
|
||||
}
|
||||
|
||||
writer.WriteByte(v)
|
||||
}
|
||||
shiftDelimiter++
|
||||
}
|
||||
|
||||
if continuousBacktick > 0 && !selfQuoted {
|
||||
writer.WriteString("``")
|
||||
}
|
||||
writer.WriteString("`")
|
||||
}
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, nil, `"`, vars...)
|
||||
}
|
||||
|
||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
return "numeric"
|
||||
case schema.Int, schema.Uint:
|
||||
if field.AutoIncrement && !field.PrimaryKey {
|
||||
// https://www.sqlite.org/autoinc.html
|
||||
return "integer PRIMARY KEY AUTOINCREMENT"
|
||||
} else {
|
||||
return "integer"
|
||||
}
|
||||
case schema.Float:
|
||||
return "real"
|
||||
case schema.String:
|
||||
return "text"
|
||||
case schema.Time:
|
||||
// Distinguish between schema.Time and tag time
|
||||
if val, ok := field.TagSettings["TYPE"]; ok {
|
||||
return val
|
||||
} else {
|
||||
return "datetime"
|
||||
}
|
||||
case schema.Bytes:
|
||||
return "blob"
|
||||
}
|
||||
|
||||
return string(field.DataType)
|
||||
}
|
||||
|
||||
func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error {
|
||||
tx.Exec("SAVEPOINT " + name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
|
||||
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func compareVersion(version1, version2 string) int {
|
||||
n, m := len(version1), len(version2)
|
||||
i, j := 0, 0
|
||||
for i < n || j < m {
|
||||
x := 0
|
||||
for ; i < n && version1[i] != '.'; i++ {
|
||||
x = x*10 + int(version1[i]-'0')
|
||||
}
|
||||
i++
|
||||
y := 0
|
||||
for ; j < m && version2[j] != '.'; j++ {
|
||||
y = y*10 + int(version2[j]-'0')
|
||||
}
|
||||
j++
|
||||
if x > y {
|
||||
return 1
|
||||
}
|
||||
if x < y {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
105
gormlite/sqlite_test.go
Normal file
105
gormlite/sqlite_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestDialector(t *testing.T) {
|
||||
// This is the DSN of the in-memory SQLite database for these tests.
|
||||
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
|
||||
|
||||
// Custom connection with a custom function called "my_custom_function".
|
||||
conn, err := driver.Open(InMemoryDSN, func(ctx context.Context, conn *sqlite3.Conn) error {
|
||||
return conn.CreateFunction("my_custom_function", 0, sqlite3.DETERMINISTIC,
|
||||
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
ctx.ResultText("my-result")
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows := []struct {
|
||||
description string
|
||||
dialector *Dialector
|
||||
openSuccess bool
|
||||
query string
|
||||
querySuccess bool
|
||||
}{
|
||||
{
|
||||
description: "Default driver",
|
||||
dialector: &Dialector{
|
||||
DSN: InMemoryDSN,
|
||||
},
|
||||
openSuccess: true,
|
||||
query: "SELECT 1",
|
||||
querySuccess: true,
|
||||
},
|
||||
{
|
||||
description: "Custom function",
|
||||
dialector: &Dialector{
|
||||
DSN: InMemoryDSN,
|
||||
},
|
||||
openSuccess: true,
|
||||
query: "SELECT my_custom_function()",
|
||||
querySuccess: false,
|
||||
},
|
||||
{
|
||||
description: "Custom connection",
|
||||
dialector: &Dialector{
|
||||
Conn: conn,
|
||||
},
|
||||
openSuccess: true,
|
||||
query: "SELECT 1",
|
||||
querySuccess: true,
|
||||
},
|
||||
{
|
||||
description: "Custom connection, custom function",
|
||||
dialector: &Dialector{
|
||||
Conn: conn,
|
||||
},
|
||||
openSuccess: true,
|
||||
query: "SELECT my_custom_function()",
|
||||
querySuccess: true,
|
||||
},
|
||||
}
|
||||
for rowIndex, row := range rows {
|
||||
t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) {
|
||||
db, err := gorm.Open(row.dialector, &gorm.Config{})
|
||||
if !row.openSuccess {
|
||||
if err == nil {
|
||||
t.Errorf("Expected Open to fail.")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected Open to succeed; got error: %v", err)
|
||||
}
|
||||
if db == nil {
|
||||
t.Errorf("Expected db to be non-nil.")
|
||||
}
|
||||
if row.query != "" {
|
||||
err = db.Exec(row.query).Error
|
||||
if !row.querySuccess {
|
||||
if err == nil {
|
||||
t.Errorf("Expected query to fail.")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected query to succeed; got error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
24
gormlite/test.sh
Executable file
24
gormlite/test.sh
Executable file
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
rm -rf gorm/ tests/
|
||||
git clone --filter=blob:none https://github.com/go-gorm/gorm.git
|
||||
mv gorm/tests tests
|
||||
rm -rf gorm/
|
||||
|
||||
patch -p1 -N < tests.patch
|
||||
|
||||
cd tests
|
||||
go mod edit \
|
||||
-require github.com/ncruces/go-sqlite3/gormlite@v0.0.0 \
|
||||
-replace github.com/ncruces/go-sqlite3/gormlite=../ \
|
||||
-replace github.com/ncruces/go-sqlite3=../../ \
|
||||
-droprequire gorm.io/driver/sqlite \
|
||||
-dropreplace gorm.io/gorm
|
||||
go mod tidy && go work use . && go test
|
||||
|
||||
cd ..
|
||||
rm -rf tests/
|
||||
go work use -r .
|
||||
31
gormlite/tests.patch
Normal file
31
gormlite/tests.patch
Normal file
@@ -0,0 +1,31 @@
|
||||
diff --git a/tests/.gitignore b/tests/.gitignore
|
||||
--- a/tests/.gitignore
|
||||
+++ b/tests/.gitignore
|
||||
@@ -1 +1 @@
|
||||
-go.sum
|
||||
+*
|
||||
diff --git a/tests/tests_test.go b/tests/tests_test.go
|
||||
--- a/tests/tests_test.go
|
||||
+++ b/tests/tests_test.go
|
||||
@@ -7,9 +7,11 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
+ _ "github.com/ncruces/go-sqlite3/embed"
|
||||
+ sqlite "github.com/ncruces/go-sqlite3/gormlite"
|
||||
+
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
- "gorm.io/driver/sqlite"
|
||||
"gorm.io/driver/sqlserver"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
@@ -89,7 +91,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
|
||||
db, err = gorm.Open(mysql.Open(dbDSN), cfg)
|
||||
default:
|
||||
log.Println("testing sqlite3...")
|
||||
- db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg)
|
||||
+ db, err = gorm.Open(sqlite.Open("file:"+filepath.Join(os.TempDir(), "gorm.db")+"?_pragma=busy_timeout(1000)&_pragma=foreign_keys(1)"), cfg)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
22
internal/util/bool.go
Normal file
22
internal/util/bool.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package util
|
||||
|
||||
import "strings"
|
||||
|
||||
func ParseBool(s string) (b, ok bool) {
|
||||
if len(s) == 0 {
|
||||
return false, false
|
||||
}
|
||||
if s[0] == '0' {
|
||||
return false, true
|
||||
}
|
||||
if '1' <= s[0] && s[0] <= '9' {
|
||||
return true, true
|
||||
}
|
||||
switch strings.ToLower(s) {
|
||||
case "true", "yes", "on":
|
||||
return true, true
|
||||
case "false", "no", "off":
|
||||
return false, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
28
internal/util/bool_test.go
Normal file
28
internal/util/bool_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package util
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseBool(t *testing.T) {
|
||||
tests := []struct {
|
||||
str string
|
||||
val bool
|
||||
ok bool
|
||||
}{
|
||||
{"", false, false},
|
||||
{"0", false, true},
|
||||
{"1", true, true},
|
||||
{"9", true, true},
|
||||
{"T", false, false},
|
||||
{"true", true, true},
|
||||
{"FALSE", false, true},
|
||||
{"false?", false, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.str, func(t *testing.T) {
|
||||
gotVal, gotOK := ParseBool(tt.str)
|
||||
if gotVal != tt.val || gotOK != tt.ok {
|
||||
t.Errorf("ParseBool(%q) = (%v, %v) want (%v, %v)", tt.str, gotVal, gotOK, tt.val, tt.ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
116
internal/util/const.go
Normal file
116
internal/util/const.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package util
|
||||
|
||||
// https://sqlite.com/matrix/rescode.html
|
||||
const (
|
||||
OK = 0 /* Successful result */
|
||||
|
||||
ERROR = 1 /* Generic error */
|
||||
INTERNAL = 2 /* Internal logic error in SQLite */
|
||||
PERM = 3 /* Access permission denied */
|
||||
ABORT = 4 /* Callback routine requested an abort */
|
||||
BUSY = 5 /* The database file is locked */
|
||||
LOCKED = 6 /* A table in the database is locked */
|
||||
NOMEM = 7 /* A malloc() failed */
|
||||
READONLY = 8 /* Attempt to write a readonly database */
|
||||
INTERRUPT = 9 /* Operation terminated by sqlite3_interrupt() */
|
||||
IOERR = 10 /* Some kind of disk I/O error occurred */
|
||||
CORRUPT = 11 /* The database disk image is malformed */
|
||||
NOTFOUND = 12 /* Unknown opcode in sqlite3_file_control() */
|
||||
FULL = 13 /* Insertion failed because database is full */
|
||||
CANTOPEN = 14 /* Unable to open the database file */
|
||||
PROTOCOL = 15 /* Database lock protocol error */
|
||||
EMPTY = 16 /* Internal use only */
|
||||
SCHEMA = 17 /* The database schema changed */
|
||||
TOOBIG = 18 /* String or BLOB exceeds size limit */
|
||||
CONSTRAINT = 19 /* Abort due to constraint violation */
|
||||
MISMATCH = 20 /* Data type mismatch */
|
||||
MISUSE = 21 /* Library used incorrectly */
|
||||
NOLFS = 22 /* Uses OS features not supported on host */
|
||||
AUTH = 23 /* Authorization denied */
|
||||
FORMAT = 24 /* Not used */
|
||||
RANGE = 25 /* 2nd parameter to sqlite3_bind out of range */
|
||||
NOTADB = 26 /* File opened that is not a database file */
|
||||
NOTICE = 27 /* Notifications from sqlite3_log() */
|
||||
WARNING = 28 /* Warnings from sqlite3_log() */
|
||||
|
||||
ROW = 100 /* sqlite3_step() has another row ready */
|
||||
DONE = 101 /* sqlite3_step() has finished executing */
|
||||
|
||||
ERROR_MISSING_COLLSEQ = ERROR | (1 << 8)
|
||||
ERROR_RETRY = ERROR | (2 << 8)
|
||||
ERROR_SNAPSHOT = ERROR | (3 << 8)
|
||||
IOERR_READ = IOERR | (1 << 8)
|
||||
IOERR_SHORT_READ = IOERR | (2 << 8)
|
||||
IOERR_WRITE = IOERR | (3 << 8)
|
||||
IOERR_FSYNC = IOERR | (4 << 8)
|
||||
IOERR_DIR_FSYNC = IOERR | (5 << 8)
|
||||
IOERR_TRUNCATE = IOERR | (6 << 8)
|
||||
IOERR_FSTAT = IOERR | (7 << 8)
|
||||
IOERR_UNLOCK = IOERR | (8 << 8)
|
||||
IOERR_RDLOCK = IOERR | (9 << 8)
|
||||
IOERR_DELETE = IOERR | (10 << 8)
|
||||
IOERR_BLOCKED = IOERR | (11 << 8)
|
||||
IOERR_NOMEM = IOERR | (12 << 8)
|
||||
IOERR_ACCESS = IOERR | (13 << 8)
|
||||
IOERR_CHECKRESERVEDLOCK = IOERR | (14 << 8)
|
||||
IOERR_LOCK = IOERR | (15 << 8)
|
||||
IOERR_CLOSE = IOERR | (16 << 8)
|
||||
IOERR_DIR_CLOSE = IOERR | (17 << 8)
|
||||
IOERR_SHMOPEN = IOERR | (18 << 8)
|
||||
IOERR_SHMSIZE = IOERR | (19 << 8)
|
||||
IOERR_SHMLOCK = IOERR | (20 << 8)
|
||||
IOERR_SHMMAP = IOERR | (21 << 8)
|
||||
IOERR_SEEK = IOERR | (22 << 8)
|
||||
IOERR_DELETE_NOENT = IOERR | (23 << 8)
|
||||
IOERR_MMAP = IOERR | (24 << 8)
|
||||
IOERR_GETTEMPPATH = IOERR | (25 << 8)
|
||||
IOERR_CONVPATH = IOERR | (26 << 8)
|
||||
IOERR_VNODE = IOERR | (27 << 8)
|
||||
IOERR_AUTH = IOERR | (28 << 8)
|
||||
IOERR_BEGIN_ATOMIC = IOERR | (29 << 8)
|
||||
IOERR_COMMIT_ATOMIC = IOERR | (30 << 8)
|
||||
IOERR_ROLLBACK_ATOMIC = IOERR | (31 << 8)
|
||||
IOERR_DATA = IOERR | (32 << 8)
|
||||
IOERR_CORRUPTFS = IOERR | (33 << 8)
|
||||
LOCKED_SHAREDCACHE = LOCKED | (1 << 8)
|
||||
LOCKED_VTAB = LOCKED | (2 << 8)
|
||||
BUSY_RECOVERY = BUSY | (1 << 8)
|
||||
BUSY_SNAPSHOT = BUSY | (2 << 8)
|
||||
BUSY_TIMEOUT = BUSY | (3 << 8)
|
||||
CANTOPEN_NOTEMPDIR = CANTOPEN | (1 << 8)
|
||||
CANTOPEN_ISDIR = CANTOPEN | (2 << 8)
|
||||
CANTOPEN_FULLPATH = CANTOPEN | (3 << 8)
|
||||
CANTOPEN_CONVPATH = CANTOPEN | (4 << 8)
|
||||
CANTOPEN_DIRTYWAL = CANTOPEN | (5 << 8) /* Not Used */
|
||||
CANTOPEN_SYMLINK = CANTOPEN | (6 << 8)
|
||||
CORRUPT_VTAB = CORRUPT | (1 << 8)
|
||||
CORRUPT_SEQUENCE = CORRUPT | (2 << 8)
|
||||
CORRUPT_INDEX = CORRUPT | (3 << 8)
|
||||
READONLY_RECOVERY = READONLY | (1 << 8)
|
||||
READONLY_CANTLOCK = READONLY | (2 << 8)
|
||||
READONLY_ROLLBACK = READONLY | (3 << 8)
|
||||
READONLY_DBMOVED = READONLY | (4 << 8)
|
||||
READONLY_CANTINIT = READONLY | (5 << 8)
|
||||
READONLY_DIRECTORY = READONLY | (6 << 8)
|
||||
ABORT_ROLLBACK = ABORT | (2 << 8)
|
||||
CONSTRAINT_CHECK = CONSTRAINT | (1 << 8)
|
||||
CONSTRAINT_COMMITHOOK = CONSTRAINT | (2 << 8)
|
||||
CONSTRAINT_FOREIGNKEY = CONSTRAINT | (3 << 8)
|
||||
CONSTRAINT_FUNCTION = CONSTRAINT | (4 << 8)
|
||||
CONSTRAINT_NOTNULL = CONSTRAINT | (5 << 8)
|
||||
CONSTRAINT_PRIMARYKEY = CONSTRAINT | (6 << 8)
|
||||
CONSTRAINT_TRIGGER = CONSTRAINT | (7 << 8)
|
||||
CONSTRAINT_UNIQUE = CONSTRAINT | (8 << 8)
|
||||
CONSTRAINT_VTAB = CONSTRAINT | (9 << 8)
|
||||
CONSTRAINT_ROWID = CONSTRAINT | (10 << 8)
|
||||
CONSTRAINT_PINNED = CONSTRAINT | (11 << 8)
|
||||
CONSTRAINT_DATATYPE = CONSTRAINT | (12 << 8)
|
||||
NOTICE_RECOVER_WAL = NOTICE | (1 << 8)
|
||||
NOTICE_RECOVER_ROLLBACK = NOTICE | (2 << 8)
|
||||
NOTICE_RBU = NOTICE | (3 << 8)
|
||||
WARNING_AUTOINDEX = WARNING | (1 << 8)
|
||||
AUTH_USER = AUTH | (1 << 8)
|
||||
|
||||
OK_LOAD_PERMANENTLY = OK | (1 << 8)
|
||||
OK_SYMLINK = OK | (2 << 8) /* internal use only */
|
||||
)
|
||||
115
internal/util/error.go
Normal file
115
internal/util/error.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type ErrorString string
|
||||
|
||||
func (e ErrorString) Error() string { return string(e) }
|
||||
|
||||
const (
|
||||
NilErr = ErrorString("sqlite3: invalid memory address or null pointer dereference")
|
||||
OOMErr = ErrorString("sqlite3: out of memory")
|
||||
RangeErr = ErrorString("sqlite3: index out of range")
|
||||
NoNulErr = ErrorString("sqlite3: missing NUL terminator")
|
||||
NoGlobalErr = ErrorString("sqlite3: could not find global: ")
|
||||
NoFuncErr = ErrorString("sqlite3: could not find function: ")
|
||||
BinaryErr = ErrorString("sqlite3: no SQLite binary embed/set/loaded")
|
||||
TimeErr = ErrorString("sqlite3: invalid time value")
|
||||
WhenceErr = ErrorString("sqlite3: invalid whence")
|
||||
OffsetErr = ErrorString("sqlite3: invalid offset")
|
||||
TailErr = ErrorString("sqlite3: multiple statements")
|
||||
IsolationErr = ErrorString("sqlite3: unsupported isolation level")
|
||||
NoVFSErr = ErrorString("sqlite3: no such vfs: ")
|
||||
)
|
||||
|
||||
func AssertErr() ErrorString {
|
||||
msg := "sqlite3: assertion failed"
|
||||
if _, file, line, ok := runtime.Caller(1); ok {
|
||||
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
|
||||
}
|
||||
return ErrorString(msg)
|
||||
}
|
||||
|
||||
func Finalizer[T any](skip int) func(*T) {
|
||||
msg := fmt.Sprintf("sqlite3: %T not closed", new(T))
|
||||
if _, file, line, ok := runtime.Caller(skip + 1); ok && skip >= 0 {
|
||||
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
|
||||
}
|
||||
return func(*T) { panic(ErrorString(msg)) }
|
||||
}
|
||||
|
||||
func ErrorCodeString(rc uint32) string {
|
||||
switch rc {
|
||||
case ABORT_ROLLBACK:
|
||||
return "sqlite3: abort due to ROLLBACK"
|
||||
case ROW:
|
||||
return "sqlite3: another row available"
|
||||
case DONE:
|
||||
return "sqlite3: no more rows available"
|
||||
}
|
||||
switch rc & 0xff {
|
||||
case OK:
|
||||
return "sqlite3: not an error"
|
||||
case ERROR:
|
||||
return "sqlite3: SQL logic error"
|
||||
case INTERNAL:
|
||||
break
|
||||
case PERM:
|
||||
return "sqlite3: access permission denied"
|
||||
case ABORT:
|
||||
return "sqlite3: query aborted"
|
||||
case BUSY:
|
||||
return "sqlite3: database is locked"
|
||||
case LOCKED:
|
||||
return "sqlite3: database table is locked"
|
||||
case NOMEM:
|
||||
return "sqlite3: out of memory"
|
||||
case READONLY:
|
||||
return "sqlite3: attempt to write a readonly database"
|
||||
case INTERRUPT:
|
||||
return "sqlite3: interrupted"
|
||||
case IOERR:
|
||||
return "sqlite3: disk I/O error"
|
||||
case CORRUPT:
|
||||
return "sqlite3: database disk image is malformed"
|
||||
case NOTFOUND:
|
||||
return "sqlite3: unknown operation"
|
||||
case FULL:
|
||||
return "sqlite3: database or disk is full"
|
||||
case CANTOPEN:
|
||||
return "sqlite3: unable to open database file"
|
||||
case PROTOCOL:
|
||||
return "sqlite3: locking protocol"
|
||||
case FORMAT:
|
||||
break
|
||||
case SCHEMA:
|
||||
return "sqlite3: database schema has changed"
|
||||
case TOOBIG:
|
||||
return "sqlite3: string or blob too big"
|
||||
case CONSTRAINT:
|
||||
return "sqlite3: constraint failed"
|
||||
case MISMATCH:
|
||||
return "sqlite3: datatype mismatch"
|
||||
case MISUSE:
|
||||
return "sqlite3: bad parameter or other API misuse"
|
||||
case NOLFS:
|
||||
break
|
||||
case AUTH:
|
||||
return "sqlite3: authorization denied"
|
||||
case EMPTY:
|
||||
break
|
||||
case RANGE:
|
||||
return "sqlite3: column index out of range"
|
||||
case NOTADB:
|
||||
return "sqlite3: file is not a database"
|
||||
case NOTICE:
|
||||
return "sqlite3: notification message"
|
||||
case WARNING:
|
||||
return "sqlite3: warning message"
|
||||
}
|
||||
return "sqlite3: unknown error"
|
||||
}
|
||||
128
internal/util/func.go
Normal file
128
internal/util/func.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
type i32 interface{ ~int32 | ~uint32 }
|
||||
type i64 interface{ ~int64 | ~uint64 }
|
||||
|
||||
type funcVI[T0 i32] func(context.Context, api.Module, T0)
|
||||
|
||||
func (fn funcVI[T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
fn(ctx, mod, T0(stack[0]))
|
||||
}
|
||||
|
||||
func ExportFuncVI[T0 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0)) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(funcVI[T0](fn),
|
||||
[]api.ValueType{api.ValueTypeI32}, nil).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
type funcVIII[T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2)
|
||||
|
||||
func (fn funcVIII[T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]))
|
||||
}
|
||||
|
||||
func ExportFuncVIII[T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2)) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(funcVIII[T0, T1, T2](fn),
|
||||
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, nil).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
type funcII[TR, T0 i32] func(context.Context, api.Module, T0) TR
|
||||
|
||||
func (fn funcII[TR, T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0])))
|
||||
}
|
||||
|
||||
func ExportFuncII[TR, T0 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0) TR) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(funcII[TR, T0](fn),
|
||||
[]api.ValueType{api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
type funcIII[TR, T0, T1 i32] func(context.Context, api.Module, T0, T1) TR
|
||||
|
||||
func (fn funcIII[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
|
||||
}
|
||||
|
||||
func ExportFuncIII[TR, T0, T1 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1) TR) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(funcIII[TR, T0, T1](fn),
|
||||
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
type funcIIII[TR, T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2) TR
|
||||
|
||||
func (fn funcIIII[TR, T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2])))
|
||||
}
|
||||
|
||||
func ExportFuncIIII[TR, T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2) TR) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(funcIIII[TR, T0, T1, T2](fn),
|
||||
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
type funcIIIII[TR, T0, T1, T2, T3 i32] func(context.Context, api.Module, T0, T1, T2, T3) TR
|
||||
|
||||
func (fn funcIIIII[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
|
||||
}
|
||||
|
||||
func ExportFuncIIIII[TR, T0, T1, T2, T3 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3) TR) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(funcIIIII[TR, T0, T1, T2, T3](fn),
|
||||
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
type funcIIIIII[TR, T0, T1, T2, T3, T4 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4) TR
|
||||
|
||||
func (fn funcIIIIII[TR, T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4])))
|
||||
}
|
||||
|
||||
func ExportFuncIIIIII[TR, T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3, T4) TR) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(funcIIIIII[TR, T0, T1, T2, T3, T4](fn),
|
||||
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
type funcIIIIJ[TR, T0, T1, T2 i32, T3 i64] func(context.Context, api.Module, T0, T1, T2, T3) TR
|
||||
|
||||
func (fn funcIIIIJ[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
|
||||
}
|
||||
|
||||
func ExportFuncIIIIJ[TR, T0, T1, T2 i32, T3 i64](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3) TR) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(funcIIIIJ[TR, T0, T1, T2, T3](fn),
|
||||
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
type funcIIJ[TR, T0 i32, T1 i64] func(context.Context, api.Module, T0, T1) TR
|
||||
|
||||
func (fn funcIIJ[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
|
||||
}
|
||||
|
||||
func ExportFuncIIJ[TR, T0 i32, T1 i64](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1) TR) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(funcIIJ[TR, T0, T1](fn),
|
||||
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export(name)
|
||||
}
|
||||
75
internal/util/handle.go
Normal file
75
internal/util/handle.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/tetratelabs/wazero/experimental"
|
||||
)
|
||||
|
||||
type handleKey struct{}
|
||||
type handleState struct {
|
||||
handles []any
|
||||
empty int
|
||||
}
|
||||
|
||||
func NewContext(ctx context.Context) context.Context {
|
||||
state := new(handleState)
|
||||
ctx = experimental.WithCloseNotifier(ctx, state)
|
||||
ctx = context.WithValue(ctx, handleKey{}, state)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (s *handleState) CloseNotify(ctx context.Context, exitCode uint32) {
|
||||
for _, h := range s.handles {
|
||||
if c, ok := h.(io.Closer); ok {
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
s.handles = nil
|
||||
s.empty = 0
|
||||
}
|
||||
|
||||
func GetHandle(ctx context.Context, id uint32) any {
|
||||
if id == 0 {
|
||||
return nil
|
||||
}
|
||||
s := ctx.Value(handleKey{}).(*handleState)
|
||||
return s.handles[^id]
|
||||
}
|
||||
|
||||
func DelHandle(ctx context.Context, id uint32) error {
|
||||
if id == 0 {
|
||||
return nil
|
||||
}
|
||||
s := ctx.Value(handleKey{}).(*handleState)
|
||||
a := s.handles[^id]
|
||||
s.handles[^id] = nil
|
||||
s.empty++
|
||||
if c, ok := a.(io.Closer); ok {
|
||||
return c.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddHandle(ctx context.Context, a any) (id uint32) {
|
||||
if a == nil {
|
||||
panic(NilErr)
|
||||
}
|
||||
s := ctx.Value(handleKey{}).(*handleState)
|
||||
|
||||
// Find an empty slot.
|
||||
if s.empty > cap(s.handles)-len(s.handles) {
|
||||
for id, h := range s.handles {
|
||||
if h == nil {
|
||||
s.empty--
|
||||
s.handles[id] = a
|
||||
return ^uint32(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add a new slot.
|
||||
s.handles = append(s.handles, a)
|
||||
return -uint32(len(s.handles))
|
||||
}
|
||||
110
internal/util/mem.go
Normal file
110
internal/util/mem.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
func View(mod api.Module, ptr uint32, size uint64) []byte {
|
||||
if ptr == 0 {
|
||||
panic(NilErr)
|
||||
}
|
||||
if size > math.MaxUint32 {
|
||||
panic(RangeErr)
|
||||
}
|
||||
buf, ok := mod.Memory().Read(ptr, uint32(size))
|
||||
if !ok {
|
||||
panic(RangeErr)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
func ReadUint32(mod api.Module, ptr uint32) uint32 {
|
||||
if ptr == 0 {
|
||||
panic(NilErr)
|
||||
}
|
||||
v, ok := mod.Memory().ReadUint32Le(ptr)
|
||||
if !ok {
|
||||
panic(RangeErr)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func WriteUint32(mod api.Module, ptr uint32, v uint32) {
|
||||
if ptr == 0 {
|
||||
panic(NilErr)
|
||||
}
|
||||
ok := mod.Memory().WriteUint32Le(ptr, v)
|
||||
if !ok {
|
||||
panic(RangeErr)
|
||||
}
|
||||
}
|
||||
|
||||
func ReadUint64(mod api.Module, ptr uint32) uint64 {
|
||||
if ptr == 0 {
|
||||
panic(NilErr)
|
||||
}
|
||||
v, ok := mod.Memory().ReadUint64Le(ptr)
|
||||
if !ok {
|
||||
panic(RangeErr)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func WriteUint64(mod api.Module, ptr uint32, v uint64) {
|
||||
if ptr == 0 {
|
||||
panic(NilErr)
|
||||
}
|
||||
ok := mod.Memory().WriteUint64Le(ptr, v)
|
||||
if !ok {
|
||||
panic(RangeErr)
|
||||
}
|
||||
}
|
||||
|
||||
func ReadFloat64(mod api.Module, ptr uint32) float64 {
|
||||
return math.Float64frombits(ReadUint64(mod, ptr))
|
||||
}
|
||||
|
||||
func WriteFloat64(mod api.Module, ptr uint32, v float64) {
|
||||
WriteUint64(mod, ptr, math.Float64bits(v))
|
||||
}
|
||||
|
||||
func ReadString(mod api.Module, ptr, maxlen uint32) string {
|
||||
if ptr == 0 {
|
||||
panic(NilErr)
|
||||
}
|
||||
switch maxlen {
|
||||
case 0:
|
||||
return ""
|
||||
case math.MaxUint32:
|
||||
// avoid overflow
|
||||
default:
|
||||
maxlen = maxlen + 1
|
||||
}
|
||||
mem := mod.Memory()
|
||||
buf, ok := mem.Read(ptr, maxlen)
|
||||
if !ok {
|
||||
buf, ok = mem.Read(ptr, mem.Size()-ptr)
|
||||
if !ok {
|
||||
panic(RangeErr)
|
||||
}
|
||||
}
|
||||
if i := bytes.IndexByte(buf, 0); i < 0 {
|
||||
panic(NoNulErr)
|
||||
} else {
|
||||
return string(buf[:i])
|
||||
}
|
||||
}
|
||||
|
||||
func WriteBytes(mod api.Module, ptr uint32, b []byte) {
|
||||
buf := View(mod, ptr, uint64(len(b)))
|
||||
copy(buf, b)
|
||||
}
|
||||
|
||||
func WriteString(mod api.Module, ptr uint32, s string) {
|
||||
buf := View(mod, ptr, uint64(len(s)+1))
|
||||
buf[len(s)] = 0
|
||||
copy(buf, s)
|
||||
}
|
||||
92
internal/util/mem_test.go
Normal file
92
internal/util/mem_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/tetratelabs/wazero/experimental/wazerotest"
|
||||
)
|
||||
|
||||
func TestView_nil(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
View(mock, 0, 8)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestView_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
View(mock, wazerotest.PageSize-2, 8)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestView_overflow(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
View(mock, 1, math.MaxInt64)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestReadUint32_nil(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
ReadUint32(mock, 0)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestReadUint32_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
ReadUint32(mock, wazerotest.PageSize-2)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestReadUint64_nil(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
ReadUint64(mock, 0)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestReadUint64_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
ReadUint64(mock, wazerotest.PageSize-2)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestWriteUint32_nil(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
WriteUint32(mock, 0, 1)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestWriteUint32_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
WriteUint32(mock, wazerotest.PageSize-2, 1)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestWriteUint64_nil(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
WriteUint64(mock, 0, 1)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestWriteUint64_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
WriteUint64(mock, wazerotest.PageSize-2, 1)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestReadString_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
ReadString(mock, wazerotest.PageSize+2, math.MaxUint32)
|
||||
t.Error("want panic")
|
||||
}
|
||||
111
mem.go
111
mem.go
@@ -1,111 +0,0 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
type memory struct {
|
||||
mod api.Module
|
||||
}
|
||||
|
||||
func (m memory) view(ptr, size uint32) []byte {
|
||||
if ptr == 0 {
|
||||
panic(nilErr)
|
||||
}
|
||||
buf, ok := m.mod.Memory().Read(ptr, size)
|
||||
if !ok {
|
||||
panic(rangeErr)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
func (m memory) readUint32(ptr uint32) uint32 {
|
||||
if ptr == 0 {
|
||||
panic(nilErr)
|
||||
}
|
||||
v, ok := m.mod.Memory().ReadUint32Le(ptr)
|
||||
if !ok {
|
||||
panic(rangeErr)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (m memory) writeUint32(ptr, v uint32) {
|
||||
if ptr == 0 {
|
||||
panic(nilErr)
|
||||
}
|
||||
ok := m.mod.Memory().WriteUint32Le(ptr, v)
|
||||
if !ok {
|
||||
panic(rangeErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (m memory) readUint64(ptr uint32) uint64 {
|
||||
if ptr == 0 {
|
||||
panic(nilErr)
|
||||
}
|
||||
v, ok := m.mod.Memory().ReadUint64Le(ptr)
|
||||
if !ok {
|
||||
panic(rangeErr)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (m memory) writeUint64(ptr uint32, v uint64) {
|
||||
if ptr == 0 {
|
||||
panic(nilErr)
|
||||
}
|
||||
ok := m.mod.Memory().WriteUint64Le(ptr, v)
|
||||
if !ok {
|
||||
panic(rangeErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (m memory) readFloat64(ptr uint32) float64 {
|
||||
return math.Float64frombits(m.readUint64(ptr))
|
||||
}
|
||||
|
||||
func (m memory) writeFloat64(ptr uint32, v float64) {
|
||||
m.writeUint64(ptr, math.Float64bits(v))
|
||||
}
|
||||
|
||||
func (m memory) readString(ptr, maxlen uint32) string {
|
||||
if ptr == 0 {
|
||||
panic(nilErr)
|
||||
}
|
||||
switch maxlen {
|
||||
case 0:
|
||||
return ""
|
||||
case math.MaxUint32:
|
||||
// avoid overflow
|
||||
default:
|
||||
maxlen = maxlen + 1
|
||||
}
|
||||
mem := m.mod.Memory()
|
||||
buf, ok := mem.Read(ptr, maxlen)
|
||||
if !ok {
|
||||
buf, ok = mem.Read(ptr, mem.Size()-ptr)
|
||||
if !ok {
|
||||
panic(rangeErr)
|
||||
}
|
||||
}
|
||||
if i := bytes.IndexByte(buf, 0); i < 0 {
|
||||
panic(noNulErr)
|
||||
} else {
|
||||
return string(buf[:i])
|
||||
}
|
||||
}
|
||||
|
||||
func (m memory) writeBytes(ptr uint32, b []byte) {
|
||||
buf := m.view(ptr, uint32(len(b)))
|
||||
copy(buf, b)
|
||||
}
|
||||
|
||||
func (m memory) writeString(ptr uint32, s string) {
|
||||
buf := m.view(ptr, uint32(len(s)+1))
|
||||
buf[len(s)] = 0
|
||||
copy(buf, s)
|
||||
}
|
||||
83
mem_test.go
83
mem_test.go
@@ -1,83 +0,0 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_memory_view_nil(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mem := newMemory(128)
|
||||
mem.view(0, 8)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func Test_memory_view_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mem := newMemory(128)
|
||||
mem.view(126, 8)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func Test_memory_readUint32_nil(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mem := newMemory(128)
|
||||
mem.readUint32(0)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func Test_memory_readUint32_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mem := newMemory(128)
|
||||
mem.readUint32(126)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func Test_memory_readUint64_nil(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mem := newMemory(128)
|
||||
mem.readUint64(0)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func Test_memory_readUint64_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mem := newMemory(128)
|
||||
mem.readUint64(126)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func Test_memory_writeUint32_nil(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mem := newMemory(128)
|
||||
mem.writeUint32(0, 1)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func Test_memory_writeUint32_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mem := newMemory(128)
|
||||
mem.writeUint32(126, 1)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func Test_memory_writeUint64_nil(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mem := newMemory(128)
|
||||
mem.writeUint64(0, 1)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func Test_memory_writeUint64_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mem := newMemory(128)
|
||||
mem.writeUint64(126, 1)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func Test_memory_readString_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mem := newMemory(128)
|
||||
mem.readString(130, math.MaxUint32)
|
||||
t.Error("want panic")
|
||||
}
|
||||
394
sqlite.go
Normal file
394
sqlite.go
Normal file
@@ -0,0 +1,394 @@
|
||||
// Package sqlite3 wraps the C SQLite API.
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/ncruces/go-sqlite3/vfs"
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
// Configure SQLite WASM.
|
||||
//
|
||||
// Importing package embed initializes these
|
||||
// with an appropriate build of SQLite:
|
||||
//
|
||||
// import _ "github.com/ncruces/go-sqlite3/embed"
|
||||
var (
|
||||
Binary []byte // WASM binary to load.
|
||||
Path string // Path to load the binary from.
|
||||
)
|
||||
|
||||
var instance struct {
|
||||
runtime wazero.Runtime
|
||||
compiled wazero.CompiledModule
|
||||
err error
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func compileSQLite() {
|
||||
ctx := context.Background()
|
||||
instance.runtime = wazero.NewRuntime(ctx)
|
||||
|
||||
env := instance.runtime.NewHostModuleBuilder("env")
|
||||
env = vfs.ExportHostFunctions(env)
|
||||
env = exportHostFunctions(env)
|
||||
_, instance.err = env.Instantiate(ctx)
|
||||
if instance.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
bin := Binary
|
||||
if bin == nil && Path != "" {
|
||||
bin, instance.err = os.ReadFile(Path)
|
||||
if instance.err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if bin == nil {
|
||||
instance.err = util.BinaryErr
|
||||
return
|
||||
}
|
||||
|
||||
instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin)
|
||||
}
|
||||
|
||||
type sqlite struct {
|
||||
ctx context.Context
|
||||
mod api.Module
|
||||
api sqliteAPI
|
||||
stack [8]uint64
|
||||
}
|
||||
|
||||
type sqliteKey struct{}
|
||||
|
||||
func instantiateSQLite() (sqlt *sqlite, err error) {
|
||||
instance.once.Do(compileSQLite)
|
||||
if instance.err != nil {
|
||||
return nil, instance.err
|
||||
}
|
||||
|
||||
sqlt = new(sqlite)
|
||||
sqlt.ctx = util.NewContext(context.Background())
|
||||
sqlt.ctx = context.WithValue(sqlt.ctx, sqliteKey{}, sqlt)
|
||||
|
||||
sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
|
||||
instance.compiled, wazero.NewModuleConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getFun := func(name string) api.Function {
|
||||
f := sqlt.mod.ExportedFunction(name)
|
||||
if f == nil {
|
||||
err = util.NoFuncErr + util.ErrorString(name)
|
||||
return nil
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
getVal := func(name string) uint32 {
|
||||
g := sqlt.mod.ExportedGlobal(name)
|
||||
if g == nil {
|
||||
err = util.NoGlobalErr + util.ErrorString(name)
|
||||
return 0
|
||||
}
|
||||
return util.ReadUint32(sqlt.mod, uint32(g.Get()))
|
||||
}
|
||||
|
||||
sqlt.api = sqliteAPI{
|
||||
free: getFun("free"),
|
||||
malloc: getFun("malloc"),
|
||||
destructor: getVal("malloc_destructor"),
|
||||
errcode: getFun("sqlite3_errcode"),
|
||||
errstr: getFun("sqlite3_errstr"),
|
||||
errmsg: getFun("sqlite3_errmsg"),
|
||||
erroff: getFun("sqlite3_error_offset"),
|
||||
open: getFun("sqlite3_open_v2"),
|
||||
close: getFun("sqlite3_close"),
|
||||
closeZombie: getFun("sqlite3_close_v2"),
|
||||
prepare: getFun("sqlite3_prepare_v3"),
|
||||
finalize: getFun("sqlite3_finalize"),
|
||||
reset: getFun("sqlite3_reset"),
|
||||
step: getFun("sqlite3_step"),
|
||||
exec: getFun("sqlite3_exec"),
|
||||
clearBindings: getFun("sqlite3_clear_bindings"),
|
||||
bindCount: getFun("sqlite3_bind_parameter_count"),
|
||||
bindIndex: getFun("sqlite3_bind_parameter_index"),
|
||||
bindName: getFun("sqlite3_bind_parameter_name"),
|
||||
bindNull: getFun("sqlite3_bind_null"),
|
||||
bindInteger: getFun("sqlite3_bind_int64"),
|
||||
bindFloat: getFun("sqlite3_bind_double"),
|
||||
bindText: getFun("sqlite3_bind_text64"),
|
||||
bindBlob: getFun("sqlite3_bind_blob64"),
|
||||
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
|
||||
columnCount: getFun("sqlite3_column_count"),
|
||||
columnName: getFun("sqlite3_column_name"),
|
||||
columnType: getFun("sqlite3_column_type"),
|
||||
columnInteger: getFun("sqlite3_column_int64"),
|
||||
columnFloat: getFun("sqlite3_column_double"),
|
||||
columnText: getFun("sqlite3_column_text"),
|
||||
columnBlob: getFun("sqlite3_column_blob"),
|
||||
columnBytes: getFun("sqlite3_column_bytes"),
|
||||
blobOpen: getFun("sqlite3_blob_open"),
|
||||
blobClose: getFun("sqlite3_blob_close"),
|
||||
blobReopen: getFun("sqlite3_blob_reopen"),
|
||||
blobBytes: getFun("sqlite3_blob_bytes"),
|
||||
blobRead: getFun("sqlite3_blob_read"),
|
||||
blobWrite: getFun("sqlite3_blob_write"),
|
||||
backupInit: getFun("sqlite3_backup_init"),
|
||||
backupStep: getFun("sqlite3_backup_step"),
|
||||
backupFinish: getFun("sqlite3_backup_finish"),
|
||||
backupRemaining: getFun("sqlite3_backup_remaining"),
|
||||
backupPageCount: getFun("sqlite3_backup_pagecount"),
|
||||
changes: getFun("sqlite3_changes64"),
|
||||
lastRowid: getFun("sqlite3_last_insert_rowid"),
|
||||
autocommit: getFun("sqlite3_get_autocommit"),
|
||||
anyCollation: getFun("sqlite3_anycollseq_init"),
|
||||
createCollation: getFun("sqlite3_create_collation_go"),
|
||||
createFunction: getFun("sqlite3_create_function_go"),
|
||||
createAggregate: getFun("sqlite3_create_aggregate_function_go"),
|
||||
createWindow: getFun("sqlite3_create_window_function_go"),
|
||||
aggregateCtx: getFun("sqlite3_aggregate_context"),
|
||||
userData: getFun("sqlite3_user_data"),
|
||||
setAuxData: getFun("sqlite3_set_auxdata_go"),
|
||||
getAuxData: getFun("sqlite3_get_auxdata"),
|
||||
valueType: getFun("sqlite3_value_type"),
|
||||
valueInteger: getFun("sqlite3_value_int64"),
|
||||
valueFloat: getFun("sqlite3_value_double"),
|
||||
valueText: getFun("sqlite3_value_text"),
|
||||
valueBlob: getFun("sqlite3_value_blob"),
|
||||
valueBytes: getFun("sqlite3_value_bytes"),
|
||||
resultNull: getFun("sqlite3_result_null"),
|
||||
resultInteger: getFun("sqlite3_result_int64"),
|
||||
resultFloat: getFun("sqlite3_result_double"),
|
||||
resultText: getFun("sqlite3_result_text64"),
|
||||
resultBlob: getFun("sqlite3_result_blob64"),
|
||||
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
|
||||
resultError: getFun("sqlite3_result_error"),
|
||||
resultErrorCode: getFun("sqlite3_result_error_code"),
|
||||
resultErrorMem: getFun("sqlite3_result_error_nomem"),
|
||||
resultErrorBig: getFun("sqlite3_result_error_toobig"),
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sqlt, nil
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) close() error {
|
||||
return sqlt.mod.Close(sqlt.ctx)
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
|
||||
if rc == _OK {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := Error{code: rc}
|
||||
|
||||
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
|
||||
if r := sqlt.call(sqlt.api.errstr, rc); r != 0 {
|
||||
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
|
||||
}
|
||||
|
||||
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
|
||||
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
|
||||
}
|
||||
|
||||
if sql != nil {
|
||||
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
|
||||
err.sql = sql[0][r:]
|
||||
}
|
||||
}
|
||||
|
||||
switch err.msg {
|
||||
case err.str, "not an error":
|
||||
err.msg = ""
|
||||
}
|
||||
return &err
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) call(fn api.Function, params ...uint64) uint64 {
|
||||
copy(sqlt.stack[:], params)
|
||||
err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return sqlt.stack[0]
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) free(ptr uint32) {
|
||||
if ptr == 0 {
|
||||
return
|
||||
}
|
||||
sqlt.call(sqlt.api.free, uint64(ptr))
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) new(size uint64) uint32 {
|
||||
if size > _MAX_ALLOCATION_SIZE {
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
ptr := uint32(sqlt.call(sqlt.api.malloc, size))
|
||||
if ptr == 0 && size != 0 {
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) newBytes(b []byte) uint32 {
|
||||
if b == nil {
|
||||
return 0
|
||||
}
|
||||
ptr := sqlt.new(uint64(len(b)))
|
||||
util.WriteBytes(sqlt.mod, ptr, b)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) newString(s string) uint32 {
|
||||
ptr := sqlt.new(uint64(len(s) + 1))
|
||||
util.WriteString(sqlt.mod, ptr, s)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) newArena(size uint64) arena {
|
||||
return arena{
|
||||
sqlt: sqlt,
|
||||
size: uint32(size),
|
||||
base: sqlt.new(size),
|
||||
}
|
||||
}
|
||||
|
||||
type arena struct {
|
||||
sqlt *sqlite
|
||||
ptrs []uint32
|
||||
base uint32
|
||||
next uint32
|
||||
size uint32
|
||||
}
|
||||
|
||||
func (a *arena) free() {
|
||||
if a.sqlt == nil {
|
||||
return
|
||||
}
|
||||
a.reset()
|
||||
a.sqlt.free(a.base)
|
||||
a.sqlt = nil
|
||||
}
|
||||
|
||||
func (a *arena) reset() {
|
||||
for _, ptr := range a.ptrs {
|
||||
a.sqlt.free(ptr)
|
||||
}
|
||||
a.ptrs = nil
|
||||
a.next = 0
|
||||
}
|
||||
|
||||
func (a *arena) new(size uint64) uint32 {
|
||||
if size <= uint64(a.size-a.next) {
|
||||
ptr := a.base + a.next
|
||||
a.next += uint32(size)
|
||||
return ptr
|
||||
}
|
||||
ptr := a.sqlt.new(size)
|
||||
a.ptrs = append(a.ptrs, ptr)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (a *arena) bytes(b []byte) uint32 {
|
||||
if b == nil {
|
||||
return 0
|
||||
}
|
||||
ptr := a.new(uint64(len(b)))
|
||||
util.WriteBytes(a.sqlt.mod, ptr, b)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (a *arena) string(s string) uint32 {
|
||||
ptr := a.new(uint64(len(s) + 1))
|
||||
util.WriteString(a.sqlt.mod, ptr, s)
|
||||
return ptr
|
||||
}
|
||||
|
||||
type sqliteAPI struct {
|
||||
free api.Function
|
||||
malloc api.Function
|
||||
errcode api.Function
|
||||
errstr api.Function
|
||||
errmsg api.Function
|
||||
erroff api.Function
|
||||
open api.Function
|
||||
close api.Function
|
||||
closeZombie api.Function
|
||||
prepare api.Function
|
||||
finalize api.Function
|
||||
reset api.Function
|
||||
step api.Function
|
||||
exec api.Function
|
||||
clearBindings api.Function
|
||||
bindCount api.Function
|
||||
bindIndex api.Function
|
||||
bindName api.Function
|
||||
bindNull api.Function
|
||||
bindInteger api.Function
|
||||
bindFloat api.Function
|
||||
bindText api.Function
|
||||
bindBlob api.Function
|
||||
bindZeroBlob api.Function
|
||||
columnCount api.Function
|
||||
columnName api.Function
|
||||
columnType api.Function
|
||||
columnInteger api.Function
|
||||
columnFloat api.Function
|
||||
columnText api.Function
|
||||
columnBlob api.Function
|
||||
columnBytes api.Function
|
||||
blobOpen api.Function
|
||||
blobClose api.Function
|
||||
blobReopen api.Function
|
||||
blobBytes api.Function
|
||||
blobRead api.Function
|
||||
blobWrite api.Function
|
||||
backupInit api.Function
|
||||
backupStep api.Function
|
||||
backupFinish api.Function
|
||||
backupRemaining api.Function
|
||||
backupPageCount api.Function
|
||||
changes api.Function
|
||||
lastRowid api.Function
|
||||
autocommit api.Function
|
||||
anyCollation api.Function
|
||||
createCollation api.Function
|
||||
createFunction api.Function
|
||||
createAggregate api.Function
|
||||
createWindow api.Function
|
||||
aggregateCtx api.Function
|
||||
userData api.Function
|
||||
setAuxData api.Function
|
||||
getAuxData api.Function
|
||||
valueType api.Function
|
||||
valueInteger api.Function
|
||||
valueFloat api.Function
|
||||
valueText api.Function
|
||||
valueBlob api.Function
|
||||
valueBytes api.Function
|
||||
resultNull api.Function
|
||||
resultInteger api.Function
|
||||
resultFloat api.Function
|
||||
resultText api.Function
|
||||
resultBlob api.Function
|
||||
resultZeroBlob api.Function
|
||||
resultError api.Function
|
||||
resultErrorCode api.Function
|
||||
resultErrorMem api.Function
|
||||
resultErrorBig api.Function
|
||||
destructor uint32
|
||||
}
|
||||
4
sqlite3/.gitignore
vendored
Normal file
4
sqlite3/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
ext/
|
||||
sqlite3.c
|
||||
sqlite3.h
|
||||
sqlite3ext.h
|
||||
@@ -1,13 +1,35 @@
|
||||
#!/usr/bin/env bash
|
||||
set -eo pipefail
|
||||
set -euo pipefail
|
||||
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
if [ ! -f "sqlite3.c" ]; then
|
||||
url="https://www.sqlite.org/2022/sqlite-amalgamation-3400100.zip"
|
||||
curl "$url" > sqlite.zip
|
||||
unzip -d . sqlite.zip
|
||||
mv sqlite-amalgamation-*/sqlite3* .
|
||||
rm -rf sqlite-amalgamation-*
|
||||
rm sqlite.zip
|
||||
fi
|
||||
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3430200.zip"
|
||||
unzip -d . sqlite-amalgamation-*.zip
|
||||
mv sqlite-amalgamation-*/sqlite3* .
|
||||
rm -rf sqlite-amalgamation-*
|
||||
|
||||
cat *.patch | patch --posix
|
||||
|
||||
mkdir -p ext/
|
||||
cd ext/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/ext/misc/decimal.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/ext/misc/uint.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/ext/misc/uuid.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/ext/misc/base64.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/ext/misc/regexp.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/ext/misc/series.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/ext/misc/anycollseq.c"
|
||||
cd ~-
|
||||
|
||||
cd ../vfs/tests/mptest/testdata/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/mptest/mptest.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/mptest/config01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/mptest/config02.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/mptest/crash01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/mptest/crash02.subtest"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/mptest/multiwrite01.test"
|
||||
cd ~-
|
||||
|
||||
cd ../vfs/tests/speedtest1/testdata/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/test/speedtest1.c"
|
||||
cd ~-
|
||||
7
sqlite3/format.sh
Executable file
7
sqlite3/format.sh
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
shopt -s extglob
|
||||
clang-format -i !(sqlite3*).@(c|h)
|
||||
40
sqlite3/func.c
Normal file
40
sqlite3/func.c
Normal file
@@ -0,0 +1,40 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
|
||||
int go_compare(void *, int, const void *, int, const void *);
|
||||
void go_func(sqlite3_context *, int, sqlite3_value **);
|
||||
void go_step(sqlite3_context *, int, sqlite3_value **);
|
||||
void go_final(sqlite3_context *);
|
||||
void go_value(sqlite3_context *);
|
||||
void go_inverse(sqlite3_context *, int, sqlite3_value **);
|
||||
void go_destroy(void *);
|
||||
|
||||
int sqlite3_create_collation_go(sqlite3 *db, const char *zName, void *pApp) {
|
||||
return sqlite3_create_collation_v2(db, zName, SQLITE_UTF8, pApp, go_compare,
|
||||
go_destroy);
|
||||
}
|
||||
|
||||
int sqlite3_create_function_go(sqlite3 *db, const char *zName, int nArg,
|
||||
int flags, void *pApp) {
|
||||
return sqlite3_create_function_v2(db, zName, nArg, SQLITE_UTF8 | flags, pApp,
|
||||
go_func, NULL, NULL, go_destroy);
|
||||
}
|
||||
|
||||
int sqlite3_create_aggregate_function_go(sqlite3 *db, const char *zName,
|
||||
int nArg, int flags, void *pApp) {
|
||||
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
|
||||
pApp, go_step, go_final, NULL, NULL,
|
||||
go_destroy);
|
||||
}
|
||||
|
||||
int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg,
|
||||
int flags, void *pApp) {
|
||||
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
|
||||
pApp, go_step, go_final, go_value,
|
||||
go_inverse, go_destroy);
|
||||
}
|
||||
|
||||
void sqlite3_set_auxdata_go(sqlite3_context *ctx, int iArg, void *pAux) {
|
||||
sqlite3_set_auxdata(ctx, iArg, pAux, go_destroy);
|
||||
}
|
||||
14
sqlite3/locking_mode.patch
Normal file
14
sqlite3/locking_mode.patch
Normal file
@@ -0,0 +1,14 @@
|
||||
# Use exclusive locking mode for WAL databases with v1 VFSes.
|
||||
--- sqlite3.c.orig
|
||||
+++ sqlite3.c
|
||||
@@ -63210,7 +63210,9 @@
|
||||
SQLITE_PRIVATE int sqlite3PagerWalSupported(Pager *pPager){
|
||||
const sqlite3_io_methods *pMethods = pPager->fd->pMethods;
|
||||
if( pPager->noLock ) return 0;
|
||||
- return pPager->exclusiveMode || (pMethods->iVersion>=2 && pMethods->xShmMap);
|
||||
+ if( pMethods->iVersion>=2 && pMethods->xShmMap ) return 1;
|
||||
+ pPager->exclusiveMode = 1;
|
||||
+ return 1;
|
||||
}
|
||||
|
||||
/*
|
||||
122
sqlite3/main.c
122
sqlite3/main.c
@@ -1,102 +1,28 @@
|
||||
#include <stdbool.h>
|
||||
#include <stdlib.h>
|
||||
#include <time.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
// Amalgamation
|
||||
#include "sqlite3.c"
|
||||
// VFS
|
||||
#include "vfs.c"
|
||||
// Extensions
|
||||
#include "ext/anycollseq.c"
|
||||
#include "ext/base64.c"
|
||||
#include "ext/decimal.c"
|
||||
#include "ext/regexp.c"
|
||||
#include "ext/series.c"
|
||||
#include "ext/uint.c"
|
||||
#include "ext/uuid.c"
|
||||
#include "func.c"
|
||||
#include "time.c"
|
||||
|
||||
int main() {
|
||||
int rc = sqlite3_initialize();
|
||||
if (rc != SQLITE_OK) return 1;
|
||||
__attribute__((constructor)) void init() {
|
||||
sqlite3_initialize();
|
||||
sqlite3_auto_extension((void (*)(void))sqlite3_base_init);
|
||||
sqlite3_auto_extension((void (*)(void))sqlite3_decimal_init);
|
||||
sqlite3_auto_extension((void (*)(void))sqlite3_regexp_init);
|
||||
sqlite3_auto_extension((void (*)(void))sqlite3_series_init);
|
||||
sqlite3_auto_extension((void (*)(void))sqlite3_uint_init);
|
||||
sqlite3_auto_extension((void (*)(void))sqlite3_uuid_init);
|
||||
sqlite3_auto_extension((void (*)(void))sqlite3_time_init);
|
||||
}
|
||||
|
||||
int go_localtime(sqlite3_int64, struct tm *);
|
||||
|
||||
int go_randomness(sqlite3_vfs *, int nByte, char *zOut);
|
||||
int go_sleep(sqlite3_vfs *, int microseconds);
|
||||
int go_current_time(sqlite3_vfs *, double *);
|
||||
int go_current_time_64(sqlite3_vfs *, sqlite3_int64 *);
|
||||
|
||||
int go_open(sqlite3_vfs *, sqlite3_filename zName, sqlite3_file *, int flags,
|
||||
int *pOutFlags);
|
||||
int go_delete(sqlite3_vfs *, const char *zName, int syncDir);
|
||||
int go_access(sqlite3_vfs *, const char *zName, int flags, int *pResOut);
|
||||
int go_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut);
|
||||
|
||||
struct go_file {
|
||||
sqlite3_file base;
|
||||
int id;
|
||||
int eLock;
|
||||
};
|
||||
|
||||
int go_close(sqlite3_file *);
|
||||
int go_read(sqlite3_file *, void *, int iAmt, sqlite3_int64 iOfst);
|
||||
int go_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
|
||||
int go_truncate(sqlite3_file *, sqlite3_int64 size);
|
||||
int go_sync(sqlite3_file *, int flags);
|
||||
int go_file_size(sqlite3_file *, sqlite3_int64 *pSize);
|
||||
int go_file_control(sqlite3_file *pFile, int op, void *pArg);
|
||||
|
||||
int go_lock(sqlite3_file *pFile, int eLock);
|
||||
int go_unlock(sqlite3_file *pFile, int eLock);
|
||||
int go_check_reserved_lock(sqlite3_file *pFile, int *pResOut);
|
||||
|
||||
static int no_lock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; }
|
||||
static int no_unlock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; }
|
||||
static int no_check_reserved_lock(sqlite3_file *pFile, int *pResOut) {
|
||||
*pResOut = 0;
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
static int no_file_control(sqlite3_file *pFile, int op, void *pArg) {
|
||||
return SQLITE_NOTFOUND;
|
||||
}
|
||||
static int no_sector_size(sqlite3_file *pFile) { return 0; }
|
||||
static int no_device_characteristics(sqlite3_file *pFile) { return 0; }
|
||||
|
||||
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
|
||||
return go_localtime((sqlite3_int64)*pTime, pTm);
|
||||
}
|
||||
|
||||
static int go_open_c(sqlite3_vfs *vfs, sqlite3_filename zName,
|
||||
sqlite3_file *file, int flags, int *pOutFlags) {
|
||||
static const sqlite3_io_methods go_io = {
|
||||
.iVersion = 1,
|
||||
.xClose = go_close,
|
||||
.xRead = go_read,
|
||||
.xWrite = go_write,
|
||||
.xTruncate = go_truncate,
|
||||
.xSync = go_sync,
|
||||
.xFileSize = go_file_size,
|
||||
.xLock = go_lock,
|
||||
.xUnlock = go_unlock,
|
||||
.xCheckReservedLock = go_check_reserved_lock,
|
||||
.xFileControl = no_file_control,
|
||||
.xSectorSize = no_sector_size,
|
||||
.xDeviceCharacteristics = no_device_characteristics,
|
||||
};
|
||||
int rc = go_open(vfs, zName, file, flags, pOutFlags);
|
||||
file->pMethods = (char)rc == SQLITE_OK ? &go_io : NULL;
|
||||
return rc;
|
||||
}
|
||||
|
||||
int sqlite3_os_init() {
|
||||
static sqlite3_vfs go_vfs = {
|
||||
.iVersion = 2,
|
||||
.szOsFile = sizeof(struct go_file),
|
||||
.mxPathname = 512,
|
||||
.zName = "go",
|
||||
|
||||
.xOpen = go_open_c,
|
||||
.xDelete = go_delete,
|
||||
.xAccess = go_access,
|
||||
.xFullPathname = go_full_pathname,
|
||||
|
||||
.xRandomness = go_randomness,
|
||||
.xSleep = go_sleep,
|
||||
.xCurrentTime = go_current_time,
|
||||
.xCurrentTimeInt64 = go_current_time_64,
|
||||
};
|
||||
return sqlite3_vfs_register(&go_vfs, /*default=*/true);
|
||||
}
|
||||
|
||||
sqlite3_destructor_type malloc_destructor = &free;
|
||||
@@ -28,18 +28,35 @@
|
||||
#define SQLITE_OMIT_AUTOINIT
|
||||
#define SQLITE_USE_ALLOCA
|
||||
|
||||
// Recommended Extensions
|
||||
// Other Options
|
||||
|
||||
// #define SQLITE_ENABLE_MATH_FUNCTIONS 1
|
||||
// #define SQLITE_ENABLE_FTS3 1
|
||||
// #define SQLITE_ENABLE_FTS3_PARENTHESIS 1
|
||||
// #define SQLITE_ENABLE_FTS4 1
|
||||
// #define SQLITE_ENABLE_FTS5 1
|
||||
// #define SQLITE_ENABLE_RTREE 1
|
||||
// #define SQLITE_ENABLE_GEOPOLY 1
|
||||
#define SQLITE_ALLOW_URI_AUTHORITY
|
||||
#define SQLITE_ENABLE_BATCH_ATOMIC_WRITE
|
||||
#define SQLITE_ENABLE_ATOMIC_WRITE
|
||||
#define SQLITE_OMIT_DESERIALIZE
|
||||
|
||||
// Need this to access WAL databases without the use of shared memory.
|
||||
#define SQLITE_DEFAULT_LOCKING_MODE 1
|
||||
// Because WASM does not support shared memory,
|
||||
// SQLite disables WAL for WASM builds.
|
||||
// We patch SQLite to use exclusive locking mode instead.
|
||||
// https://www.sqlite.org/wal.html#noshm
|
||||
#undef SQLITE_OMIT_WAL
|
||||
|
||||
// Implemented in Go.
|
||||
// Amalgamated Extensions
|
||||
|
||||
#define SQLITE_ENABLE_MATH_FUNCTIONS 1
|
||||
#define SQLITE_ENABLE_JSON1 1
|
||||
#define SQLITE_ENABLE_FTS3 1
|
||||
#define SQLITE_ENABLE_FTS3_PARENTHESIS 1
|
||||
#define SQLITE_ENABLE_FTS4 1
|
||||
#define SQLITE_ENABLE_FTS5 1
|
||||
#define SQLITE_ENABLE_RTREE 1
|
||||
#define SQLITE_ENABLE_GEOPOLY 1
|
||||
|
||||
// Session Extension
|
||||
// #define SQLITE_ENABLE_SESSION
|
||||
// #define SQLITE_ENABLE_PREUPDATE_HOOK
|
||||
|
||||
#define SQLITE_SOUNDEX
|
||||
|
||||
// Implemented in vfs.c.
|
||||
int localtime_s(struct tm *const pTm, time_t const *const pTime);
|
||||
32
sqlite3/time.c
Normal file
32
sqlite3/time.c
Normal file
@@ -0,0 +1,32 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
|
||||
static int time_collation(void *pArg, int nKey1, const void *pKey1, int nKey2,
|
||||
const void *pKey2) {
|
||||
// Remove a Z suffix if one key is no longer than the other.
|
||||
// A Z suffix collates before any character but after the empty string.
|
||||
// This avoids making different keys equal.
|
||||
const int nK1 = nKey1;
|
||||
const int nK2 = nKey2;
|
||||
const char *pK1 = (const char *)pKey1;
|
||||
const char *pK2 = (const char *)pKey2;
|
||||
if (nK1 && nK1 <= nK2 && pK1[nK1 - 1] == 'Z') {
|
||||
nKey1--;
|
||||
}
|
||||
if (nK2 && nK2 <= nK1 && pK2[nK2 - 1] == 'Z') {
|
||||
nKey2--;
|
||||
}
|
||||
|
||||
int n = nKey1 < nKey2 ? nKey1 : nKey2;
|
||||
int rc = memcmp(pKey1, pKey2, n);
|
||||
if (rc == 0) {
|
||||
rc = nKey1 - nKey2;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
int sqlite3_time_init(sqlite3 *db, char **pzErrMsg,
|
||||
const sqlite3_api_routines *pApi) {
|
||||
return sqlite3_create_collation(db, "time", SQLITE_UTF8, 0, time_collation);
|
||||
}
|
||||
137
sqlite3/vfs.c
Normal file
137
sqlite3/vfs.c
Normal file
@@ -0,0 +1,137 @@
|
||||
#include <time.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
|
||||
int go_localtime(struct tm *, sqlite3_int64);
|
||||
int go_vfs_find(const char *zVfsName);
|
||||
|
||||
int go_randomness(sqlite3_vfs *, int nByte, char *zOut);
|
||||
int go_sleep(sqlite3_vfs *, int microseconds);
|
||||
int go_current_time(sqlite3_vfs *, double *);
|
||||
int go_current_time_64(sqlite3_vfs *, sqlite3_int64 *);
|
||||
|
||||
int go_open(sqlite3_vfs *, sqlite3_filename zName, sqlite3_file *, int flags,
|
||||
int *pOutFlags);
|
||||
int go_delete(sqlite3_vfs *, const char *zName, int syncDir);
|
||||
int go_access(sqlite3_vfs *, const char *zName, int flags, int *pResOut);
|
||||
int go_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut);
|
||||
|
||||
int go_close(sqlite3_file *);
|
||||
int go_read(sqlite3_file *, void *, int iAmt, sqlite3_int64 iOfst);
|
||||
int go_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
|
||||
int go_truncate(sqlite3_file *, sqlite3_int64 size);
|
||||
int go_sync(sqlite3_file *, int flags);
|
||||
int go_file_size(sqlite3_file *, sqlite3_int64 *pSize);
|
||||
int go_file_control(sqlite3_file *, int op, void *pArg);
|
||||
int go_sector_size(sqlite3_file *file);
|
||||
int go_device_characteristics(sqlite3_file *file);
|
||||
|
||||
int go_lock(sqlite3_file *, int eLock);
|
||||
int go_unlock(sqlite3_file *, int eLock);
|
||||
int go_check_reserved_lock(sqlite3_file *, int *pResOut);
|
||||
|
||||
static int go_open_wrapper(sqlite3_vfs *vfs, sqlite3_filename zName,
|
||||
sqlite3_file *file, int flags, int *pOutFlags) {
|
||||
static const sqlite3_io_methods os_io = {
|
||||
.iVersion = 1,
|
||||
.xClose = go_close,
|
||||
.xRead = go_read,
|
||||
.xWrite = go_write,
|
||||
.xTruncate = go_truncate,
|
||||
.xSync = go_sync,
|
||||
.xFileSize = go_file_size,
|
||||
.xLock = go_lock,
|
||||
.xUnlock = go_unlock,
|
||||
.xCheckReservedLock = go_check_reserved_lock,
|
||||
.xFileControl = go_file_control,
|
||||
.xSectorSize = go_sector_size,
|
||||
.xDeviceCharacteristics = go_device_characteristics,
|
||||
};
|
||||
memset(file, 0, vfs->szOsFile);
|
||||
int rc = go_open(vfs, zName, file, flags, pOutFlags);
|
||||
if (rc) {
|
||||
return rc;
|
||||
}
|
||||
file->pMethods = &os_io;
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
struct go_file {
|
||||
sqlite3_file base;
|
||||
int handle;
|
||||
};
|
||||
|
||||
int sqlite3_os_init() {
|
||||
static sqlite3_vfs os_vfs = {
|
||||
.iVersion = 2,
|
||||
.szOsFile = sizeof(struct go_file),
|
||||
.mxPathname = 512,
|
||||
.zName = "os",
|
||||
|
||||
.xOpen = go_open_wrapper,
|
||||
.xDelete = go_delete,
|
||||
.xAccess = go_access,
|
||||
.xFullPathname = go_full_pathname,
|
||||
|
||||
.xRandomness = go_randomness,
|
||||
.xSleep = go_sleep,
|
||||
.xCurrentTime = go_current_time,
|
||||
.xCurrentTimeInt64 = go_current_time_64,
|
||||
};
|
||||
return sqlite3_vfs_register(&os_vfs, /*default=*/true);
|
||||
}
|
||||
|
||||
sqlite3_destructor_type malloc_destructor = &free;
|
||||
|
||||
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
|
||||
return go_localtime(pTm, (sqlite3_int64)*pTime);
|
||||
}
|
||||
|
||||
sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
|
||||
if (zVfsName) {
|
||||
static sqlite3_vfs *go_vfs_list;
|
||||
sqlite3_vfs *found = NULL;
|
||||
for (sqlite3_vfs **next = &go_vfs_list; *next;) {
|
||||
sqlite3_vfs *it = *next;
|
||||
if (go_vfs_find(it->zName)) {
|
||||
if (!strcmp(zVfsName, it->zName)) found = it;
|
||||
next = &it->pNext;
|
||||
} else {
|
||||
*next = it->pNext;
|
||||
free(it);
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
return found;
|
||||
}
|
||||
if (go_vfs_find(zVfsName)) {
|
||||
sqlite3_vfs *prev = go_vfs_list;
|
||||
go_vfs_list = malloc(sizeof(sqlite3_vfs) + strlen(zVfsName) + 1);
|
||||
char *name = (char *)(go_vfs_list + 1);
|
||||
strcpy(name, zVfsName);
|
||||
*go_vfs_list = (sqlite3_vfs){
|
||||
.iVersion = 2,
|
||||
.szOsFile = sizeof(struct go_file),
|
||||
.mxPathname = 512,
|
||||
.zName = name,
|
||||
.pNext = prev,
|
||||
|
||||
.xOpen = go_open_wrapper,
|
||||
.xDelete = go_delete,
|
||||
.xAccess = go_access,
|
||||
.xFullPathname = go_full_pathname,
|
||||
|
||||
.xRandomness = go_randomness,
|
||||
.xSleep = go_sleep,
|
||||
.xCurrentTime = go_current_time,
|
||||
.xCurrentTimeInt64 = go_current_time_64,
|
||||
};
|
||||
return go_vfs_list;
|
||||
}
|
||||
}
|
||||
return sqlite3_vfs_find_orig(zVfsName);
|
||||
}
|
||||
|
||||
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");
|
||||
static_assert(offsetof(sqlite3_vfs, zName) == 16, "Unexpected offset");
|
||||
static_assert(offsetof(sqlite3, u1.isInterrupted) == 288, "Unexpected offset");
|
||||
12
sqlite3/vfs_find.patch
Normal file
12
sqlite3/vfs_find.patch
Normal file
@@ -0,0 +1,12 @@
|
||||
# Wrap sqlite3_vfs_find.
|
||||
--- sqlite3.c.orig
|
||||
+++ sqlite3.c
|
||||
@@ -25394,7 +25394,7 @@
|
||||
** Locate a VFS by name. If no name is given, simply return the
|
||||
** first VFS on the list.
|
||||
*/
|
||||
-SQLITE_API sqlite3_vfs *sqlite3_vfs_find(const char *zVfs){
|
||||
+SQLITE_API sqlite3_vfs *sqlite3_vfs_find_orig(const char *zVfs){
|
||||
sqlite3_vfs *pVfs = 0;
|
||||
#if SQLITE_THREADSAFE
|
||||
sqlite3_mutex *mutex;
|
||||
221
sqlite_test.go
Normal file
221
sqlite_test.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Path = "./embed/sqlite3.wasm"
|
||||
}
|
||||
|
||||
func Test_sqlite_error_OOM(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer sqlite.close()
|
||||
|
||||
defer func() { _ = recover() }()
|
||||
sqlite.error(uint64(NOMEM), 0)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func Test_sqlite_call_closed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sqlite.close()
|
||||
|
||||
defer func() { _ = recover() }()
|
||||
sqlite.call(sqlite.api.free)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func Test_sqlite_new(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer sqlite.close()
|
||||
|
||||
t.Run("MaxUint32", func(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
sqlite.new(math.MaxUint32)
|
||||
t.Error("want panic")
|
||||
})
|
||||
|
||||
t.Run("_MAX_ALLOCATION_SIZE", func(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
sqlite.new(_MAX_ALLOCATION_SIZE)
|
||||
sqlite.new(_MAX_ALLOCATION_SIZE)
|
||||
t.Error("want panic")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_sqlite_newArena(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer sqlite.close()
|
||||
|
||||
arena := sqlite.newArena(16)
|
||||
defer arena.free()
|
||||
|
||||
const title = "Lorem ipsum"
|
||||
ptr := arena.string(title)
|
||||
if ptr == 0 {
|
||||
t.Fatalf("got nullptr")
|
||||
}
|
||||
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != title {
|
||||
t.Errorf("got %q, want %q", got, title)
|
||||
}
|
||||
|
||||
const body = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
|
||||
ptr = arena.string(body)
|
||||
if ptr == 0 {
|
||||
t.Fatalf("got nullptr")
|
||||
}
|
||||
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != body {
|
||||
t.Errorf("got %q, want %q", got, body)
|
||||
}
|
||||
|
||||
ptr = arena.bytes(nil)
|
||||
if ptr != 0 {
|
||||
t.Errorf("want nullptr")
|
||||
}
|
||||
ptr = arena.bytes([]byte(title))
|
||||
if ptr == 0 {
|
||||
t.Fatalf("got nullptr")
|
||||
}
|
||||
if got := util.View(sqlite.mod, ptr, uint64(len(title))); string(got) != title {
|
||||
t.Errorf("got %q, want %q", got, title)
|
||||
}
|
||||
|
||||
arena.free()
|
||||
}
|
||||
|
||||
func Test_sqlite_newBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer sqlite.close()
|
||||
|
||||
ptr := sqlite.newBytes(nil)
|
||||
if ptr != 0 {
|
||||
t.Errorf("got %#x, want nullptr", ptr)
|
||||
}
|
||||
|
||||
buf := []byte("sqlite3")
|
||||
ptr = sqlite.newBytes(buf)
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
want := buf
|
||||
if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sqlite_newString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer sqlite.close()
|
||||
|
||||
ptr := sqlite.newString("")
|
||||
if ptr == 0 {
|
||||
t.Error("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
str := "sqlite3\000sqlite3"
|
||||
ptr = sqlite.newString(str)
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
want := str + "\000"
|
||||
if got := util.View(sqlite.mod, ptr, uint64(len(want))); string(got) != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sqlite_getString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer sqlite.close()
|
||||
|
||||
ptr := sqlite.newString("")
|
||||
if ptr == 0 {
|
||||
t.Error("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
str := "sqlite3" + "\000 drop this"
|
||||
ptr = sqlite.newString(str)
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
want := "sqlite3"
|
||||
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
if got := util.ReadString(sqlite.mod, ptr, 0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() { _ = recover() }()
|
||||
util.ReadString(sqlite.mod, ptr, uint32(len(want)/2))
|
||||
t.Error("want panic")
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() { _ = recover() }()
|
||||
util.ReadString(sqlite.mod, 0, math.MaxUint32)
|
||||
t.Error("want panic")
|
||||
}()
|
||||
}
|
||||
|
||||
func Test_sqlite_free(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer sqlite.close()
|
||||
|
||||
sqlite.free(0)
|
||||
|
||||
ptr := sqlite.new(1)
|
||||
if ptr == 0 {
|
||||
t.Error("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
sqlite.free(ptr)
|
||||
}
|
||||
317
stmt.go
317
stmt.go
@@ -2,6 +2,9 @@ package sqlite3
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Stmt is a prepared statement object.
|
||||
@@ -9,48 +12,41 @@ import (
|
||||
// https://www.sqlite.org/c3ref/stmt.html
|
||||
type Stmt struct {
|
||||
c *Conn
|
||||
handle uint32
|
||||
err error
|
||||
handle uint32
|
||||
}
|
||||
|
||||
// Close destroys the prepared statement object.
|
||||
//
|
||||
// It is safe to close a nil, zero or closed Stmt.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/finalize.html
|
||||
func (s *Stmt) Close() error {
|
||||
if s == nil || s.handle == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
r, err := s.c.api.finalize.Call(s.c.ctx, uint64(s.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
r := s.c.call(s.c.api.finalize, uint64(s.handle))
|
||||
|
||||
s.handle = 0
|
||||
return s.c.error(r[0])
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// Reset resets the prepared statement object.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/reset.html
|
||||
func (s *Stmt) Reset() error {
|
||||
r, err := s.c.api.reset.Call(s.c.ctx, uint64(s.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
r := s.c.call(s.c.api.reset, uint64(s.handle))
|
||||
s.err = nil
|
||||
return s.c.error(r[0])
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// ClearBindings resets all bindings on the prepared statement.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/clear_bindings.html
|
||||
func (s *Stmt) ClearBindings() error {
|
||||
r, err := s.c.api.clearBindings.Call(s.c.ctx, uint64(s.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return s.c.error(r[0])
|
||||
r := s.c.call(s.c.api.clearBindings, uint64(s.handle))
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// Step evaluates the SQL statement.
|
||||
@@ -63,17 +59,15 @@ func (s *Stmt) ClearBindings() error {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/step.html
|
||||
func (s *Stmt) Step() bool {
|
||||
r, err := s.c.api.step.Call(s.c.ctx, uint64(s.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if r[0] == _ROW {
|
||||
s.c.checkInterrupt()
|
||||
r := s.c.call(s.c.api.step, uint64(s.handle))
|
||||
switch r {
|
||||
case _ROW:
|
||||
return true
|
||||
}
|
||||
if r[0] == _DONE {
|
||||
case _DONE:
|
||||
s.err = nil
|
||||
} else {
|
||||
s.err = s.c.error(r[0])
|
||||
default:
|
||||
s.err = s.c.error(r)
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -98,12 +92,9 @@ func (s *Stmt) Exec() error {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_parameter_count.html
|
||||
func (s *Stmt) BindCount() int {
|
||||
r, err := s.c.api.bindCount.Call(s.c.ctx,
|
||||
r := s.c.call(s.c.api.bindCount,
|
||||
uint64(s.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int(r[0])
|
||||
return int(r)
|
||||
}
|
||||
|
||||
// BindIndex returns the index of a parameter in the prepared statement
|
||||
@@ -113,12 +104,9 @@ func (s *Stmt) BindCount() int {
|
||||
func (s *Stmt) BindIndex(name string) int {
|
||||
defer s.c.arena.reset()
|
||||
namePtr := s.c.arena.string(name)
|
||||
r, err := s.c.api.bindIndex.Call(s.c.ctx,
|
||||
r := s.c.call(s.c.api.bindIndex,
|
||||
uint64(s.handle), uint64(namePtr))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int(r[0])
|
||||
return int(r)
|
||||
}
|
||||
|
||||
// BindName returns the name of a parameter in the prepared statement.
|
||||
@@ -126,17 +114,14 @@ func (s *Stmt) BindIndex(name string) int {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_parameter_name.html
|
||||
func (s *Stmt) BindName(param int) string {
|
||||
r, err := s.c.api.bindName.Call(s.c.ctx,
|
||||
r := s.c.call(s.c.api.bindName,
|
||||
uint64(s.handle), uint64(param))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ptr := uint32(r[0])
|
||||
ptr := uint32(r)
|
||||
if ptr == 0 {
|
||||
return ""
|
||||
}
|
||||
return s.c.mem.readString(ptr, _MAX_STRING)
|
||||
return util.ReadString(s.c.mod, ptr, _MAX_STRING)
|
||||
}
|
||||
|
||||
// BindBool binds a bool to the prepared statement.
|
||||
@@ -146,10 +131,11 @@ func (s *Stmt) BindName(param int) string {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindBool(param int, value bool) error {
|
||||
var i int64
|
||||
if value {
|
||||
return s.BindInt64(param, 1)
|
||||
i = 1
|
||||
}
|
||||
return s.BindInt64(param, 0)
|
||||
return s.BindInt64(param, i)
|
||||
}
|
||||
|
||||
// BindInt binds an int to the prepared statement.
|
||||
@@ -165,12 +151,9 @@ func (s *Stmt) BindInt(param int, value int) error {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindInt64(param int, value int64) error {
|
||||
r, err := s.c.api.bindInteger.Call(s.c.ctx,
|
||||
r := s.c.call(s.c.api.bindInteger,
|
||||
uint64(s.handle), uint64(param), uint64(value))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return s.c.error(r[0])
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindFloat binds a float64 to the prepared statement.
|
||||
@@ -178,12 +161,9 @@ func (s *Stmt) BindInt64(param int, value int64) error {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindFloat(param int, value float64) error {
|
||||
r, err := s.c.api.bindFloat.Call(s.c.ctx,
|
||||
r := s.c.call(s.c.api.bindFloat,
|
||||
uint64(s.handle), uint64(param), math.Float64bits(value))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return s.c.error(r[0])
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindText binds a string to the prepared statement.
|
||||
@@ -192,14 +172,11 @@ func (s *Stmt) BindFloat(param int, value float64) error {
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindText(param int, value string) error {
|
||||
ptr := s.c.newString(value)
|
||||
r, err := s.c.api.bindText.Call(s.c.ctx,
|
||||
r := s.c.call(s.c.api.bindText,
|
||||
uint64(s.handle), uint64(param),
|
||||
uint64(ptr), uint64(len(value)),
|
||||
s.c.api.destructor, _UTF8)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return s.c.error(r[0])
|
||||
uint64(s.c.api.destructor), _UTF8)
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindBlob binds a []byte to the prepared statement.
|
||||
@@ -209,14 +186,21 @@ func (s *Stmt) BindText(param int, value string) error {
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindBlob(param int, value []byte) error {
|
||||
ptr := s.c.newBytes(value)
|
||||
r, err := s.c.api.bindBlob.Call(s.c.ctx,
|
||||
r := s.c.call(s.c.api.bindBlob,
|
||||
uint64(s.handle), uint64(param),
|
||||
uint64(ptr), uint64(len(value)),
|
||||
s.c.api.destructor)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return s.c.error(r[0])
|
||||
uint64(s.c.api.destructor))
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindZeroBlob binds a zero-filled, length n BLOB to the prepared statement.
|
||||
// The leftmost SQL parameter has an index of 1.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindZeroBlob(param int, n int64) error {
|
||||
r := s.c.call(s.c.api.bindZeroBlob,
|
||||
uint64(s.handle), uint64(param), uint64(n))
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindNull binds a NULL to the prepared statement.
|
||||
@@ -224,24 +208,53 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindNull(param int) error {
|
||||
r, err := s.c.api.bindNull.Call(s.c.ctx,
|
||||
r := s.c.call(s.c.api.bindNull,
|
||||
uint64(s.handle), uint64(param))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindTime binds a [time.Time] to the prepared statement.
|
||||
// The leftmost SQL parameter has an index of 1.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
|
||||
if format == TimeFormatDefault {
|
||||
return s.bindRFC3339Nano(param, value)
|
||||
}
|
||||
return s.c.error(r[0])
|
||||
switch v := format.Encode(value).(type) {
|
||||
case string:
|
||||
s.BindText(param, v)
|
||||
case int64:
|
||||
s.BindInt64(param, v)
|
||||
case float64:
|
||||
s.BindFloat(param, v)
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
|
||||
const maxlen = uint64(len(time.RFC3339Nano))
|
||||
|
||||
ptr := s.c.new(maxlen)
|
||||
buf := util.View(s.c.mod, ptr, maxlen)
|
||||
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
|
||||
|
||||
r := s.c.call(s.c.api.bindText,
|
||||
uint64(s.handle), uint64(param),
|
||||
uint64(ptr), uint64(len(buf)),
|
||||
uint64(s.c.api.destructor), _UTF8)
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// ColumnCount returns the number of columns in a result set.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_count.html
|
||||
func (s *Stmt) ColumnCount() int {
|
||||
r, err := s.c.api.columnCount.Call(s.c.ctx,
|
||||
r := s.c.call(s.c.api.columnCount,
|
||||
uint64(s.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int(r[0])
|
||||
return int(r)
|
||||
}
|
||||
|
||||
// ColumnName returns the name of the result column.
|
||||
@@ -249,17 +262,14 @@ func (s *Stmt) ColumnCount() int {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_name.html
|
||||
func (s *Stmt) ColumnName(col int) string {
|
||||
r, err := s.c.api.columnName.Call(s.c.ctx,
|
||||
r := s.c.call(s.c.api.columnName,
|
||||
uint64(s.handle), uint64(col))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ptr := uint32(r[0])
|
||||
ptr := uint32(r)
|
||||
if ptr == 0 {
|
||||
return ""
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
return s.c.mem.readString(ptr, _MAX_STRING)
|
||||
return util.ReadString(s.c.mod, ptr, _MAX_STRING)
|
||||
}
|
||||
|
||||
// ColumnType returns the initial [Datatype] of the result column.
|
||||
@@ -267,12 +277,9 @@ func (s *Stmt) ColumnName(col int) string {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_blob.html
|
||||
func (s *Stmt) ColumnType(col int) Datatype {
|
||||
r, err := s.c.api.columnType.Call(s.c.ctx,
|
||||
r := s.c.call(s.c.api.columnType,
|
||||
uint64(s.handle), uint64(col))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return Datatype(r[0])
|
||||
return Datatype(r)
|
||||
}
|
||||
|
||||
// ColumnBool returns the value of the result column as a bool.
|
||||
@@ -302,12 +309,9 @@ func (s *Stmt) ColumnInt(col int) int {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_blob.html
|
||||
func (s *Stmt) ColumnInt64(col int) int64 {
|
||||
r, err := s.c.api.columnInteger.Call(s.c.ctx,
|
||||
r := s.c.call(s.c.api.columnInteger,
|
||||
uint64(s.handle), uint64(col))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(r[0])
|
||||
return int64(r)
|
||||
}
|
||||
|
||||
// ColumnFloat returns the value of the result column as a float64.
|
||||
@@ -315,12 +319,34 @@ func (s *Stmt) ColumnInt64(col int) int64 {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_blob.html
|
||||
func (s *Stmt) ColumnFloat(col int) float64 {
|
||||
r, err := s.c.api.columnFloat.Call(s.c.ctx,
|
||||
r := s.c.call(s.c.api.columnFloat,
|
||||
uint64(s.handle), uint64(col))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return math.Float64frombits(r)
|
||||
}
|
||||
|
||||
// ColumnTime returns the value of the result column as a [time.Time].
|
||||
// The leftmost column of the result set has the index 0.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_blob.html
|
||||
func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time {
|
||||
var v any
|
||||
switch s.ColumnType(col) {
|
||||
case INTEGER:
|
||||
v = s.ColumnInt64(col)
|
||||
case FLOAT:
|
||||
v = s.ColumnFloat(col)
|
||||
case TEXT, BLOB:
|
||||
v = s.ColumnText(col)
|
||||
case NULL:
|
||||
return time.Time{}
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
return math.Float64frombits(r[0])
|
||||
t, err := format.Decode(v)
|
||||
if err != nil {
|
||||
s.err = err
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// ColumnText returns the value of the result column as a string.
|
||||
@@ -328,30 +354,7 @@ func (s *Stmt) ColumnFloat(col int) float64 {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_blob.html
|
||||
func (s *Stmt) ColumnText(col int) string {
|
||||
r, err := s.c.api.columnText.Call(s.c.ctx,
|
||||
uint64(s.handle), uint64(col))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ptr := uint32(r[0])
|
||||
if ptr == 0 {
|
||||
r, err = s.c.api.errcode.Call(s.c.ctx, uint64(s.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
s.err = s.c.error(r[0])
|
||||
return ""
|
||||
}
|
||||
|
||||
r, err = s.c.api.columnBytes.Call(s.c.ctx,
|
||||
uint64(s.handle), uint64(col))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
mem := s.c.mem.view(ptr, uint32(r[0]))
|
||||
return string(mem)
|
||||
return string(s.ColumnRawText(col))
|
||||
}
|
||||
|
||||
// ColumnBlob appends to buf and returns
|
||||
@@ -360,28 +363,56 @@ func (s *Stmt) ColumnText(col int) string {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_blob.html
|
||||
func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
|
||||
r, err := s.c.api.columnBlob.Call(s.c.ctx,
|
||||
uint64(s.handle), uint64(col))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ptr := uint32(r[0])
|
||||
if ptr == 0 {
|
||||
r, err = s.c.api.errcode.Call(s.c.ctx, uint64(s.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
s.err = s.c.error(r[0])
|
||||
return buf[0:0]
|
||||
}
|
||||
|
||||
r, err = s.c.api.columnBytes.Call(s.c.ctx,
|
||||
uint64(s.handle), uint64(col))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
mem := s.c.mem.view(ptr, uint32(r[0]))
|
||||
return append(buf[0:0], mem...)
|
||||
return append(buf, s.ColumnRawBlob(col)...)
|
||||
}
|
||||
|
||||
// ColumnRawText returns the value of the result column as a []byte.
|
||||
// The []byte is owned by SQLite and may be invalidated by
|
||||
// subsequent calls to [Stmt] methods.
|
||||
// The leftmost column of the result set has the index 0.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_blob.html
|
||||
func (s *Stmt) ColumnRawText(col int) []byte {
|
||||
r := s.c.call(s.c.api.columnText,
|
||||
uint64(s.handle), uint64(col))
|
||||
return s.columnRawBytes(col, uint32(r))
|
||||
}
|
||||
|
||||
// ColumnRawBlob returns the value of the result column as a []byte.
|
||||
// The []byte is owned by SQLite and may be invalidated by
|
||||
// subsequent calls to [Stmt] methods.
|
||||
// The leftmost column of the result set has the index 0.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_blob.html
|
||||
func (s *Stmt) ColumnRawBlob(col int) []byte {
|
||||
r := s.c.call(s.c.api.columnBlob,
|
||||
uint64(s.handle), uint64(col))
|
||||
return s.columnRawBytes(col, uint32(r))
|
||||
}
|
||||
|
||||
func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte {
|
||||
if ptr == 0 {
|
||||
r := s.c.call(s.c.api.errcode, uint64(s.c.handle))
|
||||
s.err = s.c.error(r)
|
||||
return nil
|
||||
}
|
||||
|
||||
r := s.c.call(s.c.api.columnBytes,
|
||||
uint64(s.handle), uint64(col))
|
||||
return util.View(s.c.mod, ptr, r)
|
||||
}
|
||||
|
||||
// Return true if stmt is an empty SQL statement.
|
||||
// This is used as an optimization.
|
||||
// It's OK to always return false here.
|
||||
func emptyStatement(stmt string) bool {
|
||||
for _, b := range []byte(stmt) {
|
||||
switch b {
|
||||
case ' ', '\n', '\r', '\t', '\v', '\f':
|
||||
case ';':
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
430
stmt_test.go
430
stmt_test.go
@@ -1,400 +1,60 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStmt(t *testing.T) {
|
||||
func Test_emptyStatement(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
stmt string
|
||||
want bool
|
||||
}{
|
||||
{"empty", "", true},
|
||||
{"space", " ", true},
|
||||
{"separator", ";\n ", true},
|
||||
{"begin", "BEGIN", false},
|
||||
{"select", "SELECT 1;", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := emptyStatement(tt.stmt); got != tt.want {
|
||||
t.Errorf("got %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Fuzz_emptyStatement(f *testing.F) {
|
||||
f.Add("")
|
||||
f.Add(" ")
|
||||
f.Add(";\n ")
|
||||
f.Add("; ;\v")
|
||||
f.Add("BEGIN")
|
||||
f.Add("SELECT 1;")
|
||||
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
f.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`INSERT INTO test(col) VALUES(?)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if got := stmt.BindCount(); got != 1 {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
|
||||
err = stmt.BindBool(1, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.ClearBindings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.BindBool(1, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.BindInt(1, 2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.BindFloat(1, math.Pi)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.BindNull(1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.BindText(1, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.BindText(1, "text")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.BindBlob(1, []byte("blob"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.BindBlob(1, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// The table should have: 0, NULL, 1, 2, π, NULL, "", "text", `blob`, NULL
|
||||
stmt, _, err = db.Prepare(`SELECT col FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
f.Fuzz(func(t *testing.T, sql string) {
|
||||
// If empty, SQLite parses it as empty.
|
||||
if emptyStatement(sql) {
|
||||
stmt, tail, err := db.Prepare(sql)
|
||||
if err != nil {
|
||||
t.Errorf("%q, %v", sql, err)
|
||||
}
|
||||
if stmt != nil {
|
||||
t.Errorf("%q, %v", sql, stmt)
|
||||
}
|
||||
if tail != "" {
|
||||
t.Errorf("%q", sql)
|
||||
}
|
||||
stmt.Close()
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "0" {
|
||||
t.Errorf("got %q, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "0" {
|
||||
t.Errorf("got %q, want zero", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != true {
|
||||
t.Errorf("got %v, want true", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 1 {
|
||||
t.Errorf("got %v, want one", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 1 {
|
||||
t.Errorf("got %v, want one", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "1" {
|
||||
t.Errorf("got %q, want one", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "1" {
|
||||
t.Errorf("got %q, want one", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != true {
|
||||
t.Errorf("got %v, want true", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 2 {
|
||||
t.Errorf("got %v, want two", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 2 {
|
||||
t.Errorf("got %v, want two", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "2" {
|
||||
t.Errorf("got %q, want two", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "2" {
|
||||
t.Errorf("got %q, want two", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != FLOAT {
|
||||
t.Errorf("got %v, want FLOAT", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != true {
|
||||
t.Errorf("got %v, want true", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 3 {
|
||||
t.Errorf("got %v, want three", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != math.Pi {
|
||||
t.Errorf("got %v, want π", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "3.14159265358979" {
|
||||
t.Errorf("got %q, want π", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "3.14159265358979" {
|
||||
t.Errorf("got %q, want π", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "text" {
|
||||
t.Errorf(`got %q, want "text"`, got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "text" {
|
||||
t.Errorf(`got %q, want "text"`, got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != BLOB {
|
||||
t.Errorf("got %v, want BLOB", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "blob" {
|
||||
t.Errorf(`got %q, want "blob"`, got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "blob" {
|
||||
t.Errorf(`got %q, want "blob"`, got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStmt_Close(t *testing.T) {
|
||||
var stmt *Stmt
|
||||
stmt.Close()
|
||||
}
|
||||
|
||||
func TestStmt_BindName(t *testing.T) {
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
want := []string{"", "", "", "", "?5", ":AAA", "@AAA", "$AAA"}
|
||||
stmt, _, err := db.Prepare(`SELECT ?, ?5, :AAA, @AAA, $AAA`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if got := stmt.BindCount(); got != len(want) {
|
||||
t.Errorf("got %d, want %d", got, len(want))
|
||||
}
|
||||
|
||||
for i, name := range want {
|
||||
id := i + 1
|
||||
if got := stmt.BindName(id); got != name {
|
||||
t.Errorf("got %q, want %q", got, name)
|
||||
}
|
||||
if name == "" {
|
||||
id = 0
|
||||
}
|
||||
if got := stmt.BindIndex(name); got != id {
|
||||
t.Errorf("got %d, want %d", got, id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
169
tests/backup_test.go
Normal file
169
tests/backup_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestBackup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backupName := filepath.Join(t.TempDir(), "backup.db")
|
||||
|
||||
func() { // Create backup.
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Backup("main", backupName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
func() { // Restore backup.
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Restore("main", backupName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
row := 0
|
||||
ids := []int{0, 1, 2}
|
||||
names := []string{"go", "zig", "whatever"}
|
||||
for ; stmt.Step(); row++ {
|
||||
id := stmt.ColumnInt(0)
|
||||
name := stmt.ColumnText(1)
|
||||
|
||||
if id != ids[row] {
|
||||
t.Errorf("got %d, want %d", id, ids[row])
|
||||
}
|
||||
if name != names[row] {
|
||||
t.Errorf("got %q, want %q", name, names[row])
|
||||
}
|
||||
}
|
||||
if row != 3 {
|
||||
t.Errorf("got %d, want %d", row, len(ids))
|
||||
}
|
||||
|
||||
if err := stmt.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
func() { // Errors.
|
||||
db, err := sqlite3.Open(backupName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
tx, err := db.BeginExclusive()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Restore("main", backupName)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Restore("main", backupName)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
func() { // Incremental.
|
||||
db, err := sqlite3.Open(backupName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
b, err := db.BackupInit("main", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer b.Close()
|
||||
|
||||
done, err := b.Step(1)
|
||||
if done {
|
||||
t.Error("want false")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
n := b.Remaining()
|
||||
if n != 1 {
|
||||
t.Errorf("got %d", n)
|
||||
}
|
||||
|
||||
n = b.PageCount()
|
||||
if n != 2 {
|
||||
t.Errorf("got %d", n)
|
||||
}
|
||||
|
||||
err = b.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
418
tests/blob_test.go
Normal file
418
tests/blob_test.go
Normal file
@@ -0,0 +1,418 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/adler32"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestBlob(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer blob.Close()
|
||||
|
||||
size := blob.Size()
|
||||
if size != 1024 {
|
||||
t.Errorf("got %d, want 1024", size)
|
||||
}
|
||||
|
||||
var data [1280]byte
|
||||
_, err = rand.Read(data[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = blob.Write(data[:size/2])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
n, err := blob.Write(data[:])
|
||||
if n != 0 || !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Fatalf("got (%d, %v), want (0, ERROR)", n, err)
|
||||
}
|
||||
|
||||
_, err = blob.Write(data[size/2 : size])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = blob.Seek(size/4, io.SeekStart)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, err := io.ReadAll(blob); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if !bytes.Equal(got, data[size/4:size]) {
|
||||
t.Errorf("got %q, want %q", got, data[size/4:size])
|
||||
}
|
||||
|
||||
if n, err := blob.Read(make([]byte, 1)); n != 0 || err != io.EOF {
|
||||
t.Errorf("got (%d, %v), want (0, EOF)", n, err)
|
||||
}
|
||||
|
||||
if err := blob.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlob_large(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1000000))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer blob.Close()
|
||||
|
||||
size := blob.Size()
|
||||
if size != 1000000 {
|
||||
t.Errorf("got %d, want 1000000", size)
|
||||
}
|
||||
|
||||
hash := adler32.New()
|
||||
_, err = io.CopyN(blob, io.TeeReader(rand.Reader, hash), 1000000)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = blob.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
want := hash.Sum32()
|
||||
hash.Reset()
|
||||
_, err = io.Copy(hash, blob)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got := hash.Sum32(); got != want {
|
||||
t.Fatalf("got %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if err := blob.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlob_overflow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer blob.Close()
|
||||
|
||||
n, err := blob.ReadFrom(rand.Reader)
|
||||
if n != 1024 || !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Fatalf("got (%d, %v), want (0, ERROR)", n, err)
|
||||
}
|
||||
|
||||
n, err = blob.ReadFrom(rand.Reader)
|
||||
if n != 0 || !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Fatalf("got (%d, %v), want (0, ERROR)", n, err)
|
||||
}
|
||||
|
||||
_, err = blob.Seek(-128, io.SeekEnd)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
n, err = blob.WriteTo(io.Discard)
|
||||
if n != 128 || err != nil {
|
||||
t.Fatalf("got (%d, %v), want (128, nil)", n, err)
|
||||
}
|
||||
|
||||
n, err = blob.WriteTo(io.Discard)
|
||||
if n != 0 || err != nil {
|
||||
t.Fatalf("got (%d, %v), want (0, nil)", n, err)
|
||||
}
|
||||
|
||||
if err := blob.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlob_invalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.OpenBlob("", "test", "col", db.LastInsertRowID(), false)
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlob_Write_readonly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer blob.Close()
|
||||
|
||||
_, err = blob.Write([]byte("data"))
|
||||
if !errors.Is(err, sqlite3.READONLY) {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlob_Read_expired(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer blob.Close()
|
||||
|
||||
err = db.Exec(`DELETE FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = io.ReadAll(blob)
|
||||
if !errors.Is(err, sqlite3.ABORT) {
|
||||
t.Fatal("want error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlob_Seek(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer blob.Close()
|
||||
|
||||
_, err = blob.Seek(0, 10)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
|
||||
_, err = blob.Seek(-1, io.SeekCurrent)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
|
||||
n, err := blob.Seek(1, io.SeekEnd)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != blob.Size()+1 {
|
||||
t.Errorf("got %d, want %d", n, blob.Size())
|
||||
}
|
||||
|
||||
_, err = blob.Write([]byte("data"))
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlob_Reopen(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var rowids []int64
|
||||
for i := 0; i < 100; i++ {
|
||||
err = db.Exec(`INSERT INTO test VALUES (zeroblob(10))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rowids = append(rowids, db.LastInsertRowID())
|
||||
}
|
||||
|
||||
var blob *sqlite3.Blob
|
||||
|
||||
for i, rowid := range rowids {
|
||||
if i > 0 {
|
||||
err = blob.Reopen(rowid)
|
||||
} else {
|
||||
blob, err = db.OpenBlob("main", "test", "col", rowid, true)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = fmt.Fprintf(blob, "blob %d\n", i)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if err := blob.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for i, rowid := range rowids {
|
||||
if i > 0 {
|
||||
err = blob.Reopen(rowid)
|
||||
} else {
|
||||
blob, err = db.OpenBlob("main", "test", "col", rowid, false)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var got int
|
||||
_, err = fmt.Fscanf(blob, "blob %d\n", &got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != i {
|
||||
t.Errorf("got %d, want %d", got, i)
|
||||
}
|
||||
}
|
||||
if err := blob.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
182
tests/bradfitz/sql_test.go
Normal file
182
tests/bradfitz/sql_test.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package bradfitz
|
||||
|
||||
// Adapted from: https://github.com/bradfitz/go-sql-test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
type Tester interface {
|
||||
RunTest(*testing.T, func(params))
|
||||
}
|
||||
|
||||
var (
|
||||
sqlite Tester = sqliteDB{}
|
||||
)
|
||||
|
||||
const TablePrefix = "gosqltest_"
|
||||
|
||||
type sqliteDB struct{}
|
||||
|
||||
type params struct {
|
||||
dbType Tester
|
||||
*testing.T
|
||||
*sql.DB
|
||||
}
|
||||
|
||||
func (t params) mustExec(sql string, args ...interface{}) sql.Result {
|
||||
res, err := t.DB.Exec(sql, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("Error running %q: %v", sql, err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (sqliteDB) RunTest(t *testing.T, fn func(params)) {
|
||||
db, err := sql.Open("sqlite3", "file:"+
|
||||
filepath.Join(t.TempDir(), "foo.db")+
|
||||
"?_pragma=busy_timeout(10000)&_pragma=synchronous(off)")
|
||||
if err != nil {
|
||||
t.Fatalf("foo.db open fail: %v", err)
|
||||
}
|
||||
fn(params{sqlite, t, db})
|
||||
if err := db.Close(); err != nil {
|
||||
t.Fatalf("foo.db close fail: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlobs_SQLite(t *testing.T) { sqlite.RunTest(t, testBlobs) }
|
||||
|
||||
func testBlobs(t params) {
|
||||
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||
t.mustExec("create table " + TablePrefix + "foo (id integer primary key, bar blob)")
|
||||
t.mustExec("insert into "+TablePrefix+"foo (id, bar) values(?,?)", 0, blob)
|
||||
|
||||
want := fmt.Sprintf("%x", blob)
|
||||
|
||||
b := make([]byte, 16)
|
||||
err := t.QueryRow("select bar from "+TablePrefix+"foo where id = ?", 0).Scan(&b)
|
||||
got := fmt.Sprintf("%x", b)
|
||||
if err != nil {
|
||||
t.Errorf("[]byte scan: %v", err)
|
||||
} else if got != want {
|
||||
t.Errorf("for []byte, got %q; want %q", got, want)
|
||||
}
|
||||
|
||||
err = t.QueryRow("select bar from "+TablePrefix+"foo where id = ?", 0).Scan(&got)
|
||||
want = string(blob)
|
||||
if err != nil {
|
||||
t.Errorf("string scan: %v", err)
|
||||
} else if got != want {
|
||||
t.Errorf("for string, got %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManyQueryRow_SQLite(t *testing.T) { sqlite.RunTest(t, testManyQueryRow) }
|
||||
|
||||
func testManyQueryRow(t params) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
t.mustExec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))")
|
||||
t.mustExec("insert into "+TablePrefix+"foo (id, name) values(?,?)", 1, "bob")
|
||||
var name string
|
||||
for i := 0; i < 10000; i++ {
|
||||
err := t.QueryRow("select name from "+TablePrefix+"foo where id = ?", 1).Scan(&name)
|
||||
if err != nil || name != "bob" {
|
||||
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTxQuery_SQLite(t *testing.T) { sqlite.RunTest(t, testTxQuery) }
|
||||
|
||||
func testTxQuery(t params) {
|
||||
tx, err := t.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec("create table " + TablePrefix + "foo (id integer primary key, name varchar(50))")
|
||||
if err != nil {
|
||||
t.Logf("cannot drop table "+TablePrefix+"foo: %s", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec("insert into "+TablePrefix+"foo (id, name) values(?,?)", 1, "bob")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := tx.Query("select name from "+TablePrefix+"foo where id = ?", 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
if !r.Next() {
|
||||
if r.Err() != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Fatal("expected one row")
|
||||
}
|
||||
|
||||
var name string
|
||||
err = r.Scan(&name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreparedStmt_SQLite(t *testing.T) { sqlite.RunTest(t, testPreparedStmt) }
|
||||
|
||||
func testPreparedStmt(t params) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
t.mustExec("CREATE TABLE " + TablePrefix + "t (count INT)")
|
||||
sel, err := t.Prepare("SELECT count FROM " + TablePrefix + "t ORDER BY count DESC")
|
||||
if err != nil {
|
||||
t.Fatalf("prepare 1: %v", err)
|
||||
}
|
||||
ins, err := t.Prepare("INSERT INTO " + TablePrefix + "t (count) VALUES (?)")
|
||||
if err != nil {
|
||||
t.Fatalf("prepare 2: %v", err)
|
||||
}
|
||||
|
||||
for n := 1; n <= 3; n++ {
|
||||
if _, err := ins.Exec(n); err != nil {
|
||||
t.Fatalf("insert(%d) = %v", n, err)
|
||||
}
|
||||
}
|
||||
|
||||
const nRuns = 10
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < nRuns; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 10; j++ {
|
||||
count := 0
|
||||
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
|
||||
t.Errorf("Query: %v", err)
|
||||
return
|
||||
}
|
||||
if _, err := ins.Exec(rand.Intn(100)); err != nil {
|
||||
t.Errorf("Insert: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package compile_empty
|
||||
package compile
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package compile_empty
|
||||
package compile
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestCompile_empty(t *testing.T) {
|
||||
func TestCompile_missing(t *testing.T) {
|
||||
sqlite3.Path = "sqlite3.wasm"
|
||||
_, err := sqlite3.Open(":memory:")
|
||||
if err == nil {
|
||||
|
||||
14
tests/compile/nil/compile_test.go
Normal file
14
tests/compile/nil/compile_test.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package compile
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestCompile_nil(t *testing.T) {
|
||||
_, err := sqlite3.Open(":memory:")
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
}
|
||||
291
tests/conn_test.go
Normal file
291
tests/conn_test.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestConn_Open_dir(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := sqlite3.OpenFlags(".", 0)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.CANTOPEN) {
|
||||
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Open_notfound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := sqlite3.OpenFlags("test.db", sqlite3.OPEN_READONLY)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.CANTOPEN) {
|
||||
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Open_modeof(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
file := filepath.Join(dir, "test.db")
|
||||
mode := filepath.Join(dir, "modeof.txt")
|
||||
|
||||
fd, err := os.OpenFile(mode, os.O_CREATE, 0624)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fi, err := fd.Stat()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fd.Close()
|
||||
|
||||
db, err := sqlite3.Open("file:" + file + "?modeof=" + mode)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
di, err := os.Stat(file)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
db.Close()
|
||||
|
||||
if di.Mode() != fi.Mode() {
|
||||
t.Errorf("got %v, want %v", di.Mode(), fi.Mode())
|
||||
}
|
||||
|
||||
_, err = sqlite3.Open("file:" + file + "?modeof=" + mode + "2")
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Close(t *testing.T) {
|
||||
var conn *sqlite3.Conn
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestConn_Close_BUSY(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(`BEGIN`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
err = db.Close()
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.BUSY) {
|
||||
t.Errorf("got %v, want sqlite3.BUSY", err)
|
||||
}
|
||||
var terr interface{ Temporary() bool }
|
||||
if !errors.As(err, &terr) || !terr.Temporary() {
|
||||
t.Error("not temporary", err)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: database is locked: unable to close due to unfinalized statements or unfinished backups` {
|
||||
t.Error("got message:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Pragma(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open("file::memory:?_pragma=busy_timeout(1000)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
got, err := db.Pragma("busy_timeout")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := []string{"1000"}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
|
||||
var serr *sqlite3.Error
|
||||
_, err = db.Pragma("+")
|
||||
if err == nil {
|
||||
t.Error("want: error")
|
||||
}
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.ERROR {
|
||||
t.Errorf("got %d, want sqlite3.ERROR", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: SQL logic error: near "+": syntax error` {
|
||||
t.Error("got message:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_SetInterrupt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
db.SetInterrupt(ctx)
|
||||
|
||||
// Interrupt doesn't interrupt this.
|
||||
err = db.Exec(`SELECT 1`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.SetInterrupt(context.Background())
|
||||
|
||||
stmt, _, err := db.Prepare(`
|
||||
WITH RECURSIVE
|
||||
fibonacci (curr, next)
|
||||
AS (
|
||||
SELECT 0, 1
|
||||
UNION ALL
|
||||
SELECT next, curr + next FROM fibonacci
|
||||
LIMIT 1e6
|
||||
)
|
||||
SELECT min(curr) FROM fibonacci
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
db.SetInterrupt(ctx)
|
||||
cancel()
|
||||
|
||||
// Interrupting works.
|
||||
err = stmt.Exec()
|
||||
if !errors.Is(err, sqlite3.INTERRUPT) {
|
||||
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
|
||||
}
|
||||
|
||||
// Interrupting sticks.
|
||||
err = db.Exec(`SELECT 1`)
|
||||
if !errors.Is(err, sqlite3.INTERRUPT) {
|
||||
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
|
||||
}
|
||||
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
db.SetInterrupt(ctx)
|
||||
|
||||
// Interrupting can be cleared.
|
||||
err = db.Exec(`SELECT 1`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Prepare_empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(``)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt != nil {
|
||||
t.Error("want nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Prepare_tail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, tail, err := db.Prepare(`SELECT 1; -- HERE`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if !strings.Contains(tail, "-- HERE") {
|
||||
t.Errorf("got %q", tail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Prepare_invalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
var serr *sqlite3.Error
|
||||
|
||||
_, _, err = db.Prepare(`SELECT`)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.ERROR {
|
||||
t.Errorf("got %d, want sqlite3.ERROR", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: SQL logic error: incomplete input` {
|
||||
t.Error("got message:", got)
|
||||
}
|
||||
|
||||
_, _, err = db.Prepare(`SELECT * FRM sqlite_schema`)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.ERROR", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.ERROR {
|
||||
t.Errorf("got %d, want sqlite3.ERROR", rc)
|
||||
}
|
||||
if got := serr.SQL(); got != `FRM sqlite_schema` {
|
||||
t.Error("got SQL:", got)
|
||||
}
|
||||
if got := serr.Error(); got != `sqlite3: SQL logic error: near "FRM": syntax error` {
|
||||
t.Error("got message:", got)
|
||||
}
|
||||
}
|
||||
@@ -1,21 +1,44 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
_ "embed"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
//go:embed testdata/wal.db
|
||||
var waldb []byte
|
||||
|
||||
func TestDB_memory(t *testing.T) {
|
||||
t.Parallel()
|
||||
testDB(t, ":memory:")
|
||||
}
|
||||
|
||||
func TestDB_file(t *testing.T) {
|
||||
t.Parallel()
|
||||
testDB(t, filepath.Join(t.TempDir(), "test.db"))
|
||||
}
|
||||
|
||||
func TestDB_wal(t *testing.T) {
|
||||
t.Parallel()
|
||||
wal := filepath.Join(t.TempDir(), "test.db")
|
||||
err := os.WriteFile(wal, waldb, 0666)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testDB(t, wal)
|
||||
}
|
||||
|
||||
func TestDB_vfs(t *testing.T) {
|
||||
testDB(t, "file:test.db?vfs=memdb")
|
||||
}
|
||||
|
||||
func testDB(t *testing.T, name string) {
|
||||
db, err := sqlite3.Open(name)
|
||||
if err != nil {
|
||||
@@ -28,10 +51,14 @@ func testDB(t *testing.T, name string) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
|
||||
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
changes := db.Changes()
|
||||
if changes != 3 {
|
||||
t.Errorf("got %d want 3", changes)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
|
||||
if err != nil {
|
||||
@@ -43,18 +70,22 @@ func testDB(t *testing.T, name string) {
|
||||
ids := []int{0, 1, 2}
|
||||
names := []string{"go", "zig", "whatever"}
|
||||
for ; stmt.Step(); row++ {
|
||||
if ids[row] != stmt.ColumnInt(0) {
|
||||
t.Errorf("got %d, want %d", stmt.ColumnInt(0), ids[row])
|
||||
id := stmt.ColumnInt(0)
|
||||
name := stmt.ColumnText(1)
|
||||
|
||||
if id != ids[row] {
|
||||
t.Errorf("got %d, want %d", id, ids[row])
|
||||
}
|
||||
if names[row] != stmt.ColumnText(1) {
|
||||
t.Errorf("got %q, want %q", stmt.ColumnText(1), names[row])
|
||||
if name != names[row] {
|
||||
t.Errorf("got %q, want %q", name, names[row])
|
||||
}
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if row != 3 {
|
||||
t.Errorf("got %d rows, want %d", row, len(ids))
|
||||
t.Errorf("got %d, want %d", row, len(ids))
|
||||
}
|
||||
|
||||
if err := stmt.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestDir(t *testing.T) {
|
||||
_, err := sqlite3.Open(".")
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
var serr *sqlite3.Error
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.CANTOPEN {
|
||||
t.Errorf("got %d, want sqlite3.CANTOPEN", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: unable to open database file` {
|
||||
t.Error("got message: ", got)
|
||||
}
|
||||
}
|
||||
116
tests/driver_test.go
Normal file
116
tests/driver_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestDriver(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := driver.Open(":memory:", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn, err := db.Conn(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
res, err := conn.ExecContext(ctx,
|
||||
`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
changes, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if changes != 0 {
|
||||
t.Errorf("got %d want 0", changes)
|
||||
}
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if id != 0 {
|
||||
t.Errorf("got %d want 0", changes)
|
||||
}
|
||||
|
||||
res, err = conn.ExecContext(ctx,
|
||||
`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
changes, err = res.RowsAffected()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if changes != 3 {
|
||||
t.Errorf("got %d want 3", changes)
|
||||
}
|
||||
|
||||
stmt, err := conn.PrepareContext(context.Background(),
|
||||
`SELECT id, name FROM users`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
row := 0
|
||||
ids := []int{0, 1, 2}
|
||||
names := []string{"go", "zig", "whatever"}
|
||||
for ; rows.Next(); row++ {
|
||||
var id int
|
||||
var name string
|
||||
err := rows.Scan(&id, &name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if id != ids[row] {
|
||||
t.Errorf("got %d, want %d", id, ids[row])
|
||||
}
|
||||
if name != names[row] {
|
||||
t.Errorf("got %q, want %q", name, names[row])
|
||||
}
|
||||
}
|
||||
if row != 3 {
|
||||
t.Errorf("got %d, want %d", row, len(ids))
|
||||
}
|
||||
|
||||
err = rows.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
77
tests/ext_test.go
Normal file
77
tests/ext_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func Test_base64(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// base64
|
||||
stmt, _, err := db.Prepare(`SELECT base64('TWFueSBoYW5kcyBtYWtlIGxpZ2h0IHdvcmsu')`)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if !stmt.Step() {
|
||||
t.Fatal("expected one row")
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "Many hands make light work." {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_decimal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT decimal_add(decimal('0.1'), decimal('0.2')) = decimal('0.3')`)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if !stmt.Step() {
|
||||
t.Fatal("expected one row")
|
||||
}
|
||||
if !stmt.ColumnBool(0) {
|
||||
t.Error("want true")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_uint(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT 'z2' < 'z11' COLLATE UINT`)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if !stmt.Step() {
|
||||
t.Fatal("expected one row")
|
||||
}
|
||||
if !stmt.ColumnBool(0) {
|
||||
t.Error("want true")
|
||||
}
|
||||
}
|
||||
188
tests/func_test.go
Normal file
188
tests/func_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestCreateFunction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.CreateFunction("test", 1, sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
switch arg := arg[0]; arg.Int() {
|
||||
case 0:
|
||||
ctx.ResultInt(arg.Int())
|
||||
case 1:
|
||||
ctx.ResultInt64(arg.Int64())
|
||||
case 2:
|
||||
ctx.ResultBool(arg.Bool())
|
||||
case 3:
|
||||
ctx.ResultFloat(arg.Float())
|
||||
case 4:
|
||||
ctx.ResultText(arg.Text())
|
||||
case 5:
|
||||
ctx.ResultBlob(arg.Blob(nil))
|
||||
case 6:
|
||||
ctx.ResultZeroBlob(arg.Int64())
|
||||
case 7:
|
||||
ctx.ResultTime(arg.Time(sqlite3.TimeFormatUnix), sqlite3.TimeFormatDefault)
|
||||
case 8:
|
||||
ctx.ResultNull()
|
||||
case 9:
|
||||
ctx.ResultError(sqlite3.FULL)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0, 9)`)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnInt64(0); got != 1 {
|
||||
t.Errorf("got %v, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != true {
|
||||
t.Errorf("got %v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.FLOAT {
|
||||
t.Errorf("got %v, want FLOAT", got)
|
||||
}
|
||||
if got := stmt.ColumnInt64(0); got != 3 {
|
||||
t.Errorf("got %v, want 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "4" {
|
||||
t.Errorf("got %s, want 4", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
|
||||
t.Errorf("got %v, want BLOB", got)
|
||||
}
|
||||
if got := stmt.ColumnRawBlob(0); string(got) != "5" {
|
||||
t.Errorf("got %s, want 5", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
|
||||
t.Errorf("got %v, want BLOB", got)
|
||||
}
|
||||
if got := stmt.ColumnRawBlob(0); len(got) != 6 {
|
||||
t.Errorf("got %v, want 6", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); got.Unix() != 7 {
|
||||
t.Errorf("got %v, want 7", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
t.Error("want error")
|
||||
}
|
||||
if err := stmt.Err(); !errors.Is(err, sqlite3.FULL) {
|
||||
t.Errorf("got %v, want sqlite3.FULL", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnyCollationNeeded(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.AnyCollationNeeded()
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT id, name FROM users ORDER BY name COLLATE silly`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
row := 0
|
||||
ids := []int{0, 2, 1}
|
||||
names := []string{"go", "whatever", "zig"}
|
||||
for ; stmt.Step(); row++ {
|
||||
id := stmt.ColumnInt(0)
|
||||
name := stmt.ColumnText(1)
|
||||
|
||||
if id != ids[row] {
|
||||
t.Errorf("got %d, want %d", id, ids[row])
|
||||
}
|
||||
if name != names[row] {
|
||||
t.Errorf("got %q, want %q", name, names[row])
|
||||
}
|
||||
}
|
||||
if row != 3 {
|
||||
t.Errorf("got %d, want %d", row, len(ids))
|
||||
}
|
||||
|
||||
if err := stmt.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -11,21 +11,54 @@ import (
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
func TestParallel(t *testing.T) {
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
testParallel(t, name, 100)
|
||||
var iter int
|
||||
if testing.Short() {
|
||||
iter = 1000
|
||||
} else {
|
||||
iter = 5000
|
||||
}
|
||||
|
||||
name := "file:" +
|
||||
filepath.Join(t.TempDir(), "test.db") +
|
||||
"?_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=journal_mode(truncate)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
testParallel(t, name, iter)
|
||||
testIntegrity(t, name)
|
||||
}
|
||||
|
||||
func TestMemory(t *testing.T) {
|
||||
var iter int
|
||||
if testing.Short() {
|
||||
iter = 1000
|
||||
} else {
|
||||
iter = 5000
|
||||
}
|
||||
|
||||
name := "file:/test.db?vfs=memdb" +
|
||||
"&_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=journal_mode(memory)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
testParallel(t, name, iter)
|
||||
testIntegrity(t, name)
|
||||
}
|
||||
|
||||
func TestMultiProcess(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
t.Setenv("TestMultiProcess_dbname", name)
|
||||
file := filepath.Join(t.TempDir(), "test.db")
|
||||
t.Setenv("TestMultiProcess_dbfile", file)
|
||||
|
||||
name := "file:" + file +
|
||||
"?_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=journal_mode(truncate)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
|
||||
cmd := exec.Command("go", "test", "-v", "-run", "TestChildProcess")
|
||||
out, err := cmd.StdoutPipe()
|
||||
@@ -44,17 +77,22 @@ func TestMultiProcess(t *testing.T) {
|
||||
|
||||
testParallel(t, name, 1000)
|
||||
if err := cmd.Wait(); err != nil {
|
||||
t.Fatal(err)
|
||||
t.Error(err)
|
||||
}
|
||||
testIntegrity(t, name)
|
||||
}
|
||||
|
||||
func TestChildProcess(t *testing.T) {
|
||||
name := os.Getenv("TestMultiProcess_dbname")
|
||||
if name == "" || testing.Short() {
|
||||
return
|
||||
file := os.Getenv("TestMultiProcess_dbfile")
|
||||
if file == "" || testing.Short() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
name := "file:" + file +
|
||||
"?_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=journal_mode(truncate)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
|
||||
testParallel(t, name, 1000)
|
||||
}
|
||||
|
||||
@@ -66,20 +104,12 @@ func testParallel(t *testing.T, name string, n int) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`
|
||||
PRAGMA locking_mode = NORMAL;
|
||||
PRAGMA busy_timeout = 10000;
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
|
||||
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -94,10 +124,7 @@ func testParallel(t *testing.T, name string, n int) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`
|
||||
PRAGMA locking_mode = NORMAL;
|
||||
PRAGMA busy_timeout = 10000;
|
||||
`)
|
||||
err = db.Exec(`PRAGMA busy_timeout=10000`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -133,7 +160,7 @@ func testParallel(t *testing.T, name string, n int) {
|
||||
}
|
||||
|
||||
var group errgroup.Group
|
||||
group.SetLimit(4)
|
||||
group.SetLimit(6)
|
||||
for i := 0; i < n; i++ {
|
||||
if i&7 != 7 {
|
||||
group.Go(reader)
|
||||
@@ -143,7 +170,7 @@ func testParallel(t *testing.T, name string, n int) {
|
||||
}
|
||||
err = group.Wait()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
463
tests/stmt_test.go
Normal file
463
tests/stmt_test.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestStmt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`INSERT INTO test VALUES (?)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if got := stmt.BindCount(); got != 1 {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
|
||||
if err := stmt.BindBool(1, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.BindBool(1, true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.BindInt(1, 2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.BindFloat(1, math.Pi); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.BindNull(1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.BindText(1, ""); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.BindText(1, "text"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.BindBlob(1, []byte("blob")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.BindBlob(1, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.BindZeroBlob(1, 4); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.ClearBindings(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// The table should have: 0, 1, 2, π, NULL, "", "text", "blob", NULL, "\0\0\0\0", NULL
|
||||
stmt, _, err = db.Prepare(`SELECT col FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "0" {
|
||||
t.Errorf("got %q, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "0" {
|
||||
t.Errorf("got %q, want zero", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != true {
|
||||
t.Errorf("got %v, want true", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 1 {
|
||||
t.Errorf("got %v, want one", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 1 {
|
||||
t.Errorf("got %v, want one", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "1" {
|
||||
t.Errorf("got %q, want one", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "1" {
|
||||
t.Errorf("got %q, want one", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != true {
|
||||
t.Errorf("got %v, want true", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 2 {
|
||||
t.Errorf("got %v, want two", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 2 {
|
||||
t.Errorf("got %v, want two", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "2" {
|
||||
t.Errorf("got %q, want two", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "2" {
|
||||
t.Errorf("got %q, want two", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.FLOAT {
|
||||
t.Errorf("got %v, want FLOAT", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != true {
|
||||
t.Errorf("got %v, want true", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 3 {
|
||||
t.Errorf("got %v, want three", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != math.Pi {
|
||||
t.Errorf("got %v, want π", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "3.14159265358979" {
|
||||
t.Errorf("got %q, want π", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "3.14159265358979" {
|
||||
t.Errorf("got %q, want π", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "text" {
|
||||
t.Errorf(`got %q, want "text"`, got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "text" {
|
||||
t.Errorf(`got %q, want "text"`, got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
|
||||
t.Errorf("got %v, want BLOB", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "blob" {
|
||||
t.Errorf(`got %q, want "blob"`, got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "blob" {
|
||||
t.Errorf(`got %q, want "blob"`, got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
|
||||
t.Errorf("got %v, want BLOB", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "\x00\x00\x00\x00" {
|
||||
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "\x00\x00\x00\x00" {
|
||||
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
if err := stmt.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStmt_Close(t *testing.T) {
|
||||
var stmt *sqlite3.Stmt
|
||||
stmt.Close()
|
||||
}
|
||||
|
||||
func TestStmt_BindName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
want := []string{"", "", "", "", "?5", ":AAA", "@AAA", "$AAA"}
|
||||
stmt, _, err := db.Prepare(`SELECT ?, ?5, :AAA, @AAA, $AAA`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if got := stmt.BindCount(); got != len(want) {
|
||||
t.Errorf("got %d, want %d", got, len(want))
|
||||
}
|
||||
|
||||
for i, name := range want {
|
||||
id := i + 1
|
||||
if got := stmt.BindName(id); got != name {
|
||||
t.Errorf("got %q, want %q", got, name)
|
||||
}
|
||||
if name == "" {
|
||||
id = 0
|
||||
}
|
||||
if got := stmt.BindIndex(name); got != id {
|
||||
t.Errorf("got %d, want %d", got, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStmt_ColumnTime(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT ?, ?, ?, datetime(), unixepoch(), julianday(), NULL, 'abc'`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
|
||||
err = stmt.BindTime(1, reference, sqlite3.TimeFormat4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = stmt.BindTime(2, reference, sqlite3.TimeFormatUnixMilli)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = stmt.BindTime(3, reference, sqlite3.TimeFormatJulianDay)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if now := time.Now(); stmt.Step() {
|
||||
if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); !got.Equal(reference) {
|
||||
t.Errorf("got %v, want %v", got, reference)
|
||||
}
|
||||
if got := stmt.ColumnTime(1, sqlite3.TimeFormatAuto); !got.Equal(reference) {
|
||||
t.Errorf("got %v, want %v", got, reference)
|
||||
}
|
||||
if got := stmt.ColumnTime(2, sqlite3.TimeFormatAuto); got.Sub(reference).Abs() > time.Millisecond {
|
||||
t.Errorf("got %v, want %v", got, reference)
|
||||
}
|
||||
|
||||
if got := stmt.ColumnTime(3, sqlite3.TimeFormatAuto); got.Sub(now).Abs() > time.Second {
|
||||
t.Errorf("got %v, want %v", got, now)
|
||||
}
|
||||
if got := stmt.ColumnTime(4, sqlite3.TimeFormatAuto); got.Sub(now).Abs() > time.Second {
|
||||
t.Errorf("got %v, want %v", got, now)
|
||||
}
|
||||
if got := stmt.ColumnTime(5, sqlite3.TimeFormatAuto); got.Sub(now).Abs() > time.Second/10 {
|
||||
t.Errorf("got %v, want %v", got, now)
|
||||
}
|
||||
|
||||
if got := stmt.ColumnTime(6, sqlite3.TimeFormatAuto); got != (time.Time{}) {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnTime(7, sqlite3.TimeFormatAuto); got != (time.Time{}) {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if stmt.Err() == nil {
|
||||
t.Errorf("want error")
|
||||
}
|
||||
}
|
||||
}
|
||||
BIN
tests/testdata/wal.db
vendored
Normal file
BIN
tests/testdata/wal.db
vendored
Normal file
Binary file not shown.
169
tests/time_test.go
Normal file
169
tests/time_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestTimeFormat_Encode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
|
||||
|
||||
tests := []struct {
|
||||
fmt sqlite3.TimeFormat
|
||||
time time.Time
|
||||
want any
|
||||
}{
|
||||
{sqlite3.TimeFormatDefault, reference, "2013-10-07T04:23:19.12-04:00"},
|
||||
{sqlite3.TimeFormatJulianDay, reference, 2456572.849526851851852},
|
||||
{sqlite3.TimeFormatUnix, reference, int64(1381134199)},
|
||||
{sqlite3.TimeFormatUnixFrac, reference, 1381134199.120},
|
||||
{sqlite3.TimeFormatUnixMilli, reference, int64(1381134199_120)},
|
||||
{sqlite3.TimeFormatUnixMicro, reference, int64(1381134199_120000)},
|
||||
{sqlite3.TimeFormatUnixNano, reference, int64(1381134199_120000000)},
|
||||
{sqlite3.TimeFormat7, reference, "2013-10-07T08:23:19.120"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run("", func(t *testing.T) {
|
||||
if got := tt.fmt.Encode(tt.time); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("%q.Encode(%v) = %v, want %v", tt.fmt, tt.time, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeFormat_Decode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
|
||||
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
|
||||
|
||||
tests := []struct {
|
||||
fmt sqlite3.TimeFormat
|
||||
val any
|
||||
want time.Time
|
||||
wantDelta time.Duration
|
||||
wantErr bool
|
||||
}{
|
||||
{sqlite3.TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
|
||||
{sqlite3.TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
|
||||
{sqlite3.TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
|
||||
{sqlite3.TimeFormatJulianDay, false, time.Time{}, 0, true},
|
||||
|
||||
{sqlite3.TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatUnix, int64(1381134199), reference, time.Second, false},
|
||||
{sqlite3.TimeFormatUnix, "abc", time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnix, false, time.Time{}, 0, true},
|
||||
|
||||
{sqlite3.TimeFormatUnixMilli, "1381134199120", reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnixMilli, false, time.Time{}, 0, true},
|
||||
|
||||
{sqlite3.TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnixMicro, false, time.Time{}, 0, true},
|
||||
|
||||
{sqlite3.TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixNano, "abc", time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnixNano, false, time.Time{}, 0, true},
|
||||
|
||||
{sqlite3.TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
|
||||
{sqlite3.TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199", reference, time.Second, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199120", reference, 0, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, false},
|
||||
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
|
||||
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, false},
|
||||
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, true},
|
||||
|
||||
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
|
||||
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
|
||||
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, false},
|
||||
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, false},
|
||||
{sqlite3.TimeFormat3, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormat9, false, time.Time{}, 0, true},
|
||||
|
||||
{sqlite3.TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
|
||||
{sqlite3.TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
|
||||
{sqlite3.TimeFormatDefault, false, time.Time{}, 0, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("", func(t *testing.T) {
|
||||
got, err := tt.fmt.Decode(tt.val)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.want.Sub(got).Abs() > tt.wantDelta {
|
||||
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_timeCollation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS times (tstamp COLLATE TIME)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`INSERT INTO times VALUES (?), (?), (?)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
stmt.BindTime(1, time.Unix(0, 0).UTC(), sqlite3.TimeFormatDefault)
|
||||
stmt.BindTime(2, time.Unix(0, -1).UTC(), sqlite3.TimeFormatDefault)
|
||||
stmt.BindTime(3, time.Unix(0, +1).UTC(), sqlite3.TimeFormatDefault)
|
||||
stmt.Step()
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err = db.Prepare(`SELECT tstamp FROM times ORDER BY tstamp`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var t0 time.Time
|
||||
for stmt.Step() {
|
||||
t1 := stmt.ColumnTime(0, sqlite3.TimeFormatAuto)
|
||||
if t0.After(t1) {
|
||||
t.Errorf("got %v after %v", t0, t1)
|
||||
}
|
||||
t0 = t1
|
||||
}
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
535
tests/tx_test.go
Normal file
535
tests/tx_test.go
Normal file
@@ -0,0 +1,535 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestConn_Transaction_exec(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
errFailed := errors.New("failed")
|
||||
|
||||
count := func() int {
|
||||
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
if stmt.Step() {
|
||||
return stmt.ColumnInt(0)
|
||||
}
|
||||
t.Fatal(stmt.Err())
|
||||
return 0
|
||||
}
|
||||
|
||||
insert := func(succeed bool) (err error) {
|
||||
tx := db.Begin()
|
||||
defer tx.End(&err)
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if succeed {
|
||||
return nil
|
||||
}
|
||||
return errFailed
|
||||
}
|
||||
|
||||
err = insert(true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := count(); got != 1 {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
|
||||
err = insert(true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := count(); got != 2 {
|
||||
t.Errorf("got %d, want 2", got)
|
||||
}
|
||||
|
||||
err = insert(false)
|
||||
if err != errFailed {
|
||||
t.Errorf("got %v, want errFailed", err)
|
||||
}
|
||||
if got := count(); got != 2 {
|
||||
t.Errorf("got %d, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Transaction_panic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES ('one');`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
panics := func() (err error) {
|
||||
tx := db.Begin()
|
||||
defer tx.End(&err)
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
panic("omg!")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
p := recover()
|
||||
if p != "omg!" {
|
||||
t.Errorf("got %v, want panic", p)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
if stmt.Step() {
|
||||
got := stmt.ColumnInt(0)
|
||||
if got != 1 {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
t.Fatal(stmt.Err())
|
||||
}()
|
||||
|
||||
err = panics()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Transaction_interrupt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tx, err := db.BeginImmediate()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.Exec(`INSERT INTO test VALUES (1)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tx.End(&err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
db.SetInterrupt(ctx)
|
||||
|
||||
tx, err = db.BeginExclusive()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.Exec(`INSERT INTO test VALUES (2)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
_, err = db.BeginImmediate()
|
||||
if !errors.Is(err, sqlite3.INTERRUPT) {
|
||||
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES (3)`)
|
||||
if !errors.Is(err, sqlite3.INTERRUPT) {
|
||||
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
|
||||
}
|
||||
|
||||
err = nil
|
||||
tx.End(&err)
|
||||
if !errors.Is(err, sqlite3.INTERRUPT) {
|
||||
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
|
||||
}
|
||||
|
||||
db.SetInterrupt(context.Background())
|
||||
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
got := stmt.ColumnInt(0)
|
||||
if got != 1 {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
}
|
||||
err = stmt.Err()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Transaction_interrupted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
db.SetInterrupt(ctx)
|
||||
cancel()
|
||||
|
||||
tx := db.Begin()
|
||||
|
||||
err = tx.Commit()
|
||||
if !errors.Is(err, sqlite3.INTERRUPT) {
|
||||
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
|
||||
}
|
||||
|
||||
err = nil
|
||||
tx.End(&err)
|
||||
if !errors.Is(err, sqlite3.INTERRUPT) {
|
||||
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Transaction_rollback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tx := db.Begin()
|
||||
err = db.Exec(`INSERT INTO test VALUES (1)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.Exec(`COMMIT`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tx.End(&err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
got := stmt.ColumnInt(0)
|
||||
if got != 1 {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
}
|
||||
err = stmt.Err()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Savepoint_exec(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
errFailed := errors.New("failed")
|
||||
|
||||
count := func() int {
|
||||
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
if stmt.Step() {
|
||||
return stmt.ColumnInt(0)
|
||||
}
|
||||
t.Fatal(stmt.Err())
|
||||
return 0
|
||||
}
|
||||
|
||||
insert := func(succeed bool) (err error) {
|
||||
defer db.Savepoint().Release(&err)
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if succeed {
|
||||
return nil
|
||||
}
|
||||
return errFailed
|
||||
}
|
||||
|
||||
err = insert(true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := count(); got != 1 {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
|
||||
err = insert(true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := count(); got != 2 {
|
||||
t.Errorf("got %d, want 2", got)
|
||||
}
|
||||
|
||||
err = insert(false)
|
||||
if err != errFailed {
|
||||
t.Errorf("got %v, want errFailed", err)
|
||||
}
|
||||
if got := count(); got != 2 {
|
||||
t.Errorf("got %d, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Savepoint_panic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES ('one');`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
panics := func() (err error) {
|
||||
defer db.Savepoint().Release(&err)
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
panic("omg!")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
p := recover()
|
||||
if p != "omg!" {
|
||||
t.Errorf("got %v, want panic", p)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
if stmt.Step() {
|
||||
got := stmt.ColumnInt(0)
|
||||
if got != 1 {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
t.Fatal(stmt.Err())
|
||||
}()
|
||||
|
||||
err = panics()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Savepoint_interrupt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
savept := db.Savepoint()
|
||||
err = db.Exec(`INSERT INTO test VALUES (1)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
savept.Release(&err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
db.SetInterrupt(ctx)
|
||||
|
||||
savept1 := db.Savepoint()
|
||||
err = db.Exec(`INSERT INTO test VALUES (2)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
savept2 := db.Savepoint()
|
||||
err = db.Exec(`INSERT INTO test VALUES (3)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
db.Savepoint().Release(&err)
|
||||
if !errors.Is(err, sqlite3.INTERRUPT) {
|
||||
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES (4)`)
|
||||
if !errors.Is(err, sqlite3.INTERRUPT) {
|
||||
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
|
||||
}
|
||||
|
||||
err = context.Canceled
|
||||
savept2.Release(&err)
|
||||
if err != context.Canceled {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = nil
|
||||
savept1.Release(&err)
|
||||
if !errors.Is(err, sqlite3.INTERRUPT) {
|
||||
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
|
||||
}
|
||||
|
||||
db.SetInterrupt(context.Background())
|
||||
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
got := stmt.ColumnInt(0)
|
||||
if got != 1 {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
}
|
||||
err = stmt.Err()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Savepoint_rollback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
savept := db.Savepoint()
|
||||
err = db.Exec(`INSERT INTO test VALUES (1)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.Exec(`COMMIT`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
savept.Release(&err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
got := stmt.ColumnInt(0)
|
||||
if got != 1 {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
}
|
||||
err = stmt.Err()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,23 @@
|
||||
package sqlite3
|
||||
package tests
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestDatatype_String(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
data Datatype
|
||||
data sqlite3.Datatype
|
||||
want string
|
||||
}{
|
||||
{INTEGER, "INTEGER"},
|
||||
{FLOAT, "FLOAT"},
|
||||
{TEXT, "TEXT"},
|
||||
{BLOB, "BLOB"},
|
||||
{NULL, "NULL"},
|
||||
{sqlite3.INTEGER, "INTEGER"},
|
||||
{sqlite3.FLOAT, "FLOAT"},
|
||||
{sqlite3.TEXT, "TEXT"},
|
||||
{sqlite3.BLOB, "BLOB"},
|
||||
{sqlite3.NULL, "NULL"},
|
||||
{10, "10"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
36
tests/vfs_test.go
Normal file
36
tests/vfs_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
"github.com/ncruces/go-sqlite3/vfs/readervfs"
|
||||
)
|
||||
|
||||
func TestMemoryVFS_Open_notfound(t *testing.T) {
|
||||
memdb.Delete("demo.db")
|
||||
|
||||
_, err := sqlite3.Open("file:/demo.db?vfs=memdb&mode=ro")
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.CANTOPEN) {
|
||||
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReaderVFS_Open_notfound(t *testing.T) {
|
||||
readervfs.Delete("demo.db")
|
||||
|
||||
_, err := sqlite3.Open("file:demo.db?vfs=reader")
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.CANTOPEN) {
|
||||
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
|
||||
}
|
||||
}
|
||||
340
time.go
Normal file
340
time.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/ncruces/julianday"
|
||||
)
|
||||
|
||||
// TimeFormat specifies how to encode/decode time values.
|
||||
//
|
||||
// See the documentation for the [TimeFormatDefault] constant
|
||||
// for formats recognized by SQLite.
|
||||
//
|
||||
// https://www.sqlite.org/lang_datefunc.html
|
||||
type TimeFormat string
|
||||
|
||||
// TimeFormats recognized by SQLite to encode/decode time values.
|
||||
//
|
||||
// https://www.sqlite.org/lang_datefunc.html
|
||||
const (
|
||||
TimeFormatDefault TimeFormat = "" // time.RFC3339Nano
|
||||
|
||||
// Text formats
|
||||
TimeFormat1 TimeFormat = "2006-01-02"
|
||||
TimeFormat2 TimeFormat = "2006-01-02 15:04"
|
||||
TimeFormat3 TimeFormat = "2006-01-02 15:04:05"
|
||||
TimeFormat4 TimeFormat = "2006-01-02 15:04:05.000"
|
||||
TimeFormat5 TimeFormat = "2006-01-02T15:04"
|
||||
TimeFormat6 TimeFormat = "2006-01-02T15:04:05"
|
||||
TimeFormat7 TimeFormat = "2006-01-02T15:04:05.000"
|
||||
TimeFormat8 TimeFormat = "15:04"
|
||||
TimeFormat9 TimeFormat = "15:04:05"
|
||||
TimeFormat10 TimeFormat = "15:04:05.000"
|
||||
|
||||
TimeFormat2TZ = TimeFormat2 + "Z07:00"
|
||||
TimeFormat3TZ = TimeFormat3 + "Z07:00"
|
||||
TimeFormat4TZ = TimeFormat4 + "Z07:00"
|
||||
TimeFormat5TZ = TimeFormat5 + "Z07:00"
|
||||
TimeFormat6TZ = TimeFormat6 + "Z07:00"
|
||||
TimeFormat7TZ = TimeFormat7 + "Z07:00"
|
||||
TimeFormat8TZ = TimeFormat8 + "Z07:00"
|
||||
TimeFormat9TZ = TimeFormat9 + "Z07:00"
|
||||
TimeFormat10TZ = TimeFormat10 + "Z07:00"
|
||||
|
||||
// Numeric formats
|
||||
TimeFormatJulianDay TimeFormat = "julianday"
|
||||
TimeFormatUnix TimeFormat = "unixepoch"
|
||||
TimeFormatUnixFrac TimeFormat = "unixepoch_frac"
|
||||
TimeFormatUnixMilli TimeFormat = "unixepoch_milli" // not an SQLite format
|
||||
TimeFormatUnixMicro TimeFormat = "unixepoch_micro" // not an SQLite format
|
||||
TimeFormatUnixNano TimeFormat = "unixepoch_nano" // not an SQLite format
|
||||
|
||||
// Auto
|
||||
TimeFormatAuto TimeFormat = "auto"
|
||||
)
|
||||
|
||||
// Encode encodes a time value using this format.
|
||||
//
|
||||
// [TimeFormatDefault] and [TimeFormatAuto] encode using [time.RFC3339Nano],
|
||||
// with nanosecond accuracy, and preserving any timezone offset.
|
||||
//
|
||||
// This is the format used by the [database/sql] driver:
|
||||
// [database/sql.Row.Scan] will decode as [time.Time]
|
||||
// values encoded with [time.RFC3339Nano].
|
||||
//
|
||||
// Time values encoded with [time.RFC3339Nano] cannot be sorted as strings
|
||||
// to produce a time-ordered sequence.
|
||||
//
|
||||
// Assuming that the time zones of the time values are the same (e.g., all in UTC),
|
||||
// and expressed using the same string (e.g., all "Z" or all "+00:00"),
|
||||
// use the TIME [collating sequence] to produce a time-ordered sequence.
|
||||
//
|
||||
// Otherwise, use [TimeFormat7] for time-ordered encoding.
|
||||
//
|
||||
// Formats [TimeFormat1] through [TimeFormat10]
|
||||
// convert time values to UTC before encoding.
|
||||
//
|
||||
// Returns a string for the text formats,
|
||||
// a float64 for [TimeFormatJulianDay] and [TimeFormatUnixFrac],
|
||||
// or an int64 for the other numeric formats.
|
||||
//
|
||||
// https://www.sqlite.org/lang_datefunc.html
|
||||
//
|
||||
// [collating sequence]: https://www.sqlite.org/datatype3.html#collating_sequences
|
||||
func (f TimeFormat) Encode(t time.Time) any {
|
||||
switch f {
|
||||
// Numeric formats
|
||||
case TimeFormatJulianDay:
|
||||
return julianday.Float(t)
|
||||
case TimeFormatUnix:
|
||||
return t.Unix()
|
||||
case TimeFormatUnixFrac:
|
||||
return float64(t.Unix()) + float64(t.Nanosecond())*1e-9
|
||||
case TimeFormatUnixMilli:
|
||||
return t.UnixMilli()
|
||||
case TimeFormatUnixMicro:
|
||||
return t.UnixMicro()
|
||||
case TimeFormatUnixNano:
|
||||
return t.UnixNano()
|
||||
// Special formats
|
||||
case TimeFormatDefault, TimeFormatAuto:
|
||||
f = time.RFC3339Nano
|
||||
// SQLite assumes UTC if unspecified.
|
||||
case
|
||||
TimeFormat1, TimeFormat2,
|
||||
TimeFormat3, TimeFormat4,
|
||||
TimeFormat5, TimeFormat6,
|
||||
TimeFormat7, TimeFormat8,
|
||||
TimeFormat9, TimeFormat10:
|
||||
t = t.UTC()
|
||||
}
|
||||
return t.Format(string(f))
|
||||
}
|
||||
|
||||
// Decode decodes a time value using this format.
|
||||
//
|
||||
// The time value can be a string, an int64, or a float64.
|
||||
//
|
||||
// Formats [TimeFormat8] through [TimeFormat10]
|
||||
// (and [TimeFormat8TZ] through [TimeFormat10TZ])
|
||||
// assume a date of 2000-01-01.
|
||||
//
|
||||
// The timezone indicator and fractional seconds are always optional
|
||||
// for formats [TimeFormat2] through [TimeFormat10]
|
||||
// (and [TimeFormat2TZ] through [TimeFormat10TZ]).
|
||||
//
|
||||
// [TimeFormatAuto] implements (and extends) the SQLite auto modifier.
|
||||
// Julian day numbers are safe to use for historical dates,
|
||||
// from 4712BC through 9999AD.
|
||||
// Unix timestamps (expressed in seconds, milliseconds, microseconds, or nanoseconds)
|
||||
// are safe to use for current events, from at least 1980 through at least 2260.
|
||||
// Unix timestamps before 1980 and after 9999 may be misinterpreted as julian day numbers,
|
||||
// or have the wrong time unit.
|
||||
//
|
||||
// https://www.sqlite.org/lang_datefunc.html
|
||||
func (f TimeFormat) Decode(v any) (time.Time, error) {
|
||||
switch f {
|
||||
// Numeric formats
|
||||
case TimeFormatJulianDay:
|
||||
switch v := v.(type) {
|
||||
case string:
|
||||
return julianday.Parse(v)
|
||||
case float64:
|
||||
return julianday.FloatTime(v), nil
|
||||
case int64:
|
||||
return julianday.Time(v, 0), nil
|
||||
default:
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
|
||||
case TimeFormatUnix, TimeFormatUnixFrac:
|
||||
if s, ok := v.(string); ok {
|
||||
f, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
v = f
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
sec, frac := math.Modf(v)
|
||||
nsec := math.Floor(frac * 1e9)
|
||||
return time.Unix(int64(sec), int64(nsec)), nil
|
||||
case int64:
|
||||
return time.Unix(v, 0), nil
|
||||
default:
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
|
||||
case TimeFormatUnixMilli:
|
||||
if s, ok := v.(string); ok {
|
||||
i, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
v = i
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
return time.UnixMilli(int64(math.Floor(v))), nil
|
||||
case int64:
|
||||
return time.UnixMilli(int64(v)), nil
|
||||
default:
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
|
||||
case TimeFormatUnixMicro:
|
||||
if s, ok := v.(string); ok {
|
||||
i, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
v = i
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
return time.UnixMicro(int64(math.Floor(v))), nil
|
||||
case int64:
|
||||
return time.UnixMicro(int64(v)), nil
|
||||
default:
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
|
||||
case TimeFormatUnixNano:
|
||||
if s, ok := v.(string); ok {
|
||||
i, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
v = i
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
return time.Unix(0, int64(math.Floor(v))), nil
|
||||
case int64:
|
||||
return time.Unix(0, int64(v)), nil
|
||||
default:
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
|
||||
// Special formats
|
||||
case TimeFormatAuto:
|
||||
switch s := v.(type) {
|
||||
case string:
|
||||
i, err := strconv.ParseInt(s, 10, 64)
|
||||
if err == nil {
|
||||
v = i
|
||||
break
|
||||
}
|
||||
f, err := strconv.ParseFloat(s, 64)
|
||||
if err == nil {
|
||||
v = f
|
||||
break
|
||||
}
|
||||
|
||||
dates := []TimeFormat{
|
||||
TimeFormat6TZ, TimeFormat6, TimeFormat3TZ, TimeFormat3,
|
||||
TimeFormat5TZ, TimeFormat5, TimeFormat2TZ, TimeFormat2,
|
||||
TimeFormat1,
|
||||
}
|
||||
for _, f := range dates {
|
||||
t, err := time.Parse(string(f), s)
|
||||
if err == nil {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
times := []TimeFormat{
|
||||
TimeFormat9TZ, TimeFormat9, TimeFormat8TZ, TimeFormat8,
|
||||
}
|
||||
for _, f := range times {
|
||||
t, err := time.Parse(string(f), s)
|
||||
if err == nil {
|
||||
return t.AddDate(2000, 0, 0), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
if 0 <= v && v < 5373484.5 {
|
||||
return TimeFormatJulianDay.Decode(v)
|
||||
}
|
||||
if v < 253402300800 {
|
||||
return TimeFormatUnixFrac.Decode(v)
|
||||
}
|
||||
if v < 253402300800_000 {
|
||||
return TimeFormatUnixMilli.Decode(v)
|
||||
}
|
||||
if v < 253402300800_000000 {
|
||||
return TimeFormatUnixMicro.Decode(v)
|
||||
}
|
||||
return TimeFormatUnixNano.Decode(v)
|
||||
case int64:
|
||||
if 0 <= v && v < 5373485 {
|
||||
return TimeFormatJulianDay.Decode(v)
|
||||
}
|
||||
if v < 253402300800 {
|
||||
return TimeFormatUnixFrac.Decode(v)
|
||||
}
|
||||
if v < 253402300800_000 {
|
||||
return TimeFormatUnixMilli.Decode(v)
|
||||
}
|
||||
if v < 253402300800_000000 {
|
||||
return TimeFormatUnixMicro.Decode(v)
|
||||
}
|
||||
return TimeFormatUnixNano.Decode(v)
|
||||
default:
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
|
||||
case
|
||||
TimeFormat2, TimeFormat2TZ,
|
||||
TimeFormat3, TimeFormat3TZ,
|
||||
TimeFormat4, TimeFormat4TZ,
|
||||
TimeFormat5, TimeFormat5TZ,
|
||||
TimeFormat6, TimeFormat6TZ,
|
||||
TimeFormat7, TimeFormat7TZ:
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
return f.parseRelaxed(s)
|
||||
|
||||
case
|
||||
TimeFormat8, TimeFormat8TZ,
|
||||
TimeFormat9, TimeFormat9TZ,
|
||||
TimeFormat10, TimeFormat10TZ:
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
t, err := f.parseRelaxed(s)
|
||||
return t.AddDate(2000, 0, 0), err
|
||||
|
||||
default:
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
if f == "" {
|
||||
f = time.RFC3339Nano
|
||||
}
|
||||
return time.Parse(string(f), s)
|
||||
}
|
||||
}
|
||||
|
||||
func (f TimeFormat) parseRelaxed(s string) (time.Time, error) {
|
||||
fs := string(f)
|
||||
fs = strings.TrimSuffix(fs, "Z07:00")
|
||||
fs = strings.TrimSuffix(fs, ".000")
|
||||
t, err := time.Parse(fs+"Z07:00", s)
|
||||
if err != nil {
|
||||
return time.Parse(fs, s)
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
202
tx.go
Normal file
202
tx.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Tx is an in-progress database transaction.
|
||||
//
|
||||
// https://www.sqlite.org/lang_transaction.html
|
||||
type Tx struct {
|
||||
c *Conn
|
||||
}
|
||||
|
||||
// Begin starts a deferred transaction.
|
||||
//
|
||||
// https://www.sqlite.org/lang_transaction.html
|
||||
func (c *Conn) Begin() Tx {
|
||||
// BEGIN even if interrupted.
|
||||
err := c.txExecInterrupted(`BEGIN DEFERRED`)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return Tx{c}
|
||||
}
|
||||
|
||||
// BeginImmediate starts an immediate transaction.
|
||||
//
|
||||
// https://www.sqlite.org/lang_transaction.html
|
||||
func (c *Conn) BeginImmediate() (Tx, error) {
|
||||
err := c.Exec(`BEGIN IMMEDIATE`)
|
||||
if err != nil {
|
||||
return Tx{}, err
|
||||
}
|
||||
return Tx{c}, nil
|
||||
}
|
||||
|
||||
// BeginExclusive starts an exclusive transaction.
|
||||
//
|
||||
// https://www.sqlite.org/lang_transaction.html
|
||||
func (c *Conn) BeginExclusive() (Tx, error) {
|
||||
err := c.Exec(`BEGIN EXCLUSIVE`)
|
||||
if err != nil {
|
||||
return Tx{}, err
|
||||
}
|
||||
return Tx{c}, nil
|
||||
}
|
||||
|
||||
// End calls either [Tx.Commit] or [Tx.Rollback]
|
||||
// depending on whether *error points to a nil or non-nil error.
|
||||
//
|
||||
// This is meant to be deferred:
|
||||
//
|
||||
// func doWork(conn *sqlite3.Conn) (err error) {
|
||||
// tx := conn.Begin()
|
||||
// defer tx.End(&err)
|
||||
//
|
||||
// // ... do work in the transaction
|
||||
// }
|
||||
//
|
||||
// https://www.sqlite.org/lang_transaction.html
|
||||
func (tx Tx) End(errp *error) {
|
||||
recovered := recover()
|
||||
if recovered != nil {
|
||||
defer panic(recovered)
|
||||
}
|
||||
|
||||
if *errp == nil && recovered == nil {
|
||||
// Success path.
|
||||
if tx.c.GetAutocommit() { // There is nothing to commit.
|
||||
return
|
||||
}
|
||||
*errp = tx.Commit()
|
||||
if *errp == nil {
|
||||
return
|
||||
}
|
||||
// Fall through to the error path.
|
||||
}
|
||||
|
||||
// Error path.
|
||||
if tx.c.GetAutocommit() { // There is nothing to rollback.
|
||||
return
|
||||
}
|
||||
err := tx.Rollback()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Commit commits the transaction.
|
||||
//
|
||||
// https://www.sqlite.org/lang_transaction.html
|
||||
func (tx Tx) Commit() error {
|
||||
return tx.c.Exec(`COMMIT`)
|
||||
}
|
||||
|
||||
// Rollback rolls back the transaction,
|
||||
// even if the connection has been interrupted.
|
||||
//
|
||||
// https://www.sqlite.org/lang_transaction.html
|
||||
func (tx Tx) Rollback() error {
|
||||
return tx.c.txExecInterrupted(`ROLLBACK`)
|
||||
}
|
||||
|
||||
// Savepoint is a marker within a transaction
|
||||
// that allows for partial rollback.
|
||||
//
|
||||
// https://www.sqlite.org/lang_savepoint.html
|
||||
type Savepoint struct {
|
||||
c *Conn
|
||||
name string
|
||||
}
|
||||
|
||||
// Savepoint establishes a new transaction savepoint.
|
||||
//
|
||||
// https://www.sqlite.org/lang_savepoint.html
|
||||
func (c *Conn) Savepoint() Savepoint {
|
||||
name := "sqlite3.Savepoint"
|
||||
var pc [1]uintptr
|
||||
if n := runtime.Callers(2, pc[:]); n > 0 {
|
||||
frames := runtime.CallersFrames(pc[:n])
|
||||
frame, _ := frames.Next()
|
||||
if frame.Function != "" {
|
||||
name = frame.Function
|
||||
}
|
||||
}
|
||||
// Names can be reused; this makes catching bugs more likely.
|
||||
name += "#" + strconv.Itoa(int(rand.Int31()))
|
||||
|
||||
err := c.txExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return Savepoint{c: c, name: name}
|
||||
}
|
||||
|
||||
// Release releases the savepoint rolling back any changes
|
||||
// if *error points to a non-nil error.
|
||||
//
|
||||
// This is meant to be deferred:
|
||||
//
|
||||
// func doWork(conn *sqlite3.Conn) (err error) {
|
||||
// savept := conn.Savepoint()
|
||||
// defer savept.Release(&err)
|
||||
//
|
||||
// // ... do work in the transaction
|
||||
// }
|
||||
func (s Savepoint) Release(errp *error) {
|
||||
recovered := recover()
|
||||
if recovered != nil {
|
||||
defer panic(recovered)
|
||||
}
|
||||
|
||||
if *errp == nil && recovered == nil {
|
||||
// Success path.
|
||||
if s.c.GetAutocommit() { // There is nothing to commit.
|
||||
return
|
||||
}
|
||||
*errp = s.c.Exec(fmt.Sprintf("RELEASE %q;", s.name))
|
||||
if *errp == nil {
|
||||
return
|
||||
}
|
||||
// Fall through to the error path.
|
||||
}
|
||||
|
||||
// Error path.
|
||||
if s.c.GetAutocommit() { // There is nothing to rollback.
|
||||
return
|
||||
}
|
||||
// ROLLBACK and RELEASE even if interrupted.
|
||||
err := s.c.txExecInterrupted(fmt.Sprintf(`
|
||||
ROLLBACK TO %[1]q;
|
||||
RELEASE %[1]q;
|
||||
`, s.name))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Rollback rolls the transaction back to the savepoint,
|
||||
// even if the connection has been interrupted.
|
||||
// Rollback does not release the savepoint.
|
||||
//
|
||||
// https://www.sqlite.org/lang_transaction.html
|
||||
func (s Savepoint) Rollback() error {
|
||||
// ROLLBACK even if interrupted.
|
||||
return s.c.txExecInterrupted(fmt.Sprintf("ROLLBACK TO %q;", s.name))
|
||||
}
|
||||
|
||||
func (c *Conn) txExecInterrupted(sql string) error {
|
||||
err := c.Exec(sql)
|
||||
if errors.Is(err, INTERRUPT) {
|
||||
old := c.SetInterrupt(context.Background())
|
||||
defer c.SetInterrupt(old)
|
||||
err = c.Exec(sql)
|
||||
}
|
||||
return err
|
||||
}
|
||||
161
util_test.go
161
util_test.go
@@ -1,161 +0,0 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"math"
|
||||
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Path = "./embed/sqlite3.wasm"
|
||||
}
|
||||
|
||||
func newMemory(size uint32) memory {
|
||||
mem := make(mockMemory, size)
|
||||
return memory{mockModule{&mem}}
|
||||
}
|
||||
|
||||
type mockModule struct {
|
||||
memory api.Memory
|
||||
}
|
||||
|
||||
func (m mockModule) Memory() api.Memory { return m.memory }
|
||||
func (m mockModule) String() string { return "mockModule" }
|
||||
func (m mockModule) Name() string { return "mockModule" }
|
||||
|
||||
func (m mockModule) ExportedGlobal(name string) api.Global { return nil }
|
||||
func (m mockModule) ExportedMemory(name string) api.Memory { return nil }
|
||||
func (m mockModule) ExportedFunction(name string) api.Function { return nil }
|
||||
func (m mockModule) ExportedMemoryDefinitions() map[string]api.MemoryDefinition { return nil }
|
||||
func (m mockModule) ExportedFunctionDefinitions() map[string]api.FunctionDefinition { return nil }
|
||||
func (m mockModule) CloseWithExitCode(ctx context.Context, exitCode uint32) error { return nil }
|
||||
func (m mockModule) Close(context.Context) error { return nil }
|
||||
|
||||
type mockMemory []byte
|
||||
|
||||
func (m mockMemory) Definition() api.MemoryDefinition { return nil }
|
||||
|
||||
func (m mockMemory) Size() uint32 { return uint32(len(m)) }
|
||||
|
||||
func (m mockMemory) ReadByte(offset uint32) (byte, bool) {
|
||||
if offset >= m.Size() {
|
||||
return 0, false
|
||||
}
|
||||
return m[offset], true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadUint16Le(offset uint32) (uint16, bool) {
|
||||
if !m.hasSize(offset, 2) {
|
||||
return 0, false
|
||||
}
|
||||
return binary.LittleEndian.Uint16(m[offset : offset+2]), true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadUint32Le(offset uint32) (uint32, bool) {
|
||||
if !m.hasSize(offset, 4) {
|
||||
return 0, false
|
||||
}
|
||||
return binary.LittleEndian.Uint32(m[offset : offset+4]), true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadFloat32Le(offset uint32) (float32, bool) {
|
||||
v, ok := m.ReadUint32Le(offset)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
return math.Float32frombits(v), true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadUint64Le(offset uint32) (uint64, bool) {
|
||||
if !m.hasSize(offset, 8) {
|
||||
return 0, false
|
||||
}
|
||||
return binary.LittleEndian.Uint64(m[offset : offset+8]), true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadFloat64Le(offset uint32) (float64, bool) {
|
||||
v, ok := m.ReadUint64Le(offset)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
return math.Float64frombits(v), true
|
||||
}
|
||||
|
||||
func (m mockMemory) Read(offset, byteCount uint32) ([]byte, bool) {
|
||||
if !m.hasSize(offset, byteCount) {
|
||||
return nil, false
|
||||
}
|
||||
return m[offset : offset+byteCount : offset+byteCount], true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteByte(offset uint32, v byte) bool {
|
||||
if offset >= m.Size() {
|
||||
return false
|
||||
}
|
||||
m[offset] = v
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteUint16Le(offset uint32, v uint16) bool {
|
||||
if !m.hasSize(offset, 2) {
|
||||
return false
|
||||
}
|
||||
binary.LittleEndian.PutUint16(m[offset:], v)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteUint32Le(offset, v uint32) bool {
|
||||
if !m.hasSize(offset, 4) {
|
||||
return false
|
||||
}
|
||||
binary.LittleEndian.PutUint32(m[offset:], v)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteFloat32Le(offset uint32, v float32) bool {
|
||||
return m.WriteUint32Le(offset, math.Float32bits(v))
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteUint64Le(offset uint32, v uint64) bool {
|
||||
if !m.hasSize(offset, 8) {
|
||||
return false
|
||||
}
|
||||
binary.LittleEndian.PutUint64(m[offset:], v)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteFloat64Le(offset uint32, v float64) bool {
|
||||
return m.WriteUint64Le(offset, math.Float64bits(v))
|
||||
}
|
||||
|
||||
func (m mockMemory) Write(offset uint32, val []byte) bool {
|
||||
if !m.hasSize(offset, uint32(len(val))) {
|
||||
return false
|
||||
}
|
||||
copy(m[offset:], val)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteString(offset uint32, val string) bool {
|
||||
if !m.hasSize(offset, uint32(len(val))) {
|
||||
return false
|
||||
}
|
||||
copy(m[offset:], val)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *mockMemory) Grow(delta uint32) (result uint32, ok bool) {
|
||||
prev := (len(*m) + 65535) / 65536
|
||||
*m = append(*m, make([]byte, 65536*delta)...)
|
||||
return uint32(prev), true
|
||||
}
|
||||
|
||||
func (m mockMemory) PageSize() (result uint32) {
|
||||
return uint32(len(m) / 65536)
|
||||
}
|
||||
|
||||
func (m mockMemory) hasSize(offset uint32, byteCount uint32) bool {
|
||||
return uint64(offset)+uint64(byteCount) <= uint64(len(m))
|
||||
}
|
||||
125
value.go
Normal file
125
value.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Value is any value that can be stored in a database table.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value.html
|
||||
type Value struct {
|
||||
*sqlite
|
||||
handle uint32
|
||||
}
|
||||
|
||||
// Type returns the initial [Datatype] of the value.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Type() Datatype {
|
||||
r := v.call(v.api.valueType, uint64(v.handle))
|
||||
return Datatype(r)
|
||||
}
|
||||
|
||||
// Bool returns the value as a bool.
|
||||
// SQLite does not have a separate boolean storage class.
|
||||
// Instead, boolean values are retrieved as integers,
|
||||
// with 0 converted to false and any other value to true.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Bool() bool {
|
||||
if i := v.Int64(); i != 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Int returns the value as an int.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Int() int {
|
||||
return int(v.Int64())
|
||||
}
|
||||
|
||||
// Int64 returns the value as an int64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Int64() int64 {
|
||||
r := v.call(v.api.valueInteger, uint64(v.handle))
|
||||
return int64(r)
|
||||
}
|
||||
|
||||
// Float returns the value as a float64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Float() float64 {
|
||||
r := v.call(v.api.valueFloat, uint64(v.handle))
|
||||
return math.Float64frombits(r)
|
||||
}
|
||||
|
||||
// Time returns the value as a [time.Time].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Time(format TimeFormat) time.Time {
|
||||
var a any
|
||||
switch v.Type() {
|
||||
case INTEGER:
|
||||
a = v.Int64()
|
||||
case FLOAT:
|
||||
a = v.Float()
|
||||
case TEXT, BLOB:
|
||||
a = v.Text()
|
||||
case NULL:
|
||||
return time.Time{}
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
t, _ := format.Decode(a)
|
||||
return t
|
||||
}
|
||||
|
||||
// Text returns the value as a string.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Text() string {
|
||||
return string(v.RawText())
|
||||
}
|
||||
|
||||
// Blob appends to buf and returns
|
||||
// the value as a []byte.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Blob(buf []byte) []byte {
|
||||
return append(buf, v.RawBlob()...)
|
||||
}
|
||||
|
||||
// RawText returns the value as a []byte.
|
||||
// The []byte is owned by SQLite and may be invalidated by
|
||||
// subsequent calls to [Value] methods.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) RawText() []byte {
|
||||
r := v.call(v.api.valueText, uint64(v.handle))
|
||||
return v.rawBytes(uint32(r))
|
||||
}
|
||||
|
||||
// RawBlob returns the value as a []byte.
|
||||
// The []byte is owned by SQLite and may be invalidated by
|
||||
// subsequent calls to [Value] methods.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) RawBlob() []byte {
|
||||
r := v.call(v.api.valueBlob, uint64(v.handle))
|
||||
return v.rawBytes(uint32(r))
|
||||
}
|
||||
|
||||
func (v Value) rawBytes(ptr uint32) []byte {
|
||||
if ptr == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
r := v.call(v.api.valueBytes, uint64(v.handle))
|
||||
return util.View(v.mod, ptr, r)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user