Compare commits

...

55 Commits

Author SHA1 Message Date
Nuno Cruces
828788912e JSON example. 2023-11-09 12:11:36 +00:00
Nuno Cruces
6f8645cd2e Tests. 2023-11-08 07:28:48 +00:00
Nuno Cruces
c00927e8bb Driver savepoints. 2023-11-07 15:19:40 +00:00
dependabot[bot]
6b28be6d0e Bump golang.org/x/sys from 0.13.0 to 0.14.0 (#36)
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.13.0 to 0.14.0.
- [Commits](https://github.com/golang/sys/compare/v0.13.0...v0.14.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-07 15:00:00 +00:00
dependabot[bot]
310b4ff29d Bump golang.org/x/sync from 0.4.0 to 0.5.0 (#35)
Bumps [golang.org/x/sync](https://github.com/golang/sync) from 0.4.0 to 0.5.0.
- [Commits](https://github.com/golang/sync/compare/v0.4.0...v0.5.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sync
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-07 12:38:12 +00:00
dependabot[bot]
e82cf16b11 Bump golang.org/x/text from 0.13.0 to 0.14.0 (#34)
Bumps [golang.org/x/text](https://github.com/golang/text) from 0.13.0 to 0.14.0.
- [Release notes](https://github.com/golang/text/releases)
- [Commits](https://github.com/golang/text/compare/v0.13.0...v0.14.0)

---
updated-dependencies:
- dependency-name: golang.org/x/text
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-07 00:57:28 +00:00
Nuno Cruces
24c9b57c56 Pointer-passing interfaces. 2023-11-07 00:50:43 +00:00
Nuno Cruces
24b965ac7e Refactor. 2023-11-06 18:29:28 +00:00
Nuno Cruces
446168c572 Update workflows. 2023-11-04 11:21:31 +00:00
Nuno Cruces
a9e2cbbfc5 Quote values, identifiers. 2023-11-04 01:18:25 +00:00
Nuno Cruces
a7c00eb150 SQLite 3.44.0. 2023-11-03 03:43:14 -07:00
Nuno Cruces
0bcdb712ba SQL json_time function. 2023-11-03 03:40:46 -07:00
Nuno Cruces
2157d0f325 Interrupts: avoid goroutine. 2023-10-25 14:12:21 +01:00
Nuno Cruces
6353160619 Improve benchmark repeatability. 2023-10-25 13:17:37 +01:00
Nuno Cruces
501d157279 Update BSD test. 2023-10-24 23:27:10 +01:00
Nuno Cruces
4db18a7b9a JSON encoding fix. 2023-10-19 16:46:58 +01:00
Nuno Cruces
a9dddaa86c Optimize VFS search. 2023-10-19 16:43:54 +01:00
Nuno Cruces
b25936dbec Unix formats return UTC. 2023-10-19 12:07:03 +01:00
dependabot[bot]
bf23041e46 Bump github.com/ncruces/julianday from 0.1.5 to 1.0.0 (#33)
Bumps [github.com/ncruces/julianday](https://github.com/ncruces/julianday) from 0.1.5 to 1.0.0.
- [Commits](https://github.com/ncruces/julianday/compare/v0.1.5...v1.0.0)

---
updated-dependencies:
- dependency-name: github.com/ncruces/julianday
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-19 00:34:35 +01:00
Nuno Cruces
d60fceac92 JSON support. 2023-10-18 23:14:46 +01:00
Nuno Cruces
61da30f44a Allow configuring wazero. 2023-10-18 15:06:32 +01:00
Nuno Cruces
d4ff605983 Time scanner. 2023-10-18 15:06:12 +01:00
Nuno Cruces
8d0c654178 Cross compilation. 2023-10-17 15:30:08 +01:00
Nuno Cruces
728e59951b Test BSD. 2023-10-16 12:51:49 +01:00
Nuno Cruces
f7b16bad5c Patch flaky tests. 2023-10-16 12:26:25 +01:00
Nuno Cruces
db3e6da31a BSD locks. 2023-10-16 02:11:20 +01:00
Nuno Cruces
3f443b2ecc API change. 2023-10-13 18:53:37 +01:00
Nuno Cruces
eec45ea684 Towards JSON. 2023-10-13 17:06:05 +01:00
Nuno Cruces
f6d77f3cf4 GORM v1.25.5. 2023-10-13 00:42:06 +01:00
Nuno Cruces
d5d7cd1f2d SQLite 3.43.2. 2023-10-12 10:52:48 +01:00
dependabot[bot]
a33a187d48 Bump golang.org/x/sys from 0.12.0 to 0.13.0 (#30)
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.12.0 to 0.13.0.
- [Commits](https://github.com/golang/sys/compare/v0.12.0...v0.13.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-05 23:53:11 +01:00
dependabot[bot]
70c6ee15c6 Bump golang.org/x/sync from 0.3.0 to 0.4.0 (#29)
Bumps [golang.org/x/sync](https://github.com/golang/sync) from 0.3.0 to 0.4.0.
- [Commits](https://github.com/golang/sync/compare/v0.3.0...v0.4.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sync
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-05 23:43:11 +01:00
Nuno Cruces
994d9b1812 Updated dependencies. 2023-10-02 10:09:26 +01:00
Nuno Cruces
b19bd28ed3 Simplify lock timeouts. 2023-10-02 10:06:09 +01:00
Nuno Cruces
e66bd51845 More VFS API. 2023-09-21 02:43:45 +01:00
Nuno Cruces
f5614bc2ed Tweaks. 2023-09-20 15:07:07 +01:00
Nuno Cruces
d9fcf60b7d Driver API. 2023-09-20 02:41:09 +01:00
Nuno Cruces
ac6dd1aa5f Updated dependencies. 2023-09-18 15:22:11 +01:00
Nuno Cruces
b1495bd6cb Build tags, docs. 2023-09-18 15:11:05 +01:00
Nuno Cruces
2d91760295 Portability. 2023-09-18 12:44:18 +01:00
Nuno Cruces
38d4254bc4 Update README.md 2023-09-15 15:37:57 +01:00
Nuno Cruces
c0aa734786 binaryen-version_116. 2023-09-15 15:10:08 +01:00
Nuno Cruces
fa845dbd3d Run test in all platforms. 2023-09-12 15:30:43 +01:00
Nuno Cruces
fed315ab79 Update go.yml 2023-09-12 15:28:11 +01:00
Nuno Cruces
726d7316f7 Update README.md 2023-09-12 00:00:32 +01:00
Nuno Cruces
ddb387b021 Updated dependencies. 2023-09-11 23:54:22 +01:00
Nuno Cruces
d0f19507f5 SQLite 3.43.1. 2023-09-11 23:48:38 +01:00
Nuno Cruces
9d997552ad Pearson correlation. 2023-09-02 00:48:55 +01:00
Nuno Cruces
9d75c39dcc Update README.md 2023-09-01 16:01:42 +01:00
Nuno Cruces
746a84965e Covariance. 2023-09-01 02:38:57 +01:00
Nuno Cruces
312d3b58f2 Statistics functions. 2023-09-01 01:23:25 +01:00
Nuno Cruces
b71cd295c2 Updated dependencies. 2023-08-25 09:56:09 +01:00
Nuno Cruces
5b3b61a304 SQLite 3.43.0. 2023-08-24 18:56:23 +01:00
Nuno Cruces
d661d15723 wazero v1.5.0. 2023-08-24 18:56:10 +01:00
Nuno Cruces
1e38165ad0 Timer resolution. 2023-08-20 03:12:55 +01:00
102 changed files with 2704 additions and 903 deletions

29
.github/workflows/bsd.yml vendored Normal file
View File

@@ -0,0 +1,29 @@
name: BSD
on:
workflow_dispatch:
jobs:
test:
runs-on: macos-12
steps:
- uses: actions/checkout@v4
with:
lfs: 'true'
- name: Set up
uses: actions/setup-go@v4
with:
go-version: stable
- name: Build
run: GOOS=freebsd go test -c ./...
- name: Test
uses: cross-platform-actions/action@v0.21.1
with:
operating_system: freebsd
version: '13.2'
sync_files: runner-to-vm
run: find . -name '*.test' -maxdepth 1 -exec {} -test.v \;

View File

@@ -1,76 +0,0 @@
# For most projects, this workflow file will not need changing; you simply need
# to commit it to your repository.
#
# You may wish to alter this file to override the set of languages analyzed,
# or to provide custom queries or build logic.
#
# ******** NOTE ********
# We have attempted to detect the languages in your repository. Please check
# the `language` matrix defined below to confirm you have the correct set of
# supported CodeQL languages.
#
name: "CodeQL"
on:
push:
branches: [ "main" ]
pull_request:
# The branches below must be a subset of the branches above
branches: [ "main" ]
schedule:
- cron: '15 18 * * 6'
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: [ 'go' ]
# CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
# Use only 'java' to analyze code written in Java, Kotlin or both
# Use only 'javascript' to analyze code written in JavaScript, TypeScript or both
# Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support
steps:
- name: Checkout repository
uses: actions/checkout@v3
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v2
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
# queries: security-extended,security-and-quality
# Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v2
# Command-line programs to run using the OS shell.
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
# If the Autobuild fails above, remove it and uncomment the following three lines.
# modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance.
# - run: |
# echo "Run, Build Application using script"
# ./location_of_script_within_repo/buildscript.sh
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2
with:
category: "/language:${{matrix.language}}"

22
.github/workflows/cross.sh vendored Executable file
View File

@@ -0,0 +1,22 @@
#!/usr/bin/env bash
echo android ; GOOS=android GOARCH=amd64 go build .
echo darwin ; GOOS=darwin GOARCH=amd64 go build .
echo dragonfly ; GOOS=dragonfly GOARCH=amd64 go build .
echo freebsd ; GOOS=freebsd GOARCH=amd64 go build .
echo illumos ; GOOS=illumos GOARCH=amd64 go build .
echo ios ; GOOS=ios GOARCH=amd64 go build .
echo linux ; GOOS=linux GOARCH=amd64 go build .
echo netbsd ; GOOS=netbsd GOARCH=amd64 go build .
echo openbsd ; GOOS=openbsd GOARCH=amd64 go build .
echo plan9 ; GOOS=plan9 GOARCH=amd64 go build .
echo solaris ; GOOS=solaris GOARCH=amd64 go build .
echo windows ; GOOS=windows GOARCH=amd64 go build .
# echo aix ; GOOS=aix GOARCH=ppc64 go build .
echo js ; GOOS=js GOARCH=wasm go build .
echo wasip1 ; GOOS=wasip1 GOARCH=wasm go build .
echo darwin-flock ; GOOS=darwin GOARCH=amd64 go build -tags sqlite3_flock .
echo darwin-nosys ; GOOS=darwin GOARCH=amd64 go build -tags sqlite3_nosys .
echo linux-nosys ; GOOS=linux GOARCH=amd64 go build -tags sqlite3_nosys .
echo windows-nosys ; GOOS=windows GOARCH=amd64 go build -tags sqlite3_nosys .
echo freebsd-nosys ; GOOS=freebsd GOARCH=amd64 go build -tags sqlite3_nosys .

21
.github/workflows/cross.yml vendored Normal file
View File

@@ -0,0 +1,21 @@
name: Cross compile
on:
workflow_dispatch:
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
lfs: 'true'
- name: Set up
uses: actions/setup-go@v4
with:
go-version: stable
- name: Build
run: .github/workflows/cross.sh

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
lfs: 'true'
@@ -34,13 +34,11 @@ jobs:
- name: Download
run: go mod download
# Fixed in go 1.21: https://go.dev/issue/54372
# - name: Verify
# run: go mod verify
- name: Verify
run: go mod verify
- name: Vet
run: go vet ./...
continue-on-error: true
- name: Build
run: go build -v ./...
@@ -48,17 +46,19 @@ jobs:
- name: Test
run: go test -v ./...
- name: Test no locks
run: go test -v -tags sqlite3_nosys ./tests -run TestDB_nolock
- name: Test BSD locks
run: go test -v -tags sqlite3_bsd ./...
run: go test -v -tags sqlite3_flock ./...
if: matrix.os == 'macos-latest'
- name: Coverage report
uses: ncruces/go-coverage-report@v0
with:
chart: 'true'
amend: 'true'
reuse-go: 'true'
chart: true
amend: true
reuse-go: true
if: |
matrix.os == 'ubuntu-latest' &&
github.event_name == 'push'
continue-on-error: true
github.event_name == 'push' &&
matrix.os == 'ubuntu-latest'

View File

@@ -7,18 +7,28 @@
Go module `github.com/ncruces/go-sqlite3` wraps a [WASM](https://webassembly.org/) build of [SQLite](https://sqlite.org/),
and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings.
- Package [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
wraps the [C SQLite API](https://www.sqlite.org/cintro.html)
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
- Package [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver)
provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
- Package [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
embeds a build of SQLite into your application.
- Package [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
- Package [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
provides a [GORM](https://gorm.io) driver.
- [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
wraps the [C SQLite API](https://www.sqlite.org/cintro.html)
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
- [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver)
provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
- [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
embeds a build of SQLite into your application.
- [`github.com/ncruces/go-sqlite3/ext/blob`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/blob)
simplifies incremental BLOB I/O.
- [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
registers [statistics functions](https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html).
- [`github.com/ncruces/go-sqlite3/ext/unicode`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/unicode)
registers Unicode aware functions.
- [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
- [`github.com/ncruces/go-sqlite3/vfs/memdb`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/memdb)
implements an in-memory VFS.
- [`github.com/ncruces/go-sqlite3/vfs/readervfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/readervfs)
implements a VFS for immutable databases.
- [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
provides a [GORM](https://gorm.io) driver.
### Caveats
@@ -35,30 +45,39 @@ To work around this limitation, SQLite is [patched](sqlite3/locking_mode.patch)
to always use `EXCLUSIVE` locking mode for WAL databases.
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
to open WAL databases you should disable connection pooling by calling
to use the [`database/sql`](https://pkg.go.dev/database/sql) driver
with WAL mode databases you should disable connection pooling by calling
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
#### POSIX Advisory Locks
#### File Locking
POSIX advisory locks, which SQLite uses, are
[broken by design](https://www.sqlite.org/src/artifact/90c4fa?ln=1073-1161).
POSIX advisory locks, which SQLite uses on Unix, are
[broken by design](https://www.sqlite.org/src/artifact/2e8b12?ln=1073-1161).
On Linux, macOS and illumos, this module uses
[OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
to synchronize access to database files.
OFD locks are fully compatible with process-associated POSIX advisory locks.
OFD locks are fully compatible with POSIX advisory locks.
On BSD Unixes, this module uses
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2).
BSD locks may _not_ be compatible with process-associated POSIX advisory locks
(they are on FreeBSD).
On BSD Unixes, BSD locks are fully compatible with POSIX advisory locks.
On Windows, this module uses `LockFile`, `LockFileEx`, and `UnlockFile`,
like SQLite.
On all other platforms, file locking is not supported, and you must use
[`nolock=1`](https://www.sqlite.org/uri.html#urinolock)
to open database files.
To use the [`database/sql`](https://pkg.go.dev/database/sql) driver
with `nolock=1` you must disable connection pooling by calling
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
#### Testing
The pure Go VFS is tested by running SQLite's
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
on Linux, macOS and Windows;
BSD code paths are tested on macOS using the `sqlite3_bsd` build tag.
on Linux, macOS, Windows and FreeBSD.
Performance is tested by running
[speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c).
@@ -69,17 +88,18 @@ Performance is tested by running
- [x] nested transactions
- [x] incremental BLOB I/O
- [x] online backup
- [x] JSON support
- [ ] virtual tables
- [ ] session extension
- [ ] custom VFSes
- [x] custom VFS API
- [x] in-memory VFS
- [x] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
- [ ] [MVCC](https://en.wikipedia.org/wiki/Multiversion_concurrency_control) VFS, using [BadgerDB](https://github.com/dgraph-io/badger)
### Alternatives
- [`modernc.org/sqlite`](https://pkg.go.dev/modernc.org/sqlite)
- [`crawshaw.io/sqlite`](https://pkg.go.dev/crawshaw.io/sqlite)
- [`github.com/mattn/go-sqlite3`](https://pkg.go.dev/github.com/mattn/go-sqlite3)
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)

View File

@@ -118,8 +118,8 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
return 0, nil
}
want := int64(1024 * 1024)
avail := b.bytes - b.offset
want := int64(65536)
if want > avail {
want = avail
}
@@ -175,8 +175,11 @@ func (b *Blob) Write(p []byte) (n int, err error) {
//
// https://www.sqlite.org/c3ref/blob_write.html
func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) {
want := int64(1024 * 1024)
avail := b.bytes - b.offset
want := int64(65536)
if l, ok := r.(*io.LimitedReader); ok && want > l.N {
want = l.N
}
if want > avail {
want = avail
}

94
conn.go
View File

@@ -2,16 +2,14 @@ package sqlite3
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"net/url"
"runtime"
"strings"
"sync/atomic"
"unsafe"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
)
// Conn is a database connection handle.
@@ -22,7 +20,6 @@ type Conn struct {
*sqlite
interrupt context.Context
waiter chan struct{}
pending *Stmt
arena arena
@@ -49,6 +46,8 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
return newConn(filename, flags)
}
type connKey struct{}
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
sqlite, err := instantiateSQLite()
if err != nil {
@@ -64,6 +63,7 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
c := &Conn{sqlite: sqlite}
c.arena = c.newArena(1024)
c.ctx = context.WithValue(c.ctx, connKey{}, c)
c.handle, err = c.openDB(filename, flags)
if err != nil {
return nil, err
@@ -132,7 +132,6 @@ func (c *Conn) Close() error {
return nil
}
c.SetInterrupt(context.Background())
c.pending.Close()
c.pending = nil
@@ -240,65 +239,45 @@ func (c *Conn) Changes() int64 {
//
// https://www.sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
// Is a waiter running?
if c.waiter != nil {
c.waiter <- struct{}{} // Cancel the waiter.
<-c.waiter // Wait for it to finish.
c.waiter = nil
// Is it the same context?
if ctx == c.interrupt {
return ctx
}
// Reset the pending statement.
if c.pending != nil {
// An uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
if c.pending == nil {
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
} else {
c.pending.Reset()
}
old = c.interrupt
c.interrupt = ctx
// Remove the handler if the context can't be canceled.
if ctx == nil || ctx.Done() == nil {
c.call(c.api.progressHandler, uint64(c.handle), 0)
return old
}
// Creating an uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
if c.pending == nil {
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
}
c.pending.Step()
// Don't create the goroutine if we're already interrupted.
// This happens frequently while restoring to a previously interrupted state.
if c.checkInterrupt() {
return old
}
waiter := make(chan struct{})
c.waiter = waiter
go func() {
select {
case <-waiter: // Waiter was cancelled.
break
case <-ctx.Done(): // Done was closed.
const isInterruptedOffset = 288
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
// Wait for the next call to SetInterrupt.
<-waiter
}
// Signal that the waiter has finished.
waiter <- struct{}{}
}()
c.call(c.api.progressHandler, uint64(c.handle), 100)
return old
}
func (c *Conn) checkInterrupt() bool {
if c.interrupt == nil || c.interrupt.Err() == nil {
return false
func callbackProgress(ctx context.Context, mod api.Module, _ uint32) uint32 {
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
if c.interrupt != nil && c.interrupt.Err() != nil {
return 1
}
}
return 0
}
func (c *Conn) checkInterrupt() {
if c.interrupt != nil && c.interrupt.Err() != nil {
c.call(c.api.interrupt, uint64(c.handle))
}
const isInterruptedOffset = 288
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
return true
}
// Pragma executes a PRAGMA statement and returns any results.
@@ -324,22 +303,9 @@ func (c *Conn) error(rc uint64, sql ...string) error {
// DriverConn is implemented by the SQLite [database/sql] driver connection.
//
// It can be used to access advanced SQLite features like
// [savepoints], [online backup] and [incremental BLOB I/O].
// It can be used to access SQLite features like [online backup].
//
// [savepoints]: https://www.sqlite.org/lang_savepoint.html
// [online backup]: https://www.sqlite.org/backup.html
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
type DriverConn interface {
driver.Conn
driver.ConnBeginTx
driver.ExecerContext
driver.ConnPrepareContext
SetInterrupt(ctx context.Context) (old context.Context)
Savepoint() Savepoint
Backup(srcDB, dstURI string) error
Restore(dstDB, srcURI string) error
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
Raw() *Conn
}

View File

@@ -97,6 +97,7 @@ const (
IOERR_ROLLBACK_ATOMIC ExtendedErrorCode = xErrorCode(IOERR) | (31 << 8)
IOERR_DATA ExtendedErrorCode = xErrorCode(IOERR) | (32 << 8)
IOERR_CORRUPTFS ExtendedErrorCode = xErrorCode(IOERR) | (33 << 8)
IOERR_IN_PAGE ExtendedErrorCode = xErrorCode(IOERR) | (34 << 8)
LOCKED_SHAREDCACHE ExtendedErrorCode = xErrorCode(LOCKED) | (1 << 8)
LOCKED_VTAB ExtendedErrorCode = xErrorCode(LOCKED) | (2 << 8)
BUSY_RECOVERY ExtendedErrorCode = xErrorCode(BUSY) | (1 << 8)

View File

@@ -1,6 +1,7 @@
package sqlite3
import (
"encoding/json"
"errors"
"math"
"time"
@@ -13,24 +14,33 @@ import (
//
// https://www.sqlite.org/c3ref/context.html
type Context struct {
*sqlite
c *Conn
handle uint32
}
// Conn returns the database connection of the
// [Conn.CreateFunction] or [Conn.CreateWindowFunction]
// routines that originally registered the application defined function.
//
// https://sqlite.org/c3ref/context_db_handle.html
func (ctx Context) Conn() *Conn {
return ctx.c
}
// SetAuxData saves metadata for argument n of the function.
//
// https://www.sqlite.org/c3ref/get_auxdata.html
func (c Context) SetAuxData(n int, data any) {
ptr := util.AddHandle(c.ctx, data)
c.call(c.api.setAuxData, uint64(c.handle), uint64(n), uint64(ptr))
func (ctx Context) SetAuxData(n int, data any) {
ptr := util.AddHandle(ctx.c.ctx, data)
ctx.c.call(ctx.c.api.setAuxData, uint64(ctx.handle), uint64(n), uint64(ptr))
}
// GetAuxData returns metadata for argument n of the function.
//
// https://www.sqlite.org/c3ref/get_auxdata.html
func (c Context) GetAuxData(n int) any {
ptr := uint32(c.call(c.api.getAuxData, uint64(c.handle), uint64(n)))
return util.GetHandle(c.ctx, ptr)
func (ctx Context) GetAuxData(n int) any {
ptr := uint32(ctx.c.call(ctx.c.api.getAuxData, uint64(ctx.handle), uint64(n)))
return util.GetHandle(ctx.c.ctx, ptr)
}
// ResultBool sets the result of the function to a bool.
@@ -38,125 +48,160 @@ func (c Context) GetAuxData(n int) any {
// Instead, boolean values are stored as integers 0 (false) and 1 (true).
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultBool(value bool) {
func (ctx Context) ResultBool(value bool) {
var i int64
if value {
i = 1
}
c.ResultInt64(i)
ctx.ResultInt64(i)
}
// ResultInt sets the result of the function to an int.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultInt(value int) {
c.ResultInt64(int64(value))
func (ctx Context) ResultInt(value int) {
ctx.ResultInt64(int64(value))
}
// ResultInt64 sets the result of the function to an int64.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultInt64(value int64) {
c.call(c.api.resultInteger,
uint64(c.handle), uint64(value))
func (ctx Context) ResultInt64(value int64) {
ctx.c.call(ctx.c.api.resultInteger,
uint64(ctx.handle), uint64(value))
}
// ResultFloat sets the result of the function to a float64.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultFloat(value float64) {
c.call(c.api.resultFloat,
uint64(c.handle), math.Float64bits(value))
func (ctx Context) ResultFloat(value float64) {
ctx.c.call(ctx.c.api.resultFloat,
uint64(ctx.handle), math.Float64bits(value))
}
// ResultText sets the result of the function to a string.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultText(value string) {
ptr := c.newString(value)
c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(value)),
uint64(c.api.destructor), _UTF8)
func (ctx Context) ResultText(value string) {
ptr := ctx.c.newString(value)
ctx.c.call(ctx.c.api.resultText,
uint64(ctx.handle), uint64(ptr), uint64(len(value)),
uint64(ctx.c.api.destructor), _UTF8)
}
// ResultBlob sets the result of the function to a []byte.
// Returning a nil slice is the same as calling [Context.ResultNull].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultBlob(value []byte) {
ptr := c.newBytes(value)
c.call(c.api.resultBlob,
uint64(c.handle), uint64(ptr), uint64(len(value)),
uint64(c.api.destructor))
func (ctx Context) ResultBlob(value []byte) {
ptr := ctx.c.newBytes(value)
ctx.c.call(ctx.c.api.resultBlob,
uint64(ctx.handle), uint64(ptr), uint64(len(value)),
uint64(ctx.c.api.destructor))
}
// BindZeroBlob sets the result of the function to a zero-filled, length n BLOB.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultZeroBlob(n int64) {
c.call(c.api.resultZeroBlob,
uint64(c.handle), uint64(n))
func (ctx Context) ResultZeroBlob(n int64) {
ctx.c.call(ctx.c.api.resultZeroBlob,
uint64(ctx.handle), uint64(n))
}
// ResultNull sets the result of the function to NULL.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultNull() {
c.call(c.api.resultNull,
uint64(c.handle))
func (ctx Context) ResultNull() {
ctx.c.call(ctx.c.api.resultNull,
uint64(ctx.handle))
}
// ResultTime sets the result of the function to a [time.Time].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultTime(value time.Time, format TimeFormat) {
func (ctx Context) ResultTime(value time.Time, format TimeFormat) {
if format == TimeFormatDefault {
c.resultRFC3339Nano(value)
ctx.resultRFC3339Nano(value)
return
}
switch v := format.Encode(value).(type) {
case string:
c.ResultText(v)
ctx.ResultText(v)
case int64:
c.ResultInt64(v)
ctx.ResultInt64(v)
case float64:
c.ResultFloat(v)
ctx.ResultFloat(v)
default:
panic(util.AssertErr())
}
}
func (c Context) resultRFC3339Nano(value time.Time) {
const maxlen = uint64(len(time.RFC3339Nano))
func (ctx Context) resultRFC3339Nano(value time.Time) {
const maxlen = uint64(len(time.RFC3339Nano)) + 5
ptr := c.new(maxlen)
buf := util.View(c.mod, ptr, maxlen)
ptr := ctx.c.new(maxlen)
buf := util.View(ctx.c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(buf)),
uint64(c.api.destructor), _UTF8)
ctx.c.call(ctx.c.api.resultText,
uint64(ctx.handle), uint64(ptr), uint64(len(buf)),
uint64(ctx.c.api.destructor), _UTF8)
}
// ResultPointer sets the result of the function to NULL, just like [Context.ResultNull],
// except that it also associates ptr with that NULL value such that it can be retrieved
// within an application-defined SQL function using [Value.Pointer].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultPointer(ptr any) {
valPtr := util.AddHandle(ctx.c.ctx, ptr)
ctx.c.call(ctx.c.api.resultPointer, uint64(valPtr))
}
// ResultJSON sets the result of the function to the JSON encoding of value.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultJSON(value any) {
data, err := json.Marshal(value)
if err != nil {
ctx.ResultError(err)
}
ptr := ctx.c.newBytes(data)
ctx.c.call(ctx.c.api.resultText,
uint64(ctx.handle), uint64(ptr), uint64(len(data)),
uint64(ctx.c.api.destructor))
}
// ResultValue sets the result of the function a copy of [Value].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultValue(value Value) {
if value.sqlite != ctx.c.sqlite {
ctx.ResultError(MISUSE)
}
ctx.c.call(ctx.c.api.resultValue,
uint64(ctx.handle), uint64(value.handle))
}
// ResultError sets the result of the function an error.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultError(err error) {
func (ctx Context) ResultError(err error) {
if errors.Is(err, NOMEM) {
c.call(c.api.resultErrorMem, uint64(c.handle))
ctx.c.call(ctx.c.api.resultErrorMem, uint64(ctx.handle))
return
}
if errors.Is(err, TOOBIG) {
c.call(c.api.resultErrorBig, uint64(c.handle))
ctx.c.call(ctx.c.api.resultErrorBig, uint64(ctx.handle))
return
}
str := err.Error()
ptr := c.newString(str)
c.call(c.api.resultError,
uint64(c.handle), uint64(ptr), uint64(len(str)))
c.free(ptr)
ptr := ctx.c.newString(str)
ctx.c.call(ctx.c.api.resultError,
uint64(ctx.handle), uint64(ptr), uint64(len(str)))
ctx.c.free(ptr)
var code uint64
var ecode ErrorCode
@@ -168,7 +213,7 @@ func (c Context) ResultError(err error) {
code = uint64(ecode)
}
if code != 0 {
c.call(c.api.resultErrorCode,
uint64(c.handle), code)
ctx.c.call(ctx.c.api.resultErrorCode,
uint64(ctx.handle), code)
}
}

View File

@@ -30,6 +30,8 @@ import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"io"
"net/url"
@@ -40,14 +42,34 @@ import (
"github.com/ncruces/go-sqlite3/internal/util"
)
// This variable can be replaced with -ldflags:
//
// go build -ldflags="-X github.com/ncruces/go-sqlite3.driverName=sqlite"
var driverName = "sqlite3"
func init() {
sql.Register("sqlite3", sqlite{})
if driverName != "" {
sql.Register(driverName, sqlite{})
}
}
// Open opens the SQLite database specified by dataSourceName as a [database/sql.DB].
//
// The init function is called by the driver on new connections.
// The conn can be used to execute queries, register functions, etc.
// Any error return closes the conn and passes the error to [database/sql].
func Open(dataSourceName string, init func(*sqlite3.Conn) error) (*sql.DB, error) {
c, err := newConnector(dataSourceName, init)
if err != nil {
return nil, err
}
return sql.OpenDB(c), nil
}
type sqlite struct{}
func (sqlite) Open(name string) (driver.Conn, error) {
c, err := sqlite{}.OpenConnector(name)
c, err := newConnector(name, nil)
if err != nil {
return nil, err
}
@@ -55,7 +77,11 @@ func (sqlite) Open(name string) (driver.Conn, error) {
}
func (sqlite) OpenConnector(name string) (driver.Connector, error) {
c := connector{name: name}
return newConnector(name, nil)
}
func newConnector(name string, init func(*sqlite3.Conn) error) (*connector, error) {
c := connector{name: name, init: init}
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
query, err := url.ParseQuery(after)
@@ -70,6 +96,7 @@ func (sqlite) OpenConnector(name string) (driver.Connector, error) {
}
type connector struct {
init func(*sqlite3.Conn) error
name string
txlock string
pragmas bool
@@ -107,19 +134,22 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
if err != nil {
return nil, err
}
c.reusable = true
} else {
s, _, err := c.Conn.Prepare(`
SELECT * FROM
PRAGMA_locking_mode,
PRAGMA_query_only;
`)
}
if n.init != nil {
err = n.init(c.Conn)
if err != nil {
return nil, err
}
if s.Step() {
c.reusable = s.ColumnText(0) == "normal"
c.readOnly = s.ColumnRawText(1)[0] // 0 or 1
}
if n.pragmas || n.init != nil {
s, _, err := c.Conn.Prepare(`PRAGMA query_only`)
if err != nil {
return nil, err
}
if s.Step() && s.ColumnBool(0) {
c.readOnly = '1'
} else {
c.readOnly = '0'
}
err = s.Close()
if err != nil {
@@ -134,20 +164,19 @@ type conn struct {
txBegin string
txCommit string
txRollback string
reusable bool
readOnly byte
}
var (
// Ensure these interfaces are implemented:
_ driver.ExecerContext = &conn{}
_ driver.ConnBeginTx = &conn{}
_ driver.Validator = &conn{}
_ sqlite3.DriverConn = &conn{}
_ driver.ConnPrepareContext = &conn{}
_ driver.ExecerContext = &conn{}
_ driver.ConnBeginTx = &conn{}
_ sqlite3.DriverConn = &conn{}
)
func (c *conn) IsValid() bool {
return c.reusable
func (c *conn) Raw() *sqlite3.Conn {
return c.Conn
}
func (c *conn) Begin() (driver.Tx, error) {
@@ -163,10 +192,10 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
txBegin = `
BEGIN deferred;
PRAGMA query_only=on`
c.txCommit = `
c.txRollback = `
ROLLBACK;
PRAGMA query_only=` + string(c.readOnly)
c.txRollback = c.txCommit
c.txCommit = c.txRollback
}
switch opts.Isolation {
@@ -190,14 +219,20 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
func (c *conn) Commit() error {
err := c.Conn.Exec(c.txCommit)
if err != nil && !c.GetAutocommit() {
if err != nil && !c.Conn.GetAutocommit() {
c.Rollback()
}
return err
}
func (c *conn) Rollback() error {
return c.Conn.Exec(c.txRollback)
err := c.Conn.Exec(c.txRollback)
if errors.Is(err, sqlite3.INTERRUPT) {
old := c.Conn.SetInterrupt(context.Background())
defer c.Conn.SetInterrupt(old)
err = c.Conn.Exec(c.txRollback)
}
return err
}
func (c *conn) Prepare(query string) (driver.Stmt, error) {
@@ -234,6 +269,12 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
return nil, driver.ErrSkip
}
if savept, ok := ctx.(*saveptCtx); ok {
// Called from driver.Savepoint.
savept.Savepoint = c.Savepoint()
return resultRowsAffected(0), nil
}
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
@@ -245,6 +286,10 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
return newResult(c.Conn), nil
}
func (*conn) CheckNamedValue(arg *driver.NamedValue) error {
return nil
}
type stmt struct {
Stmt *sqlite3.Stmt
Conn *sqlite3.Conn
@@ -341,8 +386,12 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
err = s.Stmt.BindBlob(id, a)
case sqlite3.ZeroBlob:
err = s.Stmt.BindZeroBlob(id, int64(a))
case interface{ Value() any }:
err = s.Stmt.BindPointer(id, a.Value())
case time.Time:
err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
case json.Marshaler:
err = s.Stmt.BindJSON(id, a)
case nil:
err = s.Stmt.BindNull(id)
default:
@@ -359,7 +408,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) {
case bool, int, int64, float64, string, []byte,
sqlite3.ZeroBlob, time.Time, nil:
sqlite3.ZeroBlob, interface{ Value() any },
time.Time, json.Marshaler, nil:
return nil
default:
return driver.ErrSkip
@@ -438,11 +488,7 @@ func (r *rows) Next(dest []driver.Value) error {
case sqlite3.TEXT:
dest[i] = stringOrTime(r.Stmt.ColumnRawText(i))
case sqlite3.NULL:
if buf, ok := dest[i].([]byte); ok {
dest[i] = buf[0:0]
} else {
dest[i] = nil
}
dest[i] = nil
default:
panic(util.AssertErr())
}

65
driver/json_test.go Normal file
View File

@@ -0,0 +1,65 @@
package driver_test
import (
"fmt"
"log"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
func Example_json() {
db, err := driver.Open("file:/test.db?vfs=memdb", nil)
if err != nil {
log.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`
CREATE TABLE orders (
cart_id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
cart TEXT
);
`)
if err != nil {
log.Fatal(err)
}
type CartItem struct {
ItemID string `json:"id"`
Name string `json:"name"`
Quantity int `json:"quantity,omitempty"`
Price int `json:"price,omitempty"`
}
type Cart struct {
Items []CartItem `json:"items"`
}
_, err = db.Exec(`INSERT INTO orders (user_id, cart) VALUES (?, ?)`, 123, sqlite3.JSON(Cart{
[]CartItem{
{ItemID: "111", Name: "T-shirt", Quantity: 1, Price: 250},
{ItemID: "222", Name: "Trousers", Quantity: 1, Price: 600},
},
}))
if err != nil {
log.Fatal(err)
}
var total string
err = db.QueryRow(`
SELECT total(json_each.value -> 'price')
FROM orders, json_each(cart -> 'items')
WHERE cart_id = last_insert_rowid()
`).Scan(&total)
if err != nil {
log.Fatal(err)
}
fmt.Println("total:", total)
// Output:
// total: 850
}

27
driver/savepoint.go Normal file
View File

@@ -0,0 +1,27 @@
package driver
import (
"database/sql"
"time"
"github.com/ncruces/go-sqlite3"
)
// Savepoint establishes a new transaction savepoint.
//
// https://www.sqlite.org/lang_savepoint.html
func Savepoint(tx *sql.Tx) sqlite3.Savepoint {
var ctx saveptCtx
tx.ExecContext(&ctx, "")
return ctx.Savepoint
}
type saveptCtx struct{ sqlite3.Savepoint }
func (*saveptCtx) Deadline() (deadline time.Time, ok bool) { return }
func (*saveptCtx) Done() <-chan struct{} { return nil }
func (*saveptCtx) Err() error { return nil }
func (*saveptCtx) Value(key any) any { return nil }

87
driver/savepoint_test.go Normal file
View File

@@ -0,0 +1,87 @@
package driver_test
import (
"fmt"
"log"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
func ExampleSavepoint() {
db, err := driver.Open("file:/test.db?vfs=memdb", nil)
if err != nil {
log.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
err = func() error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
stmt, err := tx.Prepare(`INSERT INTO users (id, name) VALUES (?, ?)`)
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(0, "go")
if err != nil {
return err
}
_, err = stmt.Exec(1, "zig")
if err != nil {
return err
}
savept := driver.Savepoint(tx)
_, err = stmt.Exec(2, "whatever")
if err != nil {
return err
}
err = savept.Rollback()
if err != nil {
return err
}
_, err = stmt.Exec(3, "rust")
if err != nil {
return err
}
return tx.Commit()
}()
if err != nil {
log.Fatal(err)
}
rows, err := db.Query(`SELECT id, name FROM users`)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var id, name string
err = rows.Scan(&id, &name)
if err != nil {
log.Fatal(err)
}
fmt.Printf("%s %s\n", id, name)
}
// Output:
// 0 go
// 1 zig
// 3 rust
}

View File

@@ -1,75 +0,0 @@
package sqlite3_test
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
var db *sql.DB
func ExampleDriverConn() {
var err error
db, err = sql.Open("sqlite3", "demo.db")
if err != nil {
log.Fatal(err)
}
defer os.Remove("demo.db")
defer db.Close()
ctx := context.Background()
conn, err := db.Conn(ctx)
if err != nil {
log.Fatal(err)
}
defer conn.Close()
_, err = conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
log.Fatal(err)
}
res, err := conn.ExecContext(ctx, `INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(11))
if err != nil {
log.Fatal(err)
}
id, err := res.LastInsertId()
if err != nil {
log.Fatal(err)
}
err = conn.Raw(func(driverConn any) error {
conn := driverConn.(sqlite3.DriverConn)
savept := conn.Savepoint()
defer savept.Release(&err)
blob, err := conn.OpenBlob("main", "test", "col", id, true)
if err != nil {
return err
}
defer blob.Close()
_, err = fmt.Fprint(blob, "Hello BLOB!")
return err
})
if err != nil {
log.Fatal(err)
}
var msg string
err = conn.QueryRowContext(ctx, `SELECT col FROM test`).Scan(&msg)
if err != nil {
log.Fatal(err)
}
fmt.Println(msg)
// Output:
// Hello BLOB!
}

View File

@@ -1,6 +1,6 @@
# Embeddable WASM build of SQLite
This folder includes an embeddable WASM build of SQLite 3.42.0 for use with
This folder includes an embeddable WASM build of SQLite 3.44.0 for use with
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
The following optional features are compiled in:

View File

@@ -4,7 +4,7 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../
BINARYEN="$ROOT/tools/binaryen-version_114/bin"
BINARYEN="$ROOT/tools/binaryen-version_116/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \

View File

@@ -13,6 +13,8 @@ sqlite3_finalize
sqlite3_reset
sqlite3_step
sqlite3_exec
sqlite3_interrupt
sqlite3_progress_handler_go
sqlite3_clear_bindings
sqlite3_bind_parameter_count
sqlite3_bind_parameter_index
@@ -23,6 +25,7 @@ sqlite3_bind_double
sqlite3_bind_text64
sqlite3_bind_blob64
sqlite3_bind_zeroblob64
sqlite3_bind_pointer_go
sqlite3_column_count
sqlite3_column_name
sqlite3_column_type
@@ -62,12 +65,15 @@ sqlite3_value_double
sqlite3_value_text
sqlite3_value_blob
sqlite3_value_bytes
sqlite3_value_pointer_go
sqlite3_result_null
sqlite3_result_int64
sqlite3_result_double
sqlite3_result_text64
sqlite3_result_blob64
sqlite3_result_zeroblob64
sqlite3_result_pointer_go
sqlite3_result_value
sqlite3_result_error
sqlite3_result_error_code
sqlite3_result_error_nomem

Binary file not shown.

59
ext/blob/blob.go Normal file
View File

@@ -0,0 +1,59 @@
// Package blob provides an alternative interface to incremental BLOB I/O.
package blob
import (
"errors"
"github.com/ncruces/go-sqlite3"
)
// Register registers the blob_open SQL function.
func Register(db *sqlite3.Conn) {
db.CreateFunction("blob_open", -1,
sqlite3.DETERMINISTIC|sqlite3.DIRECTONLY, openBlob)
}
func openBlob(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) < 6 {
ctx.ResultError(errors.New("wrong number of arguments to function blob_open()"))
return
}
row := arg[3].Int64()
var err error
blob, ok := ctx.GetAuxData(0).(*sqlite3.Blob)
if ok {
err = blob.Reopen(row)
if errors.Is(err, sqlite3.MISUSE) {
// Blob was closed (db, table or column changed).
ok = false
}
}
if !ok {
db := arg[0].Text()
table := arg[1].Text()
column := arg[2].Text()
write := arg[4].Bool()
blob, err = ctx.Conn().OpenBlob(db, table, column, row, write)
}
if err != nil {
ctx.ResultError(err)
return
}
fn := arg[5].Pointer().(OpenCallback)
err = fn(blob, arg[6:]...)
if err != nil {
ctx.ResultError(err)
return
}
// This ensures the blob is closed if db, table or column change.
ctx.SetAuxData(0, blob)
ctx.SetAuxData(1, blob)
ctx.SetAuxData(2, blob)
}
type OpenCallback func(*sqlite3.Blob, ...sqlite3.Value) error

61
ext/blob/blob_test.go Normal file
View File

@@ -0,0 +1,61 @@
package blob_test
import (
"io"
"log"
"os"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/blob"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
func Example() {
// Open the database, registering the extension.
db, err := driver.Open("file:/test.db?vfs=memdb", func(conn *sqlite3.Conn) error {
blob.Register(conn)
return nil
})
if err != nil {
log.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
log.Fatal(err)
}
const message = "Hello BLOB!"
// Create the BLOB.
_, err = db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(len(message)))
if err != nil {
log.Fatal(err)
}
// Write the BLOB.
_, err = db.Exec(`SELECT blob_open('main', 'test', 'col', last_insert_rowid(), true, ?)`,
sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
_, err = io.WriteString(blob, message)
return err
}))
if err != nil {
log.Fatal(err)
}
// Read the BLOB.
_, err = db.Exec(`SELECT blob_open('main', 'test', 'col', rowid, false, ?) FROM test`,
sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
_, err = io.Copy(os.Stdout, blob)
return err
}))
if err != nil {
log.Fatal(err)
}
// Output:
// Hello BLOB!
}

109
ext/stats/stats.go Normal file
View File

@@ -0,0 +1,109 @@
// Package stats provides aggregate functions for statistics.
//
// Functions:
// - stddev_pop: population standard deviation
// - stddev_samp: sample standard deviation
// - var_pop: population variance
// - var_samp: sample variance
// - covar_pop: population covariance
// - covar_samp: sample covariance
// - corr: correlation coefficient
//
// See: [ANSI SQL Aggregate Functions]
//
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
package stats
import "github.com/ncruces/go-sqlite3"
// Register registers statistics functions.
func Register(db *sqlite3.Conn) {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop))
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp))
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop))
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp))
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr))
}
const (
var_pop = iota
var_samp
stddev_pop
stddev_samp
corr
)
func newVariance(kind int) func() sqlite3.AggregateFunction {
return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
}
type variance struct {
kind int
welford
}
func (fn *variance) Value(ctx sqlite3.Context) {
var r float64
switch fn.kind {
case var_pop:
r = fn.var_pop()
case var_samp:
r = fn.var_samp()
case stddev_pop:
r = fn.stddev_pop()
case stddev_samp:
r = fn.stddev_samp()
}
ctx.ResultFloat(r)
}
func (fn *variance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
if a := arg[0]; a.Type() != sqlite3.NULL {
fn.enqueue(a.Float())
}
}
func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
if a := arg[0]; a.Type() != sqlite3.NULL {
fn.dequeue(a.Float())
}
}
func newCovariance(kind int) func() sqlite3.AggregateFunction {
return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
}
type covariance struct {
kind int
welford2
}
func (fn *covariance) Value(ctx sqlite3.Context) {
var r float64
switch fn.kind {
case var_pop:
r = fn.covar_pop()
case var_samp:
r = fn.covar_samp()
case corr:
r = fn.correlation()
}
ctx.ResultFloat(r)
}
func (fn *covariance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
a, b := arg[0], arg[1]
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
fn.enqueue(a.Float(), b.Float())
}
}
func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
a, b := arg[0], arg[1]
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
fn.dequeue(a.Float(), b.Float())
}
}

140
ext/stats/stats_test.go Normal file
View File

@@ -0,0 +1,140 @@
package stats
import (
"math"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestRegister_variance(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
Register(db)
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`
SELECT
sum(x), avg(x),
var_samp(x), var_pop(x),
stddev_samp(x), stddev_pop(x)
FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 40 {
t.Errorf("got %v, want 40", got)
}
if got := stmt.ColumnFloat(1); got != 10 {
t.Errorf("got %v, want 10", got)
}
if got := stmt.ColumnFloat(2); got != 30 {
t.Errorf("got %v, want 30", got)
}
if got := stmt.ColumnFloat(3); got != 22.5 {
t.Errorf("got %v, want 22.5", got)
}
if got := stmt.ColumnFloat(4); got != math.Sqrt(30) {
t.Errorf("got %v, want √30", got)
}
if got := stmt.ColumnFloat(5); got != math.Sqrt(22.5) {
t.Errorf("got %v, want √22.5", got)
}
}
{
stmt, _, err := db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
want := [...]float64{0, 4.5, 18, 0, 0}
for i := 0; stmt.Step(); i++ {
if got := stmt.ColumnFloat(0); got != want[i] {
t.Errorf("got %v, want %v", got, want[i])
}
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
t.Errorf("got %v, want %v", got, want[i])
}
}
}
}
func TestRegister_covariance(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
Register(db)
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x, y)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO data (x, y) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT
corr(x, y), covar_samp(x, y), covar_pop(x, y) FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
t.Errorf("got %v, want 0.9881049293224639", got)
}
if got := stmt.ColumnFloat(1); got != 21.25 {
t.Errorf("got %v, want 21.25", got)
}
if got := stmt.ColumnFloat(2); got != 17 {
t.Errorf("got %v, want 17", got)
}
}
{
stmt, _, err := db.Prepare(`SELECT covar_samp(x, y) OVER (ROWS 1 PRECEDING) FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
want := [...]float64{0, 10, 30, 75, 22.5}
for i := 0; stmt.Step(); i++ {
if got := stmt.ColumnFloat(0); got != want[i] {
t.Errorf("got %v, want %v", got, want[i])
}
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
t.Errorf("got %v, want %v", got, want[i])
}
}
}
}

109
ext/stats/welford.go Normal file
View File

@@ -0,0 +1,109 @@
package stats
import "math"
// Welford's algorithm with Kahan summation:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
type welford struct {
m1, m2 kahan
n uint64
}
func (w welford) average() float64 {
return w.m1.hi
}
func (w welford) var_pop() float64 {
return w.m2.hi / float64(w.n)
}
func (w welford) var_samp() float64 {
return w.m2.hi / float64(w.n-1) // Bessel's correction
}
func (w welford) stddev_pop() float64 {
return math.Sqrt(w.var_pop())
}
func (w welford) stddev_samp() float64 {
return math.Sqrt(w.var_samp())
}
func (w *welford) enqueue(x float64) {
w.n++
d1 := x - w.m1.hi - w.m1.lo
w.m1.add(d1 / float64(w.n))
d2 := x - w.m1.hi - w.m1.lo
w.m2.add(d1 * d2)
}
func (w *welford) dequeue(x float64) {
w.n--
d1 := x - w.m1.hi - w.m1.lo
w.m1.sub(d1 / float64(w.n))
d2 := x - w.m1.hi - w.m1.lo
w.m2.sub(d1 * d2)
}
type welford2 struct {
m1x, m2x kahan
m1y, m2y kahan
cov kahan
n uint64
}
func (w welford2) covar_pop() float64 {
return w.cov.hi / float64(w.n)
}
func (w welford2) covar_samp() float64 {
return w.cov.hi / float64(w.n-1) // Bessel's correction
}
func (w welford2) correlation() float64 {
return w.cov.hi / math.Sqrt(w.m2x.hi*w.m2y.hi)
}
func (w *welford2) enqueue(x, y float64) {
w.n++
d1x := x - w.m1x.hi - w.m1x.lo
d1y := y - w.m1y.hi - w.m1y.lo
w.m1x.add(d1x / float64(w.n))
w.m1y.add(d1y / float64(w.n))
d2x := x - w.m1x.hi - w.m1x.lo
d2y := y - w.m1y.hi - w.m1y.lo
w.m2x.add(d1x * d2x)
w.m2y.add(d1y * d2y)
w.cov.add(d1x * d2y)
}
func (w *welford2) dequeue(x, y float64) {
w.n--
d1x := x - w.m1x.hi - w.m1x.lo
d1y := y - w.m1y.hi - w.m1y.lo
w.m1x.sub(d1x / float64(w.n))
w.m1y.sub(d1y / float64(w.n))
d2x := x - w.m1x.hi - w.m1x.lo
d2y := y - w.m1y.hi - w.m1y.lo
w.m2x.sub(d1x * d2x)
w.m2y.sub(d1y * d2y)
w.cov.sub(d1x * d2y)
}
type kahan struct{ hi, lo float64 }
func (k *kahan) add(x float64) {
y := k.lo + x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}
func (k *kahan) sub(x float64) {
y := k.lo - x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}

75
ext/stats/welford_test.go Normal file
View File

@@ -0,0 +1,75 @@
package stats
import (
"math"
"testing"
)
func Test_welford(t *testing.T) {
var s1, s2 welford
s1.enqueue(4)
s1.enqueue(7)
s1.enqueue(13)
s1.enqueue(16)
if got := s1.average(); got != 10 {
t.Errorf("got %v, want 10", got)
}
if got := s1.var_samp(); got != 30 {
t.Errorf("got %v, want 30", got)
}
if got := s1.var_pop(); got != 22.5 {
t.Errorf("got %v, want 22.5", got)
}
if got := s1.stddev_samp(); got != math.Sqrt(30) {
t.Errorf("got %v, want √30", got)
}
if got := s1.stddev_pop(); got != math.Sqrt(22.5) {
t.Errorf("got %v, want √22.5", got)
}
s1.dequeue(4)
s2.enqueue(7)
s2.enqueue(13)
s2.enqueue(16)
if s1.var_pop() != s2.var_pop() {
t.Errorf("got %v, want %v", s1, s2)
}
}
func Test_covar(t *testing.T) {
var c1, c2 welford2
c1.enqueue(3, 70)
c1.enqueue(5, 80)
c1.enqueue(2, 60)
c1.enqueue(7, 90)
c1.enqueue(4, 75)
if got := c1.covar_samp(); got != 21.25 {
t.Errorf("got %v, want 21.25", got)
}
if got := c1.covar_pop(); got != 17 {
t.Errorf("got %v, want 17", got)
}
c1.dequeue(3, 70)
c2.enqueue(5, 80)
c2.enqueue(2, 60)
c2.enqueue(7, 90)
c2.enqueue(4, 75)
if c1.covar_pop() != c2.covar_pop() {
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
}
}
func Test_correlation(t *testing.T) {
var c welford2
c.enqueue(1, 3)
c.enqueue(2, 2)
c.enqueue(3, 1)
if got := c.correlation(); got != -1 {
t.Errorf("got %v, want -1", got)
}
}

View File

@@ -1,17 +1,19 @@
// Package unicode provides an alternative to the SQLite ICU extension.
//
// Provides Unicode aware:
// - upper and lower functions,
// Like the [ICU extension], it provides Unicode aware:
// - upper() and lower() functions,
// - LIKE and REGEXP operators,
// - collation sequences.
//
// This package is not 100% compatible with the ICU extension:
// - upper and lower use [strings.ToUpper], [strings.ToLower] and [cases];
// The implementation is not 100% compatible with the [ICU extension]:
// - upper() and lower() use [strings.ToUpper], [strings.ToLower] and [cases];
// - the LIKE operator follows [strings.EqualFold] rules;
// - the REGEXP operator uses Go [regex/syntax];
// - collation sequences use [collate].
//
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
//
// [ICU extension]: https://sqlite.org/src/dir/ext/icu
package unicode
import (
@@ -45,7 +47,7 @@ func Register(db *sqlite3.Conn) {
return
}
err := RegisterCollation(db, name, arg[0].Text())
err := RegisterCollation(db, arg[0].Text(), name)
if err != nil {
ctx.ResultError(err)
return
@@ -53,8 +55,9 @@ func Register(db *sqlite3.Conn) {
})
}
func RegisterCollation(db *sqlite3.Conn, name, lang string) error {
tag, err := language.Parse(lang)
// RegisterCollation registers a Unicode collation sequence for a database connection.
func RegisterCollation(db *sqlite3.Conn, locale, name string) error {
tag, err := language.Parse(locale)
if err != nil {
return err
}

73
func.go
View File

@@ -4,7 +4,6 @@ import (
"context"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
@@ -12,7 +11,7 @@ import (
// for any unknown collating sequence.
// The fake collating function works like BINARY.
//
// This extension can be used to load schemas that contain
// This can be used to load schemas that contain
// one or more unknown collating sequences.
func (c *Conn) AnyCollationNeeded() {
c.call(c.api.anyCollation, uint64(c.handle), 0, 0)
@@ -47,6 +46,7 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
// If fn returns a [WindowFunction], then an aggregate window function is created.
// If fn returns an [io.Closer], it will be called to free resources.
//
// https://www.sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
@@ -70,7 +70,7 @@ type AggregateFunction interface {
// The function arguments, if any, corresponding to the row being added are passed to Step.
Step(ctx Context, arg ...Value)
// Value is invoked to return the current value of the aggregate.
// Value is invoked to return the current (or final) value of the aggregate.
Value(ctx Context)
}
@@ -85,17 +85,6 @@ type WindowFunction interface {
Inverse(ctx Context, arg ...Value)
}
func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
util.ExportFuncVIII(env, "go_func", callbackFunc)
util.ExportFuncVIII(env, "go_step", callbackStep)
util.ExportFuncVI(env, "go_final", callbackFinal)
util.ExportFuncVI(env, "go_value", callbackValue)
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
return env
}
func callbackDestroy(ctx context.Context, mod api.Module, pApp uint32) {
util.DelHandle(ctx, pApp)
}
@@ -106,57 +95,57 @@ func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nK
}
func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackHandle(sqlite, pCtx).(func(ctx Context, arg ...Value))
fn(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
db := ctx.Value(connKey{}).(*Conn)
fn := callbackHandle(db, pCtx).(func(ctx Context, arg ...Value))
fn(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
}
func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
fn.Step(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
db := ctx.Value(connKey{}).(*Conn)
fn := callbackAggregate(db, pCtx, nil).(AggregateFunction)
fn.Step(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
}
func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) {
var handle uint32
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, &handle).(AggregateFunction)
fn.Value(Context{sqlite, pCtx})
db := ctx.Value(connKey{}).(*Conn)
fn := callbackAggregate(db, pCtx, &handle).(AggregateFunction)
fn.Value(Context{db, pCtx})
if err := util.DelHandle(ctx, handle); err != nil {
Context{sqlite, pCtx}.ResultError(err)
Context{db, pCtx}.ResultError(err)
}
}
func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
fn.Value(Context{sqlite, pCtx})
db := ctx.Value(connKey{}).(*Conn)
fn := callbackAggregate(db, pCtx, nil).(AggregateFunction)
fn.Value(Context{db, pCtx})
}
func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, nil).(WindowFunction)
fn.Inverse(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
db := ctx.Value(connKey{}).(*Conn)
fn := callbackAggregate(db, pCtx, nil).(WindowFunction)
fn.Inverse(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
}
func callbackHandle(sqlite *sqlite, pCtx uint32) any {
pApp := uint32(sqlite.call(sqlite.api.userData, uint64(pCtx)))
return util.GetHandle(sqlite.ctx, pApp)
func callbackHandle(db *Conn, pCtx uint32) any {
pApp := uint32(db.call(db.api.userData, uint64(pCtx)))
return util.GetHandle(db.ctx, pApp)
}
func callbackAggregate(sqlite *sqlite, pCtx uint32, close *uint32) any {
func callbackAggregate(db *Conn, pCtx uint32, close *uint32) any {
// On close, we're getting rid of the handle.
// Don't allocate space to store it.
var size uint64
if close == nil {
size = ptrlen
}
ptr := uint32(sqlite.call(sqlite.api.aggregateCtx, uint64(pCtx), size))
ptr := uint32(db.call(db.api.aggregateCtx, uint64(pCtx), size))
// Try loading the handle, if we already have one, or want a new one.
if ptr != 0 || size != 0 {
if handle := util.ReadUint32(sqlite.mod, ptr); handle != 0 {
fn := util.GetHandle(sqlite.ctx, handle)
if handle := util.ReadUint32(db.mod, ptr); handle != 0 {
fn := util.GetHandle(db.ctx, handle)
if close != nil {
*close = handle
}
@@ -167,19 +156,19 @@ func callbackAggregate(sqlite *sqlite, pCtx uint32, close *uint32) any {
}
// Create a new aggregate and store the handle.
fn := callbackHandle(sqlite, pCtx).(func() AggregateFunction)()
fn := callbackHandle(db, pCtx).(func() AggregateFunction)()
if ptr != 0 {
util.WriteUint32(sqlite.mod, ptr, util.AddHandle(sqlite.ctx, fn))
util.WriteUint32(db.mod, ptr, util.AddHandle(db.ctx, fn))
}
return fn
}
func callbackArgs(sqlite *sqlite, nArg, pArg uint32) []Value {
func callbackArgs(db *Conn, nArg, pArg uint32) []Value {
args := make([]Value, nArg)
for i := range args {
args[i] = Value{
sqlite: sqlite,
handle: util.ReadUint32(sqlite.mod, pArg+ptrlen*uint32(i)),
sqlite: db.sqlite,
handle: util.ReadUint32(db.mod, pArg+ptrlen*uint32(i)),
}
}
return args

View File

@@ -26,7 +26,7 @@ func ExampleConn_CreateWindowFunction() {
log.Fatal(err)
}
err = db.CreateWindowFunction("count_ascii", 1, sqlite3.INNOCUOUS, newASCIICounter)
err = db.CreateWindowFunction("count_ascii", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, newASCIICounter)
if err != nil {
log.Fatal(err)
}

10
go.mod
View File

@@ -3,12 +3,12 @@ module github.com/ncruces/go-sqlite3
go 1.21
require (
github.com/ncruces/julianday v0.1.5
github.com/ncruces/julianday v1.0.0
github.com/psanford/httpreadat v0.1.0
github.com/tetratelabs/wazero v1.4.0
golang.org/x/sync v0.3.0
golang.org/x/sys v0.11.0
golang.org/x/text v0.12.0
github.com/tetratelabs/wazero v1.5.0
golang.org/x/sync v0.5.0
golang.org/x/sys v0.14.0
golang.org/x/text v0.14.0
)
retract v0.4.0 // tagged from the wrong branch

20
go.sum
View File

@@ -1,12 +1,12 @@
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
github.com/tetratelabs/wazero v1.4.0 h1:9/MirYvmkJ/zSUOygKY/ia3t+e+RqIZXKbylIby1WYk=
github.com/tetratelabs/wazero v1.4.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=

View File

@@ -1,3 +1,4 @@
github.com/ncruces/go-sqlite3 v0.9.1/go.mod h1:jFoUbaCDNUS1KN5ZgFxN7bgcWoWfO0EOKeik9QAHZ08=
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=

View File

@@ -125,7 +125,7 @@ func parseDDL(strs ...string) (*ddl, error) {
ColumnTypeValue: sql.NullString{String: matches[2], Valid: true},
PrimaryKeyValue: sql.NullBool{Valid: true},
UniqueValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Bool: true, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
}
@@ -175,6 +175,18 @@ func parseDDL(strs ...string) (*ddl, error) {
return &result, nil
}
func (d *ddl) clone() *ddl {
copied := new(ddl)
*copied = *d
copied.fields = make([]string, len(d.fields))
copy(copied.fields, d.fields)
copied.columns = make([]migrator.ColumnType, len(d.columns))
copy(copied.columns, d.columns)
return copied
}
func (d *ddl) compile() string {
if len(d.fields) == 0 {
return d.head
@@ -183,6 +195,21 @@ func (d *ddl) compile() string {
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
}
func (d *ddl) renameTable(dst, src string) error {
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + regexp.QuoteMeta(src) + "\\b('|`|\")?\\s*")
if err != nil {
return err
}
replaced := tableReg.ReplaceAllString(d.head, fmt.Sprintf(" `%s` ", dst))
if replaced == d.head {
return fmt.Errorf("failed to look up tablename `%s` from DDL head '%s'", src, d.head)
}
d.head = replaced
return nil
}
func (d *ddl) addConstraint(name string, sql string) {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
@@ -208,6 +235,18 @@ func (d *ddl) removeConstraint(name string) bool {
return false
}
//lint:ignore U1000 ignore unused code.
func (d *ddl) hasConstraint(name string) bool {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
for _, f := range d.fields {
if reg.MatchString(f) {
return true
}
}
return false
}
func (d *ddl) getColumns() []string {
res := []string{}
@@ -229,3 +268,30 @@ func (d *ddl) getColumns() []string {
}
return res
}
func (d *ddl) alterColumn(name, sql string) bool {
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields[i] = sql
return false
}
}
d.fields = append(d.fields, sql)
return true
}
func (d *ddl) removeColumn(name string) bool {
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields = append(d.fields[:i], d.fields[i+1:]...)
return true
}
}
return false
}

View File

@@ -20,16 +20,16 @@ func TestParseDDL(t *testing.T) {
"CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)",
}, 6, []migrator.ColumnType{
{NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}},
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
},
{"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
}},
{"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
@@ -39,7 +39,7 @@ func TestParseDDL(t *testing.T) {
{"with_special_characters", []string{
"CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")",
}, 1, []migrator.ColumnType{
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
},
{
@@ -122,7 +122,7 @@ func TestParseDDL_Whitespaces(t *testing.T) {
NameValue: sql.NullString{String: "id", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
NullableValue: sql.NullBool{Bool: true, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true},
@@ -131,7 +131,7 @@ func TestParseDDL_Whitespaces(t *testing.T) {
NameValue: sql.NullString{String: "dark_mode", Valid: true},
DataTypeValue: sql.NullString{String: "numeric", Valid: true},
ColumnTypeValue: sql.NullString{String: "numeric", Valid: true},
NullableValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Bool: true, Valid: true},
DefaultValueValue: sql.NullString{String: "true", Valid: true},
UniqueValue: sql.NullBool{Bool: false, Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},

View File

@@ -7,7 +7,7 @@ import (
"gorm.io/gorm"
)
func (dialector Dialector) Translate(err error) error {
func (_Dialector) Translate(err error) error {
switch {
case
errors.Is(err, sqlite3.CONSTRAINT_UNIQUE),

View File

@@ -3,14 +3,14 @@ module github.com/ncruces/go-sqlite3/gormlite
go 1.21
require (
github.com/ncruces/go-sqlite3 v0.8.5
gorm.io/gorm v1.25.4
github.com/ncruces/go-sqlite3 v0.9.1
gorm.io/gorm v1.25.5
)
require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/ncruces/julianday v0.1.5 // indirect
github.com/tetratelabs/wazero v1.4.0 // indirect
golang.org/x/sys v0.11.0 // indirect
github.com/tetratelabs/wazero v1.5.0 // indirect
golang.org/x/sys v0.13.0 // indirect
)

View File

@@ -2,14 +2,15 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/ncruces/go-sqlite3 v0.8.5 h1:JeNcbJ4rsZ07ZVyqPdnFlfmVSWDW0ONoiuZSUBC369Y=
github.com/ncruces/go-sqlite3 v0.8.5/go.mod h1:XvDtjKk5MgwHX7L4I7BPzzKl36bTZ7+Hr6Kr2QeVkVw=
github.com/ncruces/go-sqlite3 v0.9.1 h1:kV7Zy+ZNyHMfMyZeWc1Yyq+wtgYZDZdp2qAA/wfeMWo=
github.com/ncruces/go-sqlite3 v0.9.1/go.mod h1:jFoUbaCDNUS1KN5ZgFxN7bgcWoWfO0EOKeik9QAHZ08=
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.4.0 h1:9/MirYvmkJ/zSUOygKY/ia3t+e+RqIZXKbylIby1WYk=
github.com/tetratelabs/wazero v1.4.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc=
gorm.io/gorm v1.25.4 h1:iyNd8fNAe8W9dvtlgeRI5zSVZPsq3OpcTu37cYcpCmw=
gorm.io/gorm v1.25.4/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=

View File

@@ -3,7 +3,6 @@ package gormlite
import (
"database/sql"
"fmt"
"regexp"
"strings"
"gorm.io/gorm"
@@ -12,11 +11,11 @@ import (
"gorm.io/gorm/schema"
)
type Migrator struct {
type _Migrator struct {
migrator.Migrator
}
func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
func (m *_Migrator) RunWithoutForeignKey(fc func() error) error {
var enabled int
m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
if enabled == 1 {
@@ -27,7 +26,7 @@ func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
return fc()
}
func (m Migrator) HasTable(value interface{}) bool {
func (m _Migrator) HasTable(value interface{}) bool {
var count int
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
@@ -35,7 +34,7 @@ func (m Migrator) HasTable(value interface{}) bool {
return count > 0
}
func (m Migrator) DropTable(values ...interface{}) error {
func (m _Migrator) DropTable(values ...interface{}) error {
return m.RunWithoutForeignKey(func() error {
values = m.ReorderModels(values, false)
tx := m.DB.Session(&gorm.Session{})
@@ -52,11 +51,11 @@ func (m Migrator) DropTable(values ...interface{}) error {
})
}
func (m Migrator) GetTables() (tableList []string, err error) {
func (m _Migrator) GetTables() (tableList []string, err error) {
return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error
}
func (m Migrator) HasColumn(value interface{}, name string) bool {
func (m _Migrator) HasColumn(value interface{}, name string) bool {
var count int
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
@@ -76,31 +75,24 @@ func (m Migrator) HasColumn(value interface{}, name string) bool {
return count > 0
}
func (m Migrator) AlterColumn(value interface{}, name string) error {
func (m _Migrator) AlterColumn(value interface{}, name string) error {
return m.RunWithoutForeignKey(func() error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
if field := stmt.Schema.LookUpField(name); field != nil {
// lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)`
reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)")
if err != nil {
return "", nil, err
if ddl.alterColumn(field.DBName, fmt.Sprintf("`%s` ?", field.DBName)) {
return nil, nil, fmt.Errorf("field `%s` not found in origin ddl, ddl= '%s'", name, ddl.compile())
}
createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName))
if createSQL == rawDDL {
return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL)
}
return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil
return ddl, []interface{}{m.FullDataTypeOf(field)}, nil
}
return "", nil, fmt.Errorf("failed to alter field with name %v", name)
return nil, nil, fmt.Errorf("failed to alter field with name `%s`", name)
})
})
}
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
func (m _Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
var (
@@ -148,29 +140,23 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
return columnTypes, execErr
}
func (m Migrator) DropColumn(value interface{}, name string) error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
func (m _Migrator) DropColumn(value interface{}, name string) error {
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,")
if err != nil {
return "", nil, err
}
createSQL := reg.ReplaceAllString(rawDDL, "")
return createSQL, nil, nil
ddl.removeColumn(name)
return ddl, nil, nil
})
}
func (m Migrator) CreateConstraint(value interface{}, name string) error {
func (m _Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
return m.recreateTable(value, &table,
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
var (
constraintName string
constraintSql string
@@ -185,22 +171,16 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
constraintSql = "CONSTRAINT ? CHECK (?)"
constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
} else {
return "", nil, nil
return nil, nil, nil
}
createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
}
createDDL.addConstraint(constraintName, constraintSql)
createSQL := createDDL.compile()
return createSQL, constraintValues, nil
ddl.addConstraint(constraintName, constraintSql)
return ddl, constraintValues, nil
})
})
}
func (m Migrator) DropConstraint(value interface{}, name string) error {
func (m _Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil {
@@ -210,20 +190,14 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
}
return m.recreateTable(value, &table,
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
}
createDDL.removeConstraint(name)
createSQL := createDDL.compile()
return createSQL, nil, nil
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
ddl.removeConstraint(name)
return ddl, nil, nil
})
})
}
func (m Migrator) HasConstraint(value interface{}, name string) bool {
func (m _Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
@@ -244,13 +218,13 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
return count > 0
}
func (m Migrator) CurrentDatabase() (name string) {
func (m _Migrator) CurrentDatabase() (name string) {
var null interface{}
m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
return
}
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
func (m _Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
if opt.Expression != "" {
@@ -269,7 +243,7 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem
return
}
func (m Migrator) CreateIndex(value interface{}, name string) error {
func (m _Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
@@ -298,7 +272,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
})
}
func (m Migrator) HasIndex(value interface{}, name string) bool {
func (m _Migrator) HasIndex(value interface{}, name string) bool {
var count int
m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
@@ -317,7 +291,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
return count > 0
}
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
func (m _Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
var sql string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
@@ -331,7 +305,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error
})
}
func (m Migrator) DropIndex(value interface{}, name string) error {
func (m _Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
@@ -365,7 +339,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter
return
}
func (m Migrator) getRawDDL(table string) (string, error) {
func (m _Migrator) getRawDDL(table string) (string, error) {
var createSQL string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
@@ -375,8 +349,10 @@ func (m Migrator) getRawDDL(table string) (string, error) {
return createSQL, nil
}
func (m Migrator) recreateTable(value interface{}, tablePtr *string,
getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error {
func (m _Migrator) recreateTable(
value interface{}, tablePtr *string,
getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error),
) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
table := stmt.Table
if tablePtr != nil {
@@ -388,27 +364,26 @@ func (m Migrator) recreateTable(value interface{}, tablePtr *string,
return err
}
newTableName := table + "__temp"
createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt)
originDDL, err := parseDDL(rawDDL)
if err != nil {
return err
}
if createSQL == "" {
createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt)
if err != nil {
return err
}
if createDDL == nil {
return nil
}
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
if err != nil {
newTableName := table + "__temp"
if err := createDDL.renameTable(newTableName, table); err != nil {
return err
}
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
createDDL, err := parseDDL(createSQL)
if err != nil {
return err
}
columns := createDDL.getColumns()
createSQL := createDDL.compile()
return m.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {

View File

@@ -13,27 +13,33 @@ import (
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
_ "github.com/ncruces/go-sqlite3/driver"
"github.com/ncruces/go-sqlite3/driver"
)
type Dialector struct {
// Open opens a GORM dialector from a data source name.
func Open(dsn string) gorm.Dialector {
return &_Dialector{DSN: dsn}
}
// Open opens a GORM dialector from a database handle.
func OpenDB(db *sql.DB) gorm.Dialector {
return &_Dialector{Conn: db}
}
type _Dialector struct {
DSN string
Conn gorm.ConnPool
}
func Open(dsn string) gorm.Dialector {
return &Dialector{DSN: dsn}
}
func (dialector Dialector) Name() string {
func (dialector _Dialector) Name() string {
return "sqlite"
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
func (dialector _Dialector) Initialize(db *gorm.DB) (err error) {
if dialector.Conn != nil {
db.ConnPool = dialector.Conn
} else {
conn, err := sql.Open("sqlite3", dialector.DSN)
conn, err := driver.Open(dialector.DSN, nil)
if err != nil {
return err
}
@@ -48,7 +54,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
if compareVersion(version, "3.35.0") >= 0 {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE", "RETURNING"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
LastInsertIDReversed: true,
})
@@ -64,7 +70,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
return
}
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
func (dialector _Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
return map[string]clause.ClauseBuilder{
"INSERT": func(c clause.Clause, builder clause.Builder) {
if insert, ok := c.Expression.(clause.Insert); ok {
@@ -113,7 +119,7 @@ func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
}
}
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
func (dialector _Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
if field.AutoIncrement {
return clause.Expr{SQL: "NULL"}
}
@@ -122,19 +128,19 @@ func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression
return clause.Expr{SQL: "DEFAULT"}
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{
func (dialector _Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return _Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
CreateIndexAfterCreateTable: true,
}}}
}
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
func (dialector _Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteByte('?')
}
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
func (dialector _Dialector) QuoteTo(writer clause.Writer, str string) {
var (
underQuoted, selfQuoted bool
continuousBacktick int8
@@ -182,16 +188,17 @@ func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteString("`")
}
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
func (dialector _Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, nil, `"`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
func (dialector _Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType {
case schema.Bool:
return "numeric"
case schema.Int, schema.Uint:
if field.AutoIncrement && !field.PrimaryKey {
if field.AutoIncrement {
// doesn't check `PrimaryKey`, to keep backward compatibility
// https://www.sqlite.org/autoinc.html
return "integer PRIMARY KEY AUTOINCREMENT"
} else {
@@ -215,12 +222,12 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
return string(field.DataType)
}
func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error {
func (dialectopr _Dialector) SavePoint(tx *gorm.DB, name string) error {
tx.Exec("SAVEPOINT " + name)
return nil
}
func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
func (dialectopr _Dialector) RollbackTo(tx *gorm.DB, name string) error {
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
return nil
}

View File

@@ -6,6 +6,8 @@ import (
"gorm.io/gorm"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
@@ -13,22 +15,52 @@ func TestDialector(t *testing.T) {
// This is the DSN of the in-memory SQLite database for these tests.
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
// Custom connection with a custom function called "my_custom_function".
db, err := driver.Open(InMemoryDSN, func(conn *sqlite3.Conn) error {
return conn.CreateFunction("my_custom_function", 0, sqlite3.DETERMINISTIC,
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultText("my-result")
})
})
if err != nil {
t.Fatal(err)
}
rows := []struct {
description string
dialector *Dialector
dialector gorm.Dialector
openSuccess bool
query string
querySuccess bool
}{
{
description: "Default driver",
dialector: &Dialector{
DSN: InMemoryDSN,
},
description: "Default driver",
dialector: Open(InMemoryDSN),
openSuccess: true,
query: "SELECT 1",
querySuccess: true,
},
{
description: "Custom function",
dialector: Open(InMemoryDSN),
openSuccess: true,
query: "SELECT my_custom_function()",
querySuccess: false,
},
{
description: "Custom connection",
dialector: OpenDB(db),
openSuccess: true,
query: "SELECT 1",
querySuccess: true,
},
{
description: "Custom connection, custom function",
dialector: OpenDB(db),
openSuccess: true,
query: "SELECT my_custom_function()",
querySuccess: true,
},
}
for rowIndex, row := range rows {
t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) {

View File

@@ -72,6 +72,7 @@ const (
IOERR_ROLLBACK_ATOMIC = IOERR | (31 << 8)
IOERR_DATA = IOERR | (32 << 8)
IOERR_CORRUPTFS = IOERR | (33 << 8)
IOERR_IN_PAGE = IOERR | (34 << 8)
LOCKED_SHAREDCACHE = LOCKED | (1 << 8)
LOCKED_VTAB = LOCKED | (2 << 8)
BUSY_RECOVERY = BUSY | (1 << 8)

View File

@@ -23,6 +23,7 @@ const (
OffsetErr = ErrorString("sqlite3: invalid offset")
TailErr = ErrorString("sqlite3: multiple statements")
IsolationErr = ErrorString("sqlite3: unsupported isolation level")
ValueErr = ErrorString("sqlite3: unsupported value")
NoVFSErr = ErrorString("sqlite3: no such vfs: ")
)

View File

@@ -14,6 +14,9 @@ func View(mod api.Module, ptr uint32, size uint64) []byte {
if size > math.MaxUint32 {
panic(RangeErr)
}
if size == 0 {
return nil
}
buf, ok := mod.Memory().Read(ptr, uint32(size))
if !ok {
panic(RangeErr)

56
json.go Normal file
View File

@@ -0,0 +1,56 @@
package sqlite3
import (
"encoding/json"
"strconv"
"time"
"unsafe"
"github.com/ncruces/go-sqlite3/internal/util"
)
// JSON returns:
// a [json.Marshaler] that can be used as an argument to
// [database/sql.DB.Exec] and similar methods to
// store value as JSON; and
// a [database/sql.Scanner] that can be used as an argument to
// [database/sql.Row.Scan] and similar methods to
// decode JSON into value.
func JSON(value any) any {
return jsonValue{value}
}
type jsonValue struct{ any }
func (j jsonValue) MarshalJSON() ([]byte, error) {
return json.Marshal(j.any)
}
func (j jsonValue) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, j.any)
}
func (j jsonValue) Scan(value any) error {
var buf []byte
switch v := value.(type) {
case []byte:
buf = v
case string:
buf = unsafe.Slice(unsafe.StringData(v), len(v))
case int64:
buf = strconv.AppendInt(nil, v, 10)
case float64:
buf = strconv.AppendFloat(nil, v, 'g', -1, 64)
case time.Time:
buf = append(buf, '"')
buf = v.AppendFormat(buf, time.RFC3339Nano)
buf = append(buf, '"')
case nil:
buf = append(buf, "null"...)
default:
panic(util.AssertErr())
}
return j.UnmarshalJSON(buf)
}

14
pointer.go Normal file
View File

@@ -0,0 +1,14 @@
package sqlite3
// Pointer returns a pointer to a value
// that can be used as an argument to
// [database/sql.DB.Exec] and similar methods.
//
// https://www.sqlite.org/bindptr.html
func Pointer[T any](val T) any {
return pointer[T]{val}
}
type pointer[T any] struct{ val T }
func (p pointer[T]) Value() any { return p.val }

112
quote.go Normal file
View File

@@ -0,0 +1,112 @@
package sqlite3
import (
"bytes"
"math"
"strconv"
"strings"
"time"
"unsafe"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Quote escapes and quotes a value
// making it safe to embed in SQL text.
func Quote(value any) string {
switch v := value.(type) {
case nil:
return "NULL"
case bool:
if v {
return "1"
} else {
return "0"
}
case int:
return strconv.Itoa(v)
case int64:
return strconv.FormatInt(v, 10)
case float64:
switch {
case math.IsNaN(v):
return "NULL"
case math.IsInf(v, 1):
return "9.0e999"
case math.IsInf(v, -1):
return "-9.0e999"
}
return strconv.FormatFloat(v, 'g', -1, 64)
case time.Time:
return "'" + v.Format(time.RFC3339Nano) + "'"
case string:
if strings.IndexByte(v, 0) >= 0 {
break
}
buf := make([]byte, 2+len(v)+strings.Count(v, "'"))
buf[0] = '\''
i := 1
for _, b := range []byte(v) {
if b == '\'' {
buf[i] = b
i += 1
}
buf[i] = b
i += 1
}
buf[i] = '\''
return unsafe.String(&buf[0], len(buf))
case []byte:
buf := make([]byte, 3+2*len(v))
buf[0] = 'x'
buf[1] = '\''
i := 2
for _, b := range v {
const hex = "0123456789ABCDEF"
buf[i+0] = hex[b/16]
buf[i+1] = hex[b%16]
i += 2
}
buf[i] = '\''
return unsafe.String(&buf[0], len(buf))
case ZeroBlob:
if v > ZeroBlob(1e9-3)/2 {
break
}
buf := bytes.Repeat([]byte("0"), int(3+2*int64(v)))
buf[0] = 'x'
buf[1] = '\''
buf[len(buf)-1] = '\''
return unsafe.String(&buf[0], len(buf))
}
panic(util.ValueErr)
}
// QuoteIdentifier escapes and quotes an identifier
// making it safe to embed in SQL text.
func QuoteIdentifier(id string) string {
if strings.IndexByte(id, 0) >= 0 {
panic(util.ValueErr)
}
buf := make([]byte, 2+len(id)+strings.Count(id, `"`))
buf[0] = '"'
i := 1
for _, b := range []byte(id) {
if b == '"' {
buf[i] = b
i += 1
}
buf[i] = b
i += 1
}
buf[i] = '"'
return unsafe.String(&buf[0], len(buf))
}

View File

@@ -22,6 +22,8 @@ import (
var (
Binary []byte // WASM binary to load.
Path string // Path to load the binary from.
RuntimeConfig wazero.RuntimeConfig
)
var instance struct {
@@ -32,12 +34,16 @@ var instance struct {
}
func compileSQLite() {
if RuntimeConfig == nil {
RuntimeConfig = wazero.NewRuntimeConfig()
}
ctx := context.Background()
instance.runtime = wazero.NewRuntime(ctx)
instance.runtime = wazero.NewRuntimeWithConfig(ctx, RuntimeConfig)
env := instance.runtime.NewHostModuleBuilder("env")
env = vfs.ExportHostFunctions(env)
env = exportHostFunctions(env)
env = exportCallbacks(env)
_, instance.err = env.Instantiate(ctx)
if instance.err != nil {
return
@@ -65,8 +71,6 @@ type sqlite struct {
stack [8]uint64
}
type sqliteKey struct{}
func instantiateSQLite() (sqlt *sqlite, err error) {
instance.once.Do(compileSQLite)
if instance.err != nil {
@@ -75,7 +79,6 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
sqlt = new(sqlite)
sqlt.ctx = util.NewContext(context.Background())
sqlt.ctx = context.WithValue(sqlt.ctx, sqliteKey{}, sqlt)
sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
instance.compiled, wazero.NewModuleConfig())
@@ -117,6 +120,8 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
reset: getFun("sqlite3_reset"),
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
interrupt: getFun("sqlite3_interrupt"),
progressHandler: getFun("sqlite3_progress_handler_go"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
@@ -127,6 +132,7 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
bindPointer: getFun("sqlite3_bind_pointer_go"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
@@ -164,12 +170,15 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
valueText: getFun("sqlite3_value_text"),
valueBlob: getFun("sqlite3_value_blob"),
valueBytes: getFun("sqlite3_value_bytes"),
valuePointer: getFun("sqlite3_value_pointer_go"),
resultNull: getFun("sqlite3_result_null"),
resultInteger: getFun("sqlite3_result_int64"),
resultFloat: getFun("sqlite3_result_double"),
resultText: getFun("sqlite3_result_text64"),
resultBlob: getFun("sqlite3_result_blob64"),
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
resultPointer: getFun("sqlite3_result_pointer_go"),
resultValue: getFun("sqlite3_result_value"),
resultError: getFun("sqlite3_result_error"),
resultErrorCode: getFun("sqlite3_result_error_code"),
resultErrorMem: getFun("sqlite3_result_error_nomem"),
@@ -200,13 +209,15 @@ func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if handle != 0 {
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if sql != nil {
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:]
if sql != nil {
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:]
}
}
}
@@ -245,7 +256,7 @@ func (sqlt *sqlite) new(size uint64) uint32 {
}
func (sqlt *sqlite) newBytes(b []byte) uint32 {
if b == nil {
if (*[0]byte)(b) == nil {
return 0
}
ptr := sqlt.new(uint64(len(b)))
@@ -333,6 +344,8 @@ type sqliteAPI struct {
reset api.Function
step api.Function
exec api.Function
interrupt api.Function
progressHandler api.Function
clearBindings api.Function
bindCount api.Function
bindIndex api.Function
@@ -343,6 +356,7 @@ type sqliteAPI struct {
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
bindPointer api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
@@ -380,15 +394,30 @@ type sqliteAPI struct {
valueText api.Function
valueBlob api.Function
valueBytes api.Function
valuePointer api.Function
resultNull api.Function
resultInteger api.Function
resultFloat api.Function
resultText api.Function
resultBlob api.Function
resultZeroBlob api.Function
resultPointer api.Function
resultValue api.Function
resultError api.Function
resultErrorCode api.Function
resultErrorMem api.Function
resultErrorBig api.Function
destructor uint32
}
func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncII(env, "go_progress", callbackProgress)
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
util.ExportFuncVIII(env, "go_func", callbackFunc)
util.ExportFuncVIII(env, "go_step", callbackStep)
util.ExportFuncVI(env, "go_final", callbackFinal)
util.ExportFuncVI(env, "go_value", callbackValue)
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
return env
}

View File

@@ -3,33 +3,33 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3420000.zip"
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3440000.zip"
unzip -d . sqlite-amalgamation-*.zip
mv sqlite-amalgamation-*/sqlite3* .
rm -rf sqlite-amalgamation-*
cat *.patch | patch
cat *.patch | patch --posix
mkdir -p ext/
cd ext/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uuid.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/series.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/anycollseq.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/uuid.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/series.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/anycollseq.c"
cd ~-
cd ../vfs/tests/mptest/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/mptest.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config02.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash02.subtest"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/multiwrite01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/mptest.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/config02.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/crash01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/crash02.subtest"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/multiwrite01.test"
cd ~-
cd ../vfs/tests/speedtest1/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/test/speedtest1.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/test/speedtest1.c"
cd ~-

View File

@@ -1,4 +1,4 @@
#include <string.h>
#include <stddef.h>
#include "sqlite3.h"
@@ -18,14 +18,15 @@ int sqlite3_create_collation_go(sqlite3 *db, const char *zName, void *pApp) {
int sqlite3_create_function_go(sqlite3 *db, const char *zName, int nArg,
int flags, void *pApp) {
return sqlite3_create_function_v2(db, zName, nArg, SQLITE_UTF8 | flags, pApp,
go_func, NULL, NULL, go_destroy);
go_func, /*step=*/NULL, /*final=*/NULL,
go_destroy);
}
int sqlite3_create_aggregate_function_go(sqlite3 *db, const char *zName,
int nArg, int flags, void *pApp) {
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
pApp, go_step, go_final, NULL, NULL,
go_destroy);
pApp, go_step, go_final, /*value=*/NULL,
/*inverse=*/NULL, go_destroy);
}
int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg,
@@ -38,3 +39,17 @@ int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg,
void sqlite3_set_auxdata_go(sqlite3_context *ctx, int iArg, void *pAux) {
sqlite3_set_auxdata(ctx, iArg, pAux, go_destroy);
}
#define GO_POINTER_TYPE "github.com/ncruces/go-sqlite3.Pointer"
int sqlite3_bind_pointer_go(sqlite3_stmt *stmt, int i, void *pApp) {
return sqlite3_bind_pointer(stmt, i, pApp, GO_POINTER_TYPE, go_destroy);
}
void sqlite3_result_pointer_go(sqlite3_context *ctx, void *pApp) {
sqlite3_result_pointer(ctx, pApp, GO_POINTER_TYPE, go_destroy);
}
void *sqlite3_value_pointer_go(sqlite3_value *val) {
return sqlite3_value_pointer(val, GO_POINTER_TYPE);
}

34
sqlite3/isoweek.patch Normal file
View File

@@ -0,0 +1,34 @@
# ISO week date specifiers.
# https://sqlite.org/forum/forumpost/73d99e4497e8e6a7
--- sqlite3.c.orig
+++ sqlite3.c
@@ -1373,6 +1373,29 @@ static void strftimeFunc(
sqlite3_str_appendchar(&sRes, 1, c);
break;
}
+ case 'V': /* Fall thru */
+ case 'G': {
+ DateTime y = x;
+ computeJD(&y);
+ y.validYMD = 0;
+ /* Adjust date to Thursday this week:
+ The number in parentheses is 0 for Monday, 3 for Thursday */
+ y.iJD += (3 - (((y.iJD+43200000)/86400000) % 7))*86400000;
+ computeYMD(&y);
+ if( cf=='G' ){
+ sqlite3_str_appendf(&sRes,"%04d",y.Y);
+ }else{
+ int nDay; /* Number of days since 1st day of year */
+ i64 tJD = y.iJD;
+ y.validJD = 0;
+ y.M = 1;
+ y.D = 1;
+ computeJD(&y);
+ nDay = (int)((tJD-y.iJD+43200000)/86400000);
+ sqlite3_str_appendf(&sRes,"%02d",nDay/7+1);
+ }
+ break;
+ }
case 'Y': {
sqlite3_str_appendf(&sRes,"%04d",x.Y);
break;

View File

@@ -1,6 +1,3 @@
#include <stdbool.h>
#include <stddef.h>
// Amalgamation
#include "sqlite3.c"
// VFS
@@ -14,6 +11,7 @@
#include "ext/uint.c"
#include "ext/uuid.c"
#include "func.c"
#include "progress.c"
#include "time.c"
__attribute__((constructor)) void init() {
@@ -25,4 +23,4 @@ __attribute__((constructor)) void init() {
sqlite3_auto_extension((void (*)(void))sqlite3_uint_init);
sqlite3_auto_extension((void (*)(void))sqlite3_uuid_init);
sqlite3_auto_extension((void (*)(void))sqlite3_time_init);
}
}

View File

@@ -1,26 +0,0 @@
# Allow the VFS to force memory journal mode
# regardless of SQLITE_OMIT_DESERIALIZE.
--- sqlite3.c.orig
+++ sqlite3.c
@@ -60425,11 +60425,7 @@
int rc = SQLITE_OK; /* Return code */
int tempFile = 0; /* True for temp files (incl. in-memory files) */
int memDb = 0; /* True if this is an in-memory file */
-#ifndef SQLITE_OMIT_DESERIALIZE
int memJM = 0; /* Memory journal mode */
-#else
-# define memJM 0
-#endif
int readOnly = 0; /* True if this is a read-only file */
int journalFileSize; /* Bytes to allocate for each journal fd */
char *zPathname = 0; /* Full path to database file */
@@ -60628,9 +60624,7 @@
int fout = 0; /* VFS flags returned by xOpen() */
rc = sqlite3OsOpen(pVfs, pPager->zFilename, pPager->fd, vfsFlags, &fout);
assert( !memDb );
-#ifndef SQLITE_OMIT_DESERIALIZE
pPager->memVfs = memJM = (fout&SQLITE_OPEN_MEMORY)!=0;
-#endif
readOnly = (fout&SQLITE_OPEN_READONLY)!=0;
/* If the file was successfully opened for read/write access,

9
sqlite3/progress.c Normal file
View File

@@ -0,0 +1,9 @@
#include <stddef.h>
#include "sqlite3.h"
int go_progress(void *);
void sqlite3_progress_handler_go(sqlite3 *db, int n) {
sqlite3_progress_handler(db, n, go_progress, /*arg=*/NULL);
}

View File

@@ -5,12 +5,28 @@
#define SQLITE_OS_OTHER 1
#define SQLITE_BYTEORDER 1234
#define HAVE_INT8_T 1
#define HAVE_INT16_T 1
#define HAVE_INT32_T 1
#define HAVE_INT64_T 1
#define HAVE_UINT8_T 1
#define HAVE_UINT16_T 1
#define HAVE_UINT32_T 1
#define HAVE_UINT64_T 1
#define HAVE_STDINT_H 1
#define HAVE_INTTYPES_H 1
#define HAVE_LOG2 1
#define HAVE_LOG10 1
#define HAVE_ISNAN 1
#define HAVE_USLEEP 1
#define HAVE_NANOSLEEP 1
#define HAVE_GMTIME_R 1
#define HAVE_LOCALTIME_S 1
#define HAVE_MALLOC_H 1
#define HAVE_MALLOC_USABLE_SIZE 1
// Recommended Options
@@ -23,7 +39,6 @@
#define SQLITE_MAX_EXPR_DEPTH 0
#define SQLITE_OMIT_DECLTYPE
#define SQLITE_OMIT_DEPRECATED
#define SQLITE_OMIT_PROGRESS_CALLBACK
#define SQLITE_OMIT_SHARED_CACHE
#define SQLITE_OMIT_AUTOINIT
#define SQLITE_USE_ALLOCA
@@ -52,11 +67,11 @@
#define SQLITE_ENABLE_RTREE 1
#define SQLITE_ENABLE_GEOPOLY 1
#define SQLITE_SOUNDEX
// Session Extension
// #define SQLITE_ENABLE_SESSION
// #define SQLITE_ENABLE_PREUPDATE_HOOK
#define SQLITE_SOUNDEX
// Implemented in vfs.c.
int localtime_s(struct tm *const pTm, time_t const *const pTime);

View File

@@ -1,3 +1,4 @@
#include <stddef.h>
#include <string.h>
#include "sqlite3.h"
@@ -26,7 +27,63 @@ static int time_collation(void *pArg, int nKey1, const void *pKey1, int nKey2,
return rc;
}
static void json_time_func(sqlite3_context *context, int argc,
sqlite3_value **argv) {
DateTime x;
if (isDate(context, argc, argv, &x)) return;
if (x.tzSet && x.tz) {
x.iJD += x.tz * 60000;
if (!validJulianDay(x.iJD)) return;
x.validYMD = 0;
x.validHMS = 0;
}
computeYMD_HMS(&x);
sqlite3 *db = sqlite3_context_db_handle(context);
sqlite3_str *res = sqlite3_str_new(db);
sqlite3_str_appendf(res, "%04d-%02d-%02dT%02d:%02d:%02d", //
x.Y, x.M, x.D, //
x.h, x.m, (int)(x.iJD / 1000 % 60));
if (x.useSubsec) {
int rem = x.iJD % 1000;
if (rem) {
sqlite3_str_appendchar(res, 1, '.');
sqlite3_str_appendchar(res, 1, '0' + rem / 100);
if ((rem %= 100)) {
sqlite3_str_appendchar(res, 1, '0' + rem / 10);
if ((rem %= 10)) {
sqlite3_str_appendchar(res, 1, '0' + rem);
}
}
}
}
if (x.tz) {
sqlite3_str_appendf(res, "%+03d:%02d", x.tz / 60, abs(x.tz) % 60);
} else {
sqlite3_str_appendchar(res, 1, 'Z');
}
int rc = sqlite3_str_errcode(res);
if (rc) {
sqlite3_result_error_code(context, rc);
return;
}
int n = sqlite3_str_length(res);
sqlite3_result_text(context, sqlite3_str_finish(res), n, sqlite3_free);
}
int sqlite3_time_init(sqlite3 *db, char **pzErrMsg,
const sqlite3_api_routines *pApi) {
return sqlite3_create_collation(db, "time", SQLITE_UTF8, 0, time_collation);
sqlite3_create_collation_v2(db, "time", SQLITE_UTF8, /*arg=*/NULL,
time_collation,
/*destroy=*/NULL);
sqlite3_create_function_v2(
db, "json_time", -1,
SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS, /*arg=*/NULL,
json_time_func, /*step=*/NULL, /*final=*/NULL, /*destroy=*/NULL);
return SQLITE_OK;
}

45
sqlite3/timezone.patch Normal file
View File

@@ -0,0 +1,45 @@
# Set UTC timezone, compute local offset.
--- sqlite3.c.orig
+++ sqlite3.c
@@ -340,6 +340,7 @@ static int setDateTimeToCurrent(sqlite3_context *context, DateTime *p){
p->iJD = sqlite3StmtCurrentTime(context);
if( p->iJD>0 ){
p->validJD = 1;
+ p->tzSet = 1;
return 0;
}else{
return 1;
@@ -355,6 +356,7 @@ static int setDateTimeToCurrent(sqlite3_context *context, DateTime *p){
static void setRawDateNumber(DateTime *p, double r){
p->s = r;
p->rawS = 1;
+ p->tzSet = 1;
if( r>=0.0 && r<5373484.5 ){
p->iJD = (sqlite3_int64)(r*86400000.0 + 0.5);
p->validJD = 1;
@@ -731,7 +733,16 @@ static int parseModifier(
** show local time.
*/
if( sqlite3_stricmp(z, "localtime")==0 && sqlite3NotPureFunc(pCtx) ){
- rc = toLocaltime(p, pCtx);
+ if( p->tzSet!=0 || p->tz==0 ) {
+ rc = toLocaltime(p, pCtx);
+ i64 iOrigJD = p->iJD;
+ p->tzSet = 0;
+ computeJD(p);
+ p->tz = (p->iJD-iOrigJD)/60000;
+ if( abs(p->tz)>= 900 ) p->tz = 0;
+ } else {
+ rc = 0;
+ }
}
break;
}
@@ -781,6 +792,7 @@ static int parseModifier(
p->validJD = 1;
p->tzSet = 1;
}
+ p->tz = 0;
rc = SQLITE_OK;
}
#endif

View File

@@ -1,3 +1,5 @@
#include <stdbool.h>
#include <stddef.h>
#include <time.h>
#include "sqlite3.h"
@@ -90,22 +92,25 @@ int localtime_s(struct tm *const pTm, time_t const *const pTime) {
sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
if (zVfsName) {
static sqlite3_vfs *go_vfs_list;
sqlite3_vfs *found = NULL;
for (sqlite3_vfs **next = &go_vfs_list; *next;) {
sqlite3_vfs *it = *next;
for (sqlite3_vfs *it = go_vfs_list; it; it = it->pNext) {
if (!strcmp(zVfsName, it->zName) && go_vfs_find(it->zName)) {
return it;
}
}
for (sqlite3_vfs **ptr = &go_vfs_list; *ptr;) {
sqlite3_vfs *it = *ptr;
if (go_vfs_find(it->zName)) {
if (!strcmp(zVfsName, it->zName)) found = it;
next = &it->pNext;
ptr = &it->pNext;
} else {
*next = it->pNext;
*ptr = it->pNext;
free(it);
}
}
if (found) {
return found;
}
if (go_vfs_find(zVfsName)) {
sqlite3_vfs *prev = go_vfs_list;
sqlite3_vfs *head = go_vfs_list;
go_vfs_list = malloc(sizeof(sqlite3_vfs) + strlen(zVfsName) + 1);
char *name = (char *)(go_vfs_list + 1);
strcpy(name, zVfsName);
@@ -114,7 +119,7 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
.szOsFile = sizeof(struct go_file),
.mxPathname = 512,
.zName = name,
.pNext = prev,
.pNext = head,
.xOpen = go_open_wrapper,
.xDelete = go_delete,
@@ -132,6 +137,5 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
return sqlite3_vfs_find_orig(zVfsName);
}
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");
static_assert(offsetof(sqlite3_vfs, zName) == 16, "Unexpected offset");
static_assert(offsetof(sqlite3, u1.isInterrupted) == 288, "Unexpected offset");
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");

View File

@@ -3,6 +3,7 @@ package sqlite3
import (
"bytes"
"math"
"os"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -56,6 +57,12 @@ func Test_sqlite_new(t *testing.T) {
})
t.Run("_MAX_ALLOCATION_SIZE", func(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
if os.Getenv("CI") != "" {
t.Skip("skipping in CI")
}
defer func() { _ = recover() }()
sqlite.new(_MAX_ALLOCATION_SIZE)
sqlite.new(_MAX_ALLOCATION_SIZE)
@@ -132,6 +139,15 @@ func Test_sqlite_newBytes(t *testing.T) {
if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
ptr = sqlite.newBytes(buf[:0])
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
if got := util.View(sqlite.mod, ptr, 0); got != nil {
t.Errorf("got %q, want nil", got)
}
}
func Test_sqlite_newString(t *testing.T) {

57
stmt.go
View File

@@ -1,7 +1,9 @@
package sqlite3
import (
"encoding/json"
"math"
"strconv"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -235,7 +237,7 @@ func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
}
func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
const maxlen = uint64(len(time.RFC3339Nano))
const maxlen = uint64(len(time.RFC3339Nano)) + 5
ptr := s.c.new(maxlen)
buf := util.View(s.c.mod, ptr, maxlen)
@@ -248,6 +250,35 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
return s.c.error(r)
}
// BindPointer binds a NULL to the prepared statement, just like [Stmt.BindNull],
// but it also associates ptr with that NULL value such that it can be retrieved
// within an application-defined SQL function using [Value.Pointer].
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindPointer(param int, ptr any) error {
valPtr := util.AddHandle(s.c.ctx, ptr)
r := s.c.call(s.c.api.bindPointer,
uint64(s.handle), uint64(param), uint64(valPtr))
return s.c.error(r)
}
// BindJSON binds the JSON encoding of value to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindJSON(param int, value any) error {
data, err := json.Marshal(value)
if err != nil {
return err
}
ptr := s.c.newBytes(data)
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(data)),
uint64(s.c.api.destructor), _UTF8)
return s.c.error(r)
}
// ColumnCount returns the number of columns in a result set.
//
// https://www.sqlite.org/c3ref/column_count.html
@@ -402,6 +433,30 @@ func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte {
return util.View(s.c.mod, ptr, r)
}
// ColumnJSON parses the JSON-encoded value of the result column
// and stores it in the value pointed to by ptr.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnJSON(col int, ptr any) error {
var data []byte
switch s.ColumnType(col) {
case NULL:
data = append(data, "null"...)
case TEXT:
data = s.ColumnRawText(col)
case BLOB:
data = s.ColumnRawBlob(col)
case INTEGER:
data = strconv.AppendInt(nil, s.ColumnInt64(col), 10)
case FLOAT:
data = strconv.AppendFloat(nil, s.ColumnFloat(col), 'g', -1, 64)
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}
// Return true if stmt is an empty SQL statement.
// This is used as an optimization.
// It's OK to always return false here.

View File

@@ -182,7 +182,7 @@ func TestConn_SetInterrupt(t *testing.T) {
defer stmt.Close()
db.SetInterrupt(ctx)
cancel()
go cancel()
// Interrupting works.
err = stmt.Exec()

View File

@@ -25,6 +25,13 @@ func TestDB_file(t *testing.T) {
testDB(t, filepath.Join(t.TempDir(), "test.db"))
}
func TestDB_nolock(t *testing.T) {
t.Parallel()
testDB(t, "file:"+
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
"?nolock=1")
}
func TestDB_wal(t *testing.T) {
t.Parallel()
wal := filepath.Join(t.TempDir(), "test.db")

View File

@@ -2,10 +2,9 @@ package tests
import (
"context"
"database/sql"
"testing"
_ "github.com/ncruces/go-sqlite3/driver"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
@@ -15,7 +14,7 @@ func TestDriver(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", ":memory:")
db, err := driver.Open(":memory:", nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -36,8 +36,17 @@ func TestCreateFunction(t *testing.T) {
case 7:
ctx.ResultTime(arg.Time(sqlite3.TimeFormatUnix), sqlite3.TimeFormatDefault)
case 8:
ctx.ResultNull()
var v any
if err := arg.JSON(&v); err != nil {
ctx.ResultError(err)
} else {
ctx.ResultJSON(v)
}
case 9:
ctx.ResultValue(arg)
case 10:
ctx.ResultNull()
case 11:
ctx.ResultError(sqlite3.FULL)
}
})
@@ -45,7 +54,7 @@ func TestCreateFunction(t *testing.T) {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0, 9)`)
stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0)`)
if err != nil {
t.Error(err)
}
@@ -123,6 +132,27 @@ func TestCreateFunction(t *testing.T) {
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
var got int
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != 8 {
t.Errorf("got %v, want 8", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnInt64(0); got != 9 {
t.Errorf("got %v, want 9", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)

68
tests/json_test.go Normal file
View File

@@ -0,0 +1,68 @@
package tests
import (
"encoding/json"
"math"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
"github.com/ncruces/julianday"
)
func TestJSON(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", nil)
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
_, err = db.Exec(
`INSERT INTO test (col) VALUES (?), (?), (?), (?)`,
nil, 1, math.Pi, reference,
)
if err != nil {
t.Fatal(err)
}
_, err = db.Exec(
`INSERT INTO test (col) VALUES (?), (?), (?), (?)`,
sqlite3.JSON(math.Pi), sqlite3.JSON(false),
julianday.Format(reference), sqlite3.JSON([]string{}))
if err != nil {
t.Fatal(err)
}
rows, err := db.Query("SELECT * FROM test")
if err != nil {
t.Fatal(err)
}
want := []string{
"null", "1", "3.141592653589793",
`"2013-10-07T04:23:19.12-04:00"`,
"3.141592653589793", "false",
"2456572.849526851851852", "[]",
}
for rows.Next() {
var got json.RawMessage
err = rows.Scan(sqlite3.JSON(&got))
if err != nil {
t.Fatal(err)
}
if string(got) != want[0] {
t.Errorf("got %q, want %q", got, want[0])
}
want = want[1:]
}
}

82
tests/quote_test.go Normal file
View File

@@ -0,0 +1,82 @@
package tests
import (
"math"
"reflect"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
)
func TestQuote(t *testing.T) {
tests := []struct {
val any
want string
}{
{`abc`, "'abc'"},
{`a"bc`, "'a\"bc'"},
{`a'bc`, "'a''bc'"},
{"\x07bc", "'\abc'"},
{"\x1c\n", "'\x1c\n'"},
{[]byte("\xB0\x00\x0B"), "x'B0000B'"},
{"\xB0\x00\x0B", ""},
{0, "0"},
{true, "1"},
{false, "0"},
{nil, "NULL"},
{math.NaN(), "NULL"},
{math.Inf(1), "9.0e999"},
{math.Inf(-1), "-9.0e999"},
{math.Pi, "3.141592653589793"},
{int64(math.MaxInt64), "9223372036854775807"},
{time.Unix(0, 0).UTC(), "'1970-01-01T00:00:00Z'"},
{sqlite3.ZeroBlob(4), "x'00000000'"},
{sqlite3.ZeroBlob(1e9), ""},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
defer func() {
if r := recover(); r != nil && tt.want != "" {
t.Errorf("Quote(%q) = %v", tt.val, r)
}
}()
got := sqlite3.Quote(tt.val)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Quote(%v) = %q, want %q", tt.val, got, tt.want)
}
})
}
}
func TestQuoteIdentifier(t *testing.T) {
tests := []struct {
id string
want string
}{
{`abc`, `"abc"`},
{`a"bc`, `"a""bc"`},
{`a'bc`, `"a'bc"`},
{"\x07bc", "\"\abc\""},
{"\x1c\n", "\"\x1c\n\""},
{"\xB0\x00\x0B", ""},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
defer func() {
if r := recover(); r != nil && tt.want != "" {
t.Errorf("QuoteIdentifier(%q) = %v", tt.id, r)
}
}()
got := sqlite3.QuoteIdentifier(tt.id)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("QuoteIdentifier(%v) = %q, want %q", tt.id, got, tt.want)
}
})
}
}

View File

@@ -1,6 +1,7 @@
package tests
import (
"encoding/json"
"math"
"testing"
"time"
@@ -81,6 +82,13 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
if err := stmt.BindBlob(1, []byte("")); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.BindBlob(1, []byte("blob")); err != nil {
t.Fatal(err)
}
@@ -102,6 +110,13 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
if err := stmt.BindJSON(1, true); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.ClearBindings(); err != nil {
t.Fatal(err)
}
@@ -114,7 +129,7 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
// The table should have: 0, 1, 2, π, NULL, "", "text", "blob", NULL, "\0\0\0\0", NULL
// The table should have: 0, 1, 2, π, NULL, "", "text", "", "blob", NULL, "\0\0\0\0", "true", NULL
stmt, _, err = db.Prepare(`SELECT col FROM test`)
if err != nil {
t.Fatal(err)
@@ -140,6 +155,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "0" {
t.Errorf("got %q, want zero", got)
}
var got int
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != 0 {
t.Errorf("got %v, want zero", got)
}
}
if stmt.Step() {
@@ -161,6 +182,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "1" {
t.Errorf("got %q, want one", got)
}
var got float32
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != 1 {
t.Errorf("got %v, want one", got)
}
}
if stmt.Step() {
@@ -182,6 +209,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "2" {
t.Errorf("got %q, want two", got)
}
var got json.Number
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != "2" {
t.Errorf("got %v, want two", got)
}
}
if stmt.Step() {
@@ -203,6 +236,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "3.14159265358979" {
t.Errorf("got %q, want π", got)
}
var got float64
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != math.Pi {
t.Errorf("got %v, want π", got)
}
}
if stmt.Step() {
@@ -224,6 +263,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != nil {
t.Errorf("got %v, want NULL", got)
}
}
if stmt.Step() {
@@ -245,6 +290,10 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
@@ -266,6 +315,35 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "text" {
t.Errorf(`got %q, want "text"`, got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
@@ -287,6 +365,10 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "blob" {
t.Errorf(`got %q, want "blob"`, got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
@@ -308,6 +390,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != nil {
t.Errorf("got %v, want NULL", got)
}
}
if stmt.Step() {
@@ -329,6 +417,37 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "\x00\x00\x00\x00" {
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "true" {
t.Errorf("got %q, want true", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "true" {
t.Errorf("got %q, want true", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != true {
t.Errorf("got %v, want true", got)
}
}
if stmt.Step() {
@@ -350,6 +469,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != nil {
t.Errorf("got %v, want NULL", got)
}
}
if err := stmt.Close(); err != nil {

View File

@@ -1,11 +1,13 @@
package tests
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
)
func TestTimeFormat_Encode(t *testing.T) {
@@ -39,70 +41,72 @@ func TestTimeFormat_Encode(t *testing.T) {
func TestTimeFormat_Decode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
zone := time.FixedZone("", -4*3600)
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, zone)
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, zone)
tests := []struct {
fmt sqlite3.TimeFormat
val any
want time.Time
wantDelta time.Duration
wantLoc *time.Location
wantErr bool
}{
{sqlite3.TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
{sqlite3.TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
{sqlite3.TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
{sqlite3.TimeFormatJulianDay, false, time.Time{}, 0, true},
{sqlite3.TimeFormatJulianDay, "2456572.849526851851852", reference, 0, time.UTC, false},
{sqlite3.TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, time.UTC, false},
{sqlite3.TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, time.UTC, false},
{sqlite3.TimeFormatJulianDay, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
{sqlite3.TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
{sqlite3.TimeFormatUnix, int64(1381134199), reference, time.Second, false},
{sqlite3.TimeFormatUnix, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnix, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnix, "1381134199.120", reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatUnix, 1381134199.120, reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatUnix, int64(1381134199), reference, time.Second, time.UTC, false},
{sqlite3.TimeFormatUnix, "abc", time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnix, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixMilli, "1381134199120", reference, 0, false},
{sqlite3.TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
{sqlite3.TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
{sqlite3.TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMilli, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMilli, "1381134199120", reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMilli, 1381134199.120e3, reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMilli, int64(1381134199_120), reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMilli, "abc", time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixMilli, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
{sqlite3.TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
{sqlite3.TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
{sqlite3.TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMicro, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMicro, "1381134199120000", reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMicro, 1381134199.120e6, reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMicro, "abc", time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixMicro, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
{sqlite3.TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
{sqlite3.TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
{sqlite3.TimeFormatUnixNano, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixNano, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixNano, "1381134199120000000", reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixNano, 1381134199.120e9, reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixNano, "abc", time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixNano, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
{sqlite3.TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
{sqlite3.TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199", reference, time.Second, false},
{sqlite3.TimeFormatAuto, "1381134199120", reference, 0, false},
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, false},
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, false},
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, true},
{sqlite3.TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, time.UTC, false},
{sqlite3.TimeFormatAuto, "2456572", reference, 24 * time.Hour, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199.120", reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199", reference, time.Second, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199120", reference, 0, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, time.UTC, false},
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, zone, false},
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, zone, false},
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, nil, true},
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, false},
{sqlite3.TimeFormat3, false, time.Time{}, 0, true},
{sqlite3.TimeFormat9, false, time.Time{}, 0, true},
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, zone, false},
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, time.UTC, false},
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, zone, false},
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, time.UTC, false},
{sqlite3.TimeFormat3, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormat9, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
{sqlite3.TimeFormatDefault, false, time.Time{}, 0, true},
{sqlite3.TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, zone, false},
{sqlite3.TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, time.UTC, false},
{sqlite3.TimeFormatDefault, false, time.Time{}, 0, nil, true},
}
for _, tt := range tests {
@@ -112,13 +116,48 @@ func TestTimeFormat_Decode(t *testing.T) {
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
return
}
if tt.want.Sub(got).Abs() > tt.wantDelta {
if got.Sub(tt.want).Abs() > tt.wantDelta {
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
}
if got.Location().String() != tt.wantLoc.String() {
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got.Location(), tt.wantLoc)
}
})
}
}
func TestTimeFormat_Scanner(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", nil)
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Exec(
`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
_, err = db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.TimeFormat7TZ.Encode(reference))
if err != nil {
t.Fatal(err)
}
var got time.Time
err = db.QueryRow("SELECT * FROM test").Scan(sqlite3.TimeFormatAuto.Scanner(&got))
if err != nil {
t.Fatal(err)
}
if !got.Equal(reference) {
t.Errorf("got %v, want %v", got, reference)
}
}
func TestDB_timeCollation(t *testing.T) {
t.Parallel()
@@ -167,3 +206,57 @@ func TestDB_timeCollation(t *testing.T) {
t.Fatal(err)
}
}
func TestDB_isoWeek(t *testing.T) {
t.Parallel()
tests := []time.Time{
time.Date(1977, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1977, 1, 2, 0, 0, 0, 0, time.UTC),
time.Date(1977, 12, 31, 0, 0, 0, 0, time.UTC),
time.Date(1978, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1978, 1, 2, 0, 0, 0, 0, time.UTC),
time.Date(1978, 12, 31, 0, 0, 0, 0, time.UTC),
time.Date(1979, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1979, 1, 2, 0, 0, 0, 0, time.UTC),
time.Date(1979, 12, 31, 0, 0, 0, 0, time.UTC),
time.Date(1980, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1980, 12, 28, 0, 0, 0, 0, time.UTC),
time.Date(1980, 12, 29, 0, 0, 0, 0, time.UTC),
time.Date(1980, 12, 30, 0, 0, 0, 0, time.UTC),
time.Date(1980, 12, 31, 0, 0, 0, 0, time.UTC),
time.Date(1981, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1981, 12, 31, 0, 0, 0, 0, time.UTC),
time.Date(1982, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1982, 1, 2, 0, 0, 0, 0, time.UTC),
time.Date(1982, 1, 3, 0, 0, 0, 0, time.UTC),
}
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT strftime('%G-W%V-%u', ?)`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
for _, tm := range tests {
stmt.BindTime(1, tm, sqlite3.TimeFormatDefault)
if stmt.Step() {
y, w := tm.ISOWeek()
d := tm.Weekday()
if d == 0 {
d = 7
}
want := fmt.Sprintf("%04d-W%02d-%d", y, w, d)
if got := stmt.ColumnText(0); got != want {
t.Errorf("got %q, want %q (%v)", got, want, tm)
}
}
stmt.Reset()
}
}

56
time.go
View File

@@ -164,9 +164,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case float64:
sec, frac := math.Modf(v)
nsec := math.Floor(frac * 1e9)
return time.Unix(int64(sec), int64(nsec)), nil
return time.Unix(int64(sec), int64(nsec)).UTC(), nil
case int64:
return time.Unix(v, 0), nil
return time.Unix(v, 0).UTC(), nil
default:
return time.Time{}, util.TimeErr
}
@@ -181,9 +181,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
switch v := v.(type) {
case float64:
return time.UnixMilli(int64(math.Floor(v))), nil
return time.UnixMilli(int64(math.Floor(v))).UTC(), nil
case int64:
return time.UnixMilli(int64(v)), nil
return time.UnixMilli(int64(v)).UTC(), nil
default:
return time.Time{}, util.TimeErr
}
@@ -198,9 +198,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
switch v := v.(type) {
case float64:
return time.UnixMicro(int64(math.Floor(v))), nil
return time.UnixMicro(int64(math.Floor(v))).UTC(), nil
case int64:
return time.UnixMicro(int64(v)), nil
return time.UnixMicro(int64(v)).UTC(), nil
default:
return time.Time{}, util.TimeErr
}
@@ -215,9 +215,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
switch v := v.(type) {
case float64:
return time.Unix(0, int64(math.Floor(v))), nil
return time.Unix(0, int64(math.Floor(v))).UTC(), nil
case int64:
return time.Unix(0, int64(v)), nil
return time.Unix(0, int64(v)).UTC(), nil
default:
return time.Time{}, util.TimeErr
}
@@ -238,26 +238,16 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
dates := []TimeFormat{
TimeFormat6TZ, TimeFormat6, TimeFormat3TZ, TimeFormat3,
TimeFormat5TZ, TimeFormat5, TimeFormat2TZ, TimeFormat2,
TimeFormat1,
TimeFormat9, TimeFormat8,
TimeFormat6, TimeFormat5,
TimeFormat3, TimeFormat2, TimeFormat1,
}
for _, f := range dates {
t, err := time.Parse(string(f), s)
t, err := f.Decode(s)
if err == nil {
return t, nil
}
}
times := []TimeFormat{
TimeFormat9TZ, TimeFormat9, TimeFormat8TZ, TimeFormat8,
}
for _, f := range times {
t, err := time.Parse(string(f), s)
if err == nil {
return t.AddDate(2000, 0, 0), nil
}
}
}
switch v := v.(type) {
case float64:
@@ -314,7 +304,10 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
return time.Time{}, util.TimeErr
}
t, err := f.parseRelaxed(s)
return t.AddDate(2000, 0, 0), err
if err != nil {
return time.Time{}, err
}
return t.AddDate(2000, 0, 0), nil
default:
s, ok := v.(string)
@@ -338,3 +331,20 @@ func (f TimeFormat) parseRelaxed(s string) (time.Time, error) {
}
return t, nil
}
// Scanner returns a [database/sql.Scanner] that can be used as an argument to
// [database/sql.Row.Scan] and similar methods to
// decode a time value into dest using this format.
func (f TimeFormat) Scanner(dest *time.Time) interface{ Scan(any) error } {
return timeScanner{dest, f}
}
type timeScanner struct {
*time.Time
TimeFormat
}
func (s timeScanner) Scan(src any) (err error) {
*s.Time, err = s.Decode(src)
return
}

33
tx.go
View File

@@ -7,6 +7,7 @@ import (
"math/rand"
"runtime"
"strconv"
"strings"
)
// Tx is an in-progress database transaction.
@@ -119,17 +120,8 @@ type Savepoint struct {
//
// https://www.sqlite.org/lang_savepoint.html
func (c *Conn) Savepoint() Savepoint {
name := "sqlite3.Savepoint"
var pc [1]uintptr
if n := runtime.Callers(2, pc[:]); n > 0 {
frames := runtime.CallersFrames(pc[:n])
frame, _ := frames.Next()
if frame.Function != "" {
name = frame.Function
}
}
// Names can be reused; this makes catching bugs more likely.
name += "#" + strconv.Itoa(int(rand.Int31()))
name := saveptName() + "_" + strconv.Itoa(int(rand.Int31()))
err := c.txExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name))
if err != nil {
@@ -138,6 +130,27 @@ func (c *Conn) Savepoint() Savepoint {
return Savepoint{c: c, name: name}
}
func saveptName() (name string) {
defer func() {
if name == "" {
name = "sqlite3.Savepoint"
}
}()
var pc [8]uintptr
n := runtime.Callers(3, pc[:])
if n <= 0 {
return ""
}
frames := runtime.CallersFrames(pc[:n])
frame, more := frames.Next()
for more && (strings.HasPrefix(frame.Function, "database/sql.") ||
strings.HasPrefix(frame.Function, "github.com/ncruces/go-sqlite3/driver.")) {
frame, more = frames.Next()
}
return frame.Function
}
// Release releases the savepoint rolling back any changes
// if *error points to a non-nil error.
//

View File

@@ -1,7 +1,9 @@
package sqlite3
import (
"encoding/json"
"math"
"strconv"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -123,3 +125,31 @@ func (v Value) rawBytes(ptr uint32) []byte {
r := v.call(v.api.valueBytes, uint64(v.handle))
return util.View(v.mod, ptr, r)
}
// Pointer gets the pointer associated with this value,
// or nil if it has no associated pointer.
func (v Value) Pointer() any {
r := v.call(v.api.valuePointer, uint64(v.handle))
return util.GetHandle(v.ctx, uint32(r))
}
// JSON parses a JSON-encoded value
// and stores the result in the value pointed to by ptr.
func (v Value) JSON(ptr any) error {
var data []byte
switch v.Type() {
case NULL:
data = append(data, "null"...)
case TEXT:
data = v.RawText()
case BLOB:
data = v.RawBlob()
case INTEGER:
data = strconv.AppendInt(nil, v.Int64(), 10)
case FLOAT:
data = strconv.AppendFloat(nil, v.Float(), 'g', -1, 64)
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}

View File

@@ -2,8 +2,6 @@
This package implements the SQLite [OS Interface](https://www.sqlite.org/vfs.html) (aka VFS).
It replaces the default VFS with a pure Go implementation,
that is tested on Linux, macOS and Windows,
but which should also work on illumos and the various BSDs.
It replaces the default SQLite VFS with a pure Go implementation.
It also exposes interfaces that should allow you to implement your own custom VFSes.

View File

@@ -47,7 +47,7 @@ type File interface {
// FileLockState extends File to implement the
// SQLITE_FCNTL_LOCKSTATE file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntllockstate
type FileLockState interface {
File
LockState() LockLevel
@@ -56,7 +56,7 @@ type FileLockState interface {
// FileSizeHint extends File to implement the
// SQLITE_FCNTL_SIZE_HINT file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlsizehint
type FileSizeHint interface {
File
SizeHint(size int64) error
@@ -65,16 +65,25 @@ type FileSizeHint interface {
// FileHasMoved extends File to implement the
// SQLITE_FCNTL_HAS_MOVED file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlhasmoved
type FileHasMoved interface {
File
HasMoved() (bool, error)
}
// FileOverwrite extends File to implement the
// SQLITE_FCNTL_OVERWRITE file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntloverwrite
type FileOverwrite interface {
File
Overwrite() error
}
// FilePowersafeOverwrite extends File to implement the
// SQLITE_FCNTL_POWERSAFE_OVERWRITE file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlpowersafeoverwrite
type FilePowersafeOverwrite interface {
File
PowersafeOverwrite() bool
@@ -84,7 +93,7 @@ type FilePowersafeOverwrite interface {
// FilePowersafeOverwrite extends File to implement the
// SQLITE_FCNTL_COMMIT_PHASETWO file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlcommitphasetwo
type FileCommitPhaseTwo interface {
File
CommitPhaseTwo() error
@@ -94,7 +103,7 @@ type FileCommitPhaseTwo interface {
// SQLITE_FCNTL_BEGIN_ATOMIC_WRITE, SQLITE_FCNTL_COMMIT_ATOMIC_WRITE
// and SQLITE_FCNTL_ROLLBACK_ATOMIC_WRITE file control opcodes.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlbeginatomicwrite
type FileBatchAtomicWrite interface {
File
BeginAtomicWrite() error

View File

@@ -9,7 +9,6 @@ import (
"path/filepath"
"runtime"
"syscall"
"time"
)
type vfsOS struct{}
@@ -124,11 +123,10 @@ func (vfsOS) OpenParams(name string, flags OpenFlag, params url.Values) (File, O
type vfsFile struct {
*os.File
lockTimeout time.Duration
lock LockLevel
psow bool
syncDir bool
readOnly bool
lock LockLevel
psow bool
syncDir bool
readOnly bool
}
var (

View File

@@ -1,11 +1,6 @@
package vfs
import (
"os"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
import "github.com/ncruces/go-sqlite3/internal/util"
const (
_PENDING_BYTE = 0x40000000
@@ -48,7 +43,7 @@ func (f *vfsFile) Lock(lock LockLevel) error {
if f.lock != LOCK_NONE {
panic(util.AssertErr())
}
if rc := osGetSharedLock(f.File, f.lockTimeout); rc != _OK {
if rc := osGetSharedLock(f.File); rc != _OK {
return rc
}
f.lock = LOCK_SHARED
@@ -59,7 +54,7 @@ func (f *vfsFile) Lock(lock LockLevel) error {
if f.lock != LOCK_SHARED {
panic(util.AssertErr())
}
if rc := osGetReservedLock(f.File, f.lockTimeout); rc != _OK {
if rc := osGetReservedLock(f.File); rc != _OK {
return rc
}
f.lock = LOCK_RESERVED
@@ -77,7 +72,7 @@ func (f *vfsFile) Lock(lock LockLevel) error {
}
f.lock = LOCK_PENDING
}
if rc := osGetExclusiveLock(f.File, f.lockTimeout); rc != _OK {
if rc := osGetExclusiveLock(f.File); rc != _OK {
return rc
}
f.lock = LOCK_EXCLUSIVE
@@ -133,18 +128,3 @@ func (f *vfsFile) CheckReservedLock() (bool, error) {
}
return osCheckReservedLock(f.File)
}
func osGetReservedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Acquire the RESERVED lock.
return osWriteLock(file, _RESERVED_BYTE, 1, timeout)
}
func osGetPendingLock(file *os.File) _ErrorCode {
// Acquire the PENDING lock.
return osWriteLock(file, _PENDING_BYTE, 1, 0)
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
// Test the RESERVED lock.
return osCheckLock(file, _RESERVED_BYTE, 1)
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -12,13 +11,6 @@ import (
)
func Test_vfsLock(t *testing.T) {
switch runtime.GOOS {
case "linux", "darwin", "windows":
break
default:
t.Skip("OS lacks OFD locks")
}
name := filepath.Join(t.TempDir(), "test.db")
// Create a temporary file.
@@ -211,9 +203,4 @@ func Test_vfsLock(t *testing.T) {
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) {
t.Error("invalid lock state", got)
}
rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCK_TIMEOUT, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
}

View File

@@ -133,7 +133,7 @@ func (m *memFile) WriteAt(b []byte, off int64) (n int, err error) {
n = copy((*m.data[base])[rest:], b)
if n < len(b) {
// Assume writes are page aligned.
return 0, io.ErrShortWrite
return n, io.ErrShortWrite
}
if size := off + int64(len(b)); size > m.size {
m.size = size
@@ -176,6 +176,8 @@ func (m *memFile) Size() (int64, error) {
return m.size, nil
}
const spinWait = 25 * time.Microsecond
func (m *memFile) Lock(lock vfs.LockLevel) error {
if m.lock >= lock {
return nil
@@ -210,8 +212,8 @@ func (m *memFile) Lock(lock vfs.LockLevel) error {
m.pending = m
}
for start := time.Now(); m.shared > 1; {
if time.Since(start) > time.Millisecond {
for before := time.Now(); m.shared > 1; {
if time.Since(before) > spinWait {
return sqlite3.BUSY
}
m.lockMtx.Unlock()

View File

@@ -1,4 +1,4 @@
//go:build freebsd || openbsd || netbsd || dragonfly || (darwin && sqlite3_bsd)
//go:build (freebsd || openbsd || netbsd || dragonfly || sqlite3_flock) && !sqlite3_nosys
package vfs
@@ -20,16 +20,16 @@ func osUnlock(file *os.File, start, len int64) _ErrorCode {
}
func osLock(file *os.File, how int, timeout time.Duration, def _ErrorCode) _ErrorCode {
before := time.Now()
var err error
for {
err = unix.Flock(int(file.Fd()), how)
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
break
}
if timeout < time.Millisecond {
if timeout <= 0 || timeout < time.Since(before) {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)

View File

@@ -1,4 +1,4 @@
//go:build !sqlite3_bsd
//go:build !sqlite3_flock && !sqlite3_nosys
package vfs

View File

@@ -1,3 +1,5 @@
//go:build !sqlite3_nosys
package vfs
import (

33
vfs/os_nolock.go Normal file
View File

@@ -0,0 +1,33 @@
//go:build !(linux || darwin || windows || freebsd || openbsd || netbsd || dragonfly || illumos) || sqlite3_nosys
package vfs
import "os"
func osGetSharedLock(file *os.File) _ErrorCode {
return _IOERR_RDLOCK
}
func osGetReservedLock(file *os.File) _ErrorCode {
return _IOERR_LOCK
}
func osGetPendingLock(file *os.File) _ErrorCode {
return _IOERR_LOCK
}
func osGetExclusiveLock(file *os.File) _ErrorCode {
return _IOERR_LOCK
}
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
return _IOERR_RDLOCK
}
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
return _IOERR_UNLOCK
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
return false, _IOERR_CHECKRESERVEDLOCK
}

View File

@@ -1,4 +1,4 @@
//go:build linux || illumos
//go:build (linux || illumos) && !sqlite3_nosys
package vfs
@@ -27,16 +27,16 @@ func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, d
Start: start,
Len: len,
}
before := time.Now()
var err error
for {
err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock)
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
break
}
if timeout < time.Millisecond {
if timeout <= 0 || timeout < time.Since(before) {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)

36
vfs/os_std_access.go Normal file
View File

@@ -0,0 +1,36 @@
//go:build !unix || sqlite3_nosys
package vfs
import (
"io/fs"
"os"
)
const (
_S_IREAD = 0400
_S_IWRITE = 0200
_S_IEXEC = 0100
)
func osAccess(path string, flags AccessFlag) error {
fi, err := os.Stat(path)
if err != nil {
return err
}
if flags == ACCESS_EXISTS {
return nil
}
var want fs.FileMode = _S_IREAD
if flags == ACCESS_READWRITE {
want |= _S_IWRITE
}
if fi.IsDir() {
want |= _S_IEXEC
}
if fi.Mode()&want != want {
return fs.ErrPermission
}
return nil
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && (!darwin || sqlite3_bsd)
//go:build !(linux || darwin) || sqlite3_flock || sqlite3_nosys
package vfs
@@ -7,10 +7,6 @@ import (
"os"
)
func osSync(file *os.File, fullsync, dataonly bool) error {
return file.Sync()
}
func osAllocate(file *os.File, size int64) error {
off, err := file.Seek(0, io.SeekEnd)
if err != nil {

14
vfs/os_std_mode.go Normal file
View File

@@ -0,0 +1,14 @@
//go:build !unix || sqlite3_nosys
package vfs
import "os"
func osSetMode(file *os.File, modeof string) error {
fi, err := os.Stat(modeof)
if err != nil {
return err
}
file.Chmod(fi.Mode())
return nil
}

12
vfs/os_std_open.go Normal file
View File

@@ -0,0 +1,12 @@
//go:build !windows || sqlite3_nosys
package vfs
import (
"io/fs"
"os"
)
func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(name, flag, perm)
}

9
vfs/os_std_sync.go Normal file
View File

@@ -0,0 +1,9 @@
//go:build !(linux || darwin) || sqlite3_flock || sqlite3_nosys
package vfs
import "os"
func osSync(file *os.File, fullsync, dataonly bool) error {
return file.Sync()
}

View File

@@ -1,20 +1,14 @@
//go:build unix
//go:build unix && !sqlite3_nosys
package vfs
import (
"io/fs"
"os"
"syscall"
"time"
"golang.org/x/sys/unix"
)
func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(name, flag, perm)
}
func osAccess(path string, flags AccessFlag) error {
var access uint32 // unix.F_OK
switch flags {
@@ -37,64 +31,3 @@ func osSetMode(file *os.File, modeof string) error {
}
return nil
}
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Test the PENDING lock before acquiring a new SHARED lock.
if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending {
return _BUSY
}
// Acquire the SHARED lock.
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
}
func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
if timeout == 0 {
timeout = time.Millisecond
}
// Acquire the EXCLUSIVE lock.
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
}
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
if state >= LOCK_EXCLUSIVE {
// Downgrade to a SHARED lock.
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return _IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return _IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return osUnlock(file, _PENDING_BYTE, 2)
}
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
// Release all locks.
return osUnlock(file, 0, 0)
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(unix.Errno); ok {
switch errno {
case
unix.EACCES,
unix.EAGAIN,
unix.EBUSY,
unix.EINTR,
unix.ENOLCK,
unix.EDEADLK,
unix.ETIMEDOUT:
return _BUSY
case unix.EPERM:
return _PERM
}
}
return def
}

82
vfs/os_unix2.go Normal file
View File

@@ -0,0 +1,82 @@
//go:build (linux || darwin || freebsd || openbsd || netbsd || dragonfly || illumos) && !sqlite3_nosys
package vfs
import (
"os"
"time"
"golang.org/x/sys/unix"
)
func osGetSharedLock(file *os.File) _ErrorCode {
// Test the PENDING lock before acquiring a new SHARED lock.
if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending {
return _BUSY
}
// Acquire the SHARED lock.
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
}
func osGetReservedLock(file *os.File) _ErrorCode {
// Acquire the RESERVED lock.
return osWriteLock(file, _RESERVED_BYTE, 1, 0)
}
func osGetPendingLock(file *os.File) _ErrorCode {
// Acquire the PENDING lock.
return osWriteLock(file, _PENDING_BYTE, 1, 0)
}
func osGetExclusiveLock(file *os.File) _ErrorCode {
// Acquire the EXCLUSIVE lock.
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, time.Millisecond)
}
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
if state >= LOCK_EXCLUSIVE {
// Downgrade to a SHARED lock.
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return _IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return _IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return osUnlock(file, _PENDING_BYTE, 2)
}
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
// Release all locks.
return osUnlock(file, 0, 0)
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
// Test the RESERVED lock.
return osCheckLock(file, _RESERVED_BYTE, 1)
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(unix.Errno); ok {
switch errno {
case
unix.EACCES,
unix.EAGAIN,
unix.EBUSY,
unix.EINTR,
unix.ENOLCK,
unix.EDEADLK,
unix.ETIMEDOUT:
return _BUSY
case unix.EPERM:
return _PERM
}
}
return def
}

View File

@@ -1,3 +1,5 @@
//go:build !sqlite3_nosys
package vfs
import (
@@ -25,40 +27,9 @@ func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.NewFile(uintptr(r), name), nil
}
func osAccess(path string, flags AccessFlag) error {
fi, err := os.Stat(path)
if err != nil {
return err
}
if flags == ACCESS_EXISTS {
return nil
}
var want fs.FileMode = windows.S_IRUSR
if flags == ACCESS_READWRITE {
want |= windows.S_IWUSR
}
if fi.IsDir() {
want |= windows.S_IXUSR
}
if fi.Mode()&want != want {
return fs.ErrPermission
}
return nil
}
func osSetMode(file *os.File, modeof string) error {
fi, err := os.Stat(modeof)
if err != nil {
return err
}
file.Chmod(fi.Mode())
return nil
}
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
func osGetSharedLock(file *os.File) _ErrorCode {
// Acquire the PENDING lock temporarily before acquiring a new SHARED lock.
rc := osReadLock(file, _PENDING_BYTE, 1, timeout)
rc := osReadLock(file, _PENDING_BYTE, 1, 0)
if rc == _OK {
// Acquire the SHARED lock.
@@ -70,16 +41,22 @@ func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
return rc
}
func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
if timeout == 0 {
timeout = time.Millisecond
}
func osGetReservedLock(file *os.File) _ErrorCode {
// Acquire the RESERVED lock.
return osWriteLock(file, _RESERVED_BYTE, 1, 0)
}
func osGetPendingLock(file *os.File) _ErrorCode {
// Acquire the PENDING lock.
return osWriteLock(file, _PENDING_BYTE, 1, 0)
}
func osGetExclusiveLock(file *os.File) _ErrorCode {
// Release the SHARED lock.
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Acquire the EXCLUSIVE lock.
rc := osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
rc := osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, time.Millisecond)
if rc != _OK {
// Reacquire the SHARED lock.
@@ -125,6 +102,11 @@ func osReleaseLock(file *os.File, state LockLevel) _ErrorCode {
return _OK
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
// Test the RESERVED lock.
return osCheckLock(file, _RESERVED_BYTE, 1)
}
func osUnlock(file *os.File, start, len uint32) _ErrorCode {
err := windows.UnlockFileEx(windows.Handle(file.Fd()),
0, len, 0, &windows.Overlapped{Offset: start})
@@ -138,6 +120,7 @@ func osUnlock(file *os.File, start, len uint32) _ErrorCode {
}
func osLock(file *os.File, flags, start, len uint32, timeout time.Duration, def _ErrorCode) _ErrorCode {
before := time.Now()
var err error
for {
err = windows.LockFileEx(windows.Handle(file.Fd()), flags,
@@ -145,11 +128,16 @@ func osLock(file *os.File, flags, start, len uint32, timeout time.Duration, def
if errno, _ := err.(windows.Errno); errno != windows.ERROR_LOCK_VIOLATION {
break
}
if timeout < time.Millisecond {
if timeout <= 0 || timeout < time.Since(before) {
break
}
if err := windows.TimeBeginPeriod(1); err != nil {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
if err := windows.TimeEndPeriod(1); err != nil {
break
}
}
return osLockErrorCode(err, def)
}

View File

@@ -2,6 +2,7 @@ package mptest
import (
"bytes"
"compress/bzip2"
"context"
"crypto/rand"
"embed"
@@ -24,8 +25,8 @@ import (
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
)
//go:embed testdata/mptest.wasm
var binary []byte
//go:embed testdata/mptest.wasm.bz2
var compressed string
//go:embed testdata/*.*test
var scripts embed.FS
@@ -48,6 +49,11 @@ func TestMain(m *testing.M) {
panic(err)
}
binary, err := io.ReadAll(bzip2.NewReader(strings.NewReader(compressed)))
if err != nil {
panic(err)
}
module, err = rt.CompileModule(ctx, binary)
if err != nil {
panic(err)
@@ -151,7 +157,7 @@ func Test_multiwrite01(t *testing.T) {
func Test_config01_memory(t *testing.T) {
ctx := util.NewContext(newContext(t))
cfg := config(ctx).WithArgs("mptest", "test.db",
cfg := config(ctx).WithArgs("mptest", "/test.db",
"config01.test",
"--vfs", "memdb",
"--timeout", "1000")

View File

@@ -1,2 +1,2 @@
mptest.wasm filter=lfs diff=lfs merge=lfs -text
mptest.wasm.bz2 filter=lfs diff=lfs merge=lfs -text
*.*test -crlf

View File

@@ -4,7 +4,7 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../../../../
BINARYEN="$ROOT/tools/binaryen-version_114/bin"
BINARYEN="$ROOT/tools/binaryen-version_116/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
@@ -28,4 +28,5 @@ WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
--enable-simd --enable-mutable-globals --enable-multivalue \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext
mv mptest.tmp mptest.wasm
mv mptest.tmp mptest.wasm
bzip2 -9f mptest.wasm

View File

@@ -1,5 +1,4 @@
#include <stdbool.h>
#include <stddef.h>
#include <unistd.h>
// Amalgamation
#include "sqlite3.c"
@@ -8,9 +7,8 @@
__attribute__((constructor)) void init() { sqlite3_initialize(); }
static int dont_unlink(const char *pathname) { return 0; }
#define sqlite3_enable_load_extension(...)
#define sqlite3_trace(...)
#define unlink dont_unlink
#define unlink(...) (0)
#undef UNUSED_PARAMETER
#include "mptest.c"

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d23b37d507077cdcbb616852185370a227278b599187dc134200ed274a7a3a02
size 1441194

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:74294bf19d213056ef5ffb7a980c3a7de5d029d0621ded53394d3055dfc4f604
size 513604

View File

@@ -2,6 +2,7 @@ package speedtest1
import (
"bytes"
"compress/bzip2"
"context"
"crypto/rand"
"flag"
@@ -23,8 +24,8 @@ import (
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
//go:embed testdata/speedtest1.wasm
var binary []byte
//go:embed testdata/speedtest1.wasm.bz2
var compressed string
var (
rt wazero.Runtime
@@ -45,6 +46,11 @@ func TestMain(m *testing.M) {
panic(err)
}
binary, err := io.ReadAll(bzip2.NewReader(strings.NewReader(compressed)))
if err != nil {
panic(err)
}
module, err = rt.CompileModule(ctx, binary)
if err != nil {
panic(err)

View File

@@ -1 +1 @@
speedtest1.wasm filter=lfs diff=lfs merge=lfs -text
speedtest1.wasm.bz2 filter=lfs diff=lfs merge=lfs -text

View File

@@ -4,7 +4,7 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../../../../
BINARYEN="$ROOT/tools/binaryen-version_114/bin"
BINARYEN="$ROOT/tools/binaryen-version_116/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
@@ -23,4 +23,5 @@ WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
--enable-simd --enable-mutable-globals --enable-multivalue \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext
mv speedtest1.tmp speedtest1.wasm
mv speedtest1.tmp speedtest1.wasm
bzip2 -9f speedtest1.wasm

View File

@@ -6,5 +6,5 @@
// VFS
#include "vfs.c"
#define randomFunc(args...) randomFunc2(args)
#define randomFunc randomFunc2
#include "speedtest1.c"

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:83d67feda51cc974634e245ac2b072f9587c607c7ad97321f2de9dde2188e63a
size 1481348

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:966754393264cc43eb931ece22941d0d607e7e776e26c26b548209d2264d01a1
size 527530

Some files were not shown because too many files have changed in this diff Show More