mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 22:19:14 +00:00
Compare commits
32 Commits
gormlite/v
...
v0.10.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
828788912e | ||
|
|
6f8645cd2e | ||
|
|
c00927e8bb | ||
|
|
6b28be6d0e | ||
|
|
310b4ff29d | ||
|
|
e82cf16b11 | ||
|
|
24c9b57c56 | ||
|
|
24b965ac7e | ||
|
|
446168c572 | ||
|
|
a9e2cbbfc5 | ||
|
|
a7c00eb150 | ||
|
|
0bcdb712ba | ||
|
|
2157d0f325 | ||
|
|
6353160619 | ||
|
|
501d157279 | ||
|
|
4db18a7b9a | ||
|
|
a9dddaa86c | ||
|
|
b25936dbec | ||
|
|
bf23041e46 | ||
|
|
d60fceac92 | ||
|
|
61da30f44a | ||
|
|
d4ff605983 | ||
|
|
8d0c654178 | ||
|
|
728e59951b | ||
|
|
f7b16bad5c | ||
|
|
db3e6da31a | ||
|
|
3f443b2ecc | ||
|
|
eec45ea684 | ||
|
|
f6d77f3cf4 | ||
|
|
d5d7cd1f2d | ||
|
|
a33a187d48 | ||
|
|
70c6ee15c6 |
29
.github/workflows/bsd.yml
vendored
Normal file
29
.github/workflows/bsd.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
name: BSD
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: macos-12
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: 'true'
|
||||
|
||||
- name: Set up
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: stable
|
||||
|
||||
- name: Build
|
||||
run: GOOS=freebsd go test -c ./...
|
||||
|
||||
- name: Test
|
||||
uses: cross-platform-actions/action@v0.21.1
|
||||
with:
|
||||
operating_system: freebsd
|
||||
version: '13.2'
|
||||
sync_files: runner-to-vm
|
||||
run: find . -name '*.test' -maxdepth 1 -exec {} -test.v \;
|
||||
76
.github/workflows/codeql.yml
vendored
76
.github/workflows/codeql.yml
vendored
@@ -1,76 +0,0 @@
|
||||
# For most projects, this workflow file will not need changing; you simply need
|
||||
# to commit it to your repository.
|
||||
#
|
||||
# You may wish to alter this file to override the set of languages analyzed,
|
||||
# or to provide custom queries or build logic.
|
||||
#
|
||||
# ******** NOTE ********
|
||||
# We have attempted to detect the languages in your repository. Please check
|
||||
# the `language` matrix defined below to confirm you have the correct set of
|
||||
# supported CodeQL languages.
|
||||
#
|
||||
name: "CodeQL"
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
pull_request:
|
||||
# The branches below must be a subset of the branches above
|
||||
branches: [ "main" ]
|
||||
schedule:
|
||||
- cron: '15 18 * * 6'
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: Analyze
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
language: [ 'go' ]
|
||||
# CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
|
||||
# Use only 'java' to analyze code written in Java, Kotlin or both
|
||||
# Use only 'javascript' to analyze code written in JavaScript, TypeScript or both
|
||||
# Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v2
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||
# By default, queries listed here will override any specified in a config file.
|
||||
# Prefix the list here with "+" to use these queries and those in the config file.
|
||||
|
||||
# Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
|
||||
# queries: security-extended,security-and-quality
|
||||
|
||||
|
||||
# Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java).
|
||||
# If this step fails, then you should remove it and run the build manually (see below)
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v2
|
||||
|
||||
# ℹ️ Command-line programs to run using the OS shell.
|
||||
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
|
||||
|
||||
# If the Autobuild fails above, remove it and uncomment the following three lines.
|
||||
# modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance.
|
||||
|
||||
# - run: |
|
||||
# echo "Run, Build Application using script"
|
||||
# ./location_of_script_within_repo/buildscript.sh
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v2
|
||||
with:
|
||||
category: "/language:${{matrix.language}}"
|
||||
22
.github/workflows/cross.sh
vendored
Executable file
22
.github/workflows/cross.sh
vendored
Executable file
@@ -0,0 +1,22 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
echo android ; GOOS=android GOARCH=amd64 go build .
|
||||
echo darwin ; GOOS=darwin GOARCH=amd64 go build .
|
||||
echo dragonfly ; GOOS=dragonfly GOARCH=amd64 go build .
|
||||
echo freebsd ; GOOS=freebsd GOARCH=amd64 go build .
|
||||
echo illumos ; GOOS=illumos GOARCH=amd64 go build .
|
||||
echo ios ; GOOS=ios GOARCH=amd64 go build .
|
||||
echo linux ; GOOS=linux GOARCH=amd64 go build .
|
||||
echo netbsd ; GOOS=netbsd GOARCH=amd64 go build .
|
||||
echo openbsd ; GOOS=openbsd GOARCH=amd64 go build .
|
||||
echo plan9 ; GOOS=plan9 GOARCH=amd64 go build .
|
||||
echo solaris ; GOOS=solaris GOARCH=amd64 go build .
|
||||
echo windows ; GOOS=windows GOARCH=amd64 go build .
|
||||
# echo aix ; GOOS=aix GOARCH=ppc64 go build .
|
||||
echo js ; GOOS=js GOARCH=wasm go build .
|
||||
echo wasip1 ; GOOS=wasip1 GOARCH=wasm go build .
|
||||
echo darwin-flock ; GOOS=darwin GOARCH=amd64 go build -tags sqlite3_flock .
|
||||
echo darwin-nosys ; GOOS=darwin GOARCH=amd64 go build -tags sqlite3_nosys .
|
||||
echo linux-nosys ; GOOS=linux GOARCH=amd64 go build -tags sqlite3_nosys .
|
||||
echo windows-nosys ; GOOS=windows GOARCH=amd64 go build -tags sqlite3_nosys .
|
||||
echo freebsd-nosys ; GOOS=freebsd GOARCH=amd64 go build -tags sqlite3_nosys .
|
||||
21
.github/workflows/cross.yml
vendored
Normal file
21
.github/workflows/cross.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Cross compile
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: 'true'
|
||||
|
||||
- name: Set up
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: stable
|
||||
|
||||
- name: Build
|
||||
run: .github/workflows/cross.sh
|
||||
13
.github/workflows/go.yml
vendored
13
.github/workflows/go.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: 'true'
|
||||
|
||||
@@ -39,7 +39,6 @@ jobs:
|
||||
|
||||
- name: Vet
|
||||
run: go vet ./...
|
||||
continue-on-error: true
|
||||
|
||||
- name: Build
|
||||
run: go build -v ./...
|
||||
@@ -48,8 +47,7 @@ jobs:
|
||||
run: go test -v ./...
|
||||
|
||||
- name: Test no locks
|
||||
run: go test -v -tags sqlite3_nolock .
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: go test -v -tags sqlite3_nosys ./tests -run TestDB_nolock
|
||||
|
||||
- name: Test BSD locks
|
||||
run: go test -v -tags sqlite3_flock ./...
|
||||
@@ -58,10 +56,9 @@ jobs:
|
||||
- name: Coverage report
|
||||
uses: ncruces/go-coverage-report@v0
|
||||
with:
|
||||
chart: 'true'
|
||||
amend: 'true'
|
||||
reuse-go: 'true'
|
||||
chart: true
|
||||
amend: true
|
||||
reuse-go: true
|
||||
if: |
|
||||
github.event_name == 'push' &&
|
||||
matrix.os == 'ubuntu-latest'
|
||||
continue-on-error: true
|
||||
|
||||
44
README.md
44
README.md
@@ -15,16 +15,18 @@ and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings.
|
||||
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
|
||||
- [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
|
||||
embeds a build of SQLite into your application.
|
||||
- [`github.com/ncruces/go-sqlite3/ext/blob`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/blob)
|
||||
simplifies incremental BLOB I/O.
|
||||
- [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
|
||||
registers [statistics functions](https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html).
|
||||
- [`github.com/ncruces/go-sqlite3/ext/unicode`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/unicode)
|
||||
registers Unicode aware functions.
|
||||
- [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
|
||||
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
|
||||
- [`github.com/ncruces/go-sqlite3/vfs/memdb`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/memdb)
|
||||
implements an in-memory VFS.
|
||||
- [`github.com/ncruces/go-sqlite3/vfs/readervfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/readervfs)
|
||||
implements a VFS for immutable databases.
|
||||
- [`github.com/ncruces/go-sqlite3/ext/unicode`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/unicode)
|
||||
registers Unicode aware functions.
|
||||
- [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
|
||||
registers [statistics functions](https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html).
|
||||
- [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
|
||||
provides a [GORM](https://gorm.io) driver.
|
||||
|
||||
@@ -43,39 +45,39 @@ To work around this limitation, SQLite is [patched](sqlite3/locking_mode.patch)
|
||||
to always use `EXCLUSIVE` locking mode for WAL databases.
|
||||
|
||||
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
|
||||
to use the [`database/sql`](https://pkg.go.dev/database/sql)
|
||||
driver with WAL mode databases you should disable connection pooling by calling
|
||||
to use the [`database/sql`](https://pkg.go.dev/database/sql) driver
|
||||
with WAL mode databases you should disable connection pooling by calling
|
||||
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
|
||||
|
||||
#### POSIX Advisory Locks
|
||||
#### File Locking
|
||||
|
||||
POSIX advisory locks, which SQLite uses, are
|
||||
POSIX advisory locks, which SQLite uses on Unix, are
|
||||
[broken by design](https://www.sqlite.org/src/artifact/2e8b12?ln=1073-1161).
|
||||
|
||||
On Linux, macOS and illumos, this module uses
|
||||
[OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
|
||||
to synchronize access to database files.
|
||||
OFD locks are fully compatible with process-associated POSIX advisory locks.
|
||||
OFD locks are fully compatible with POSIX advisory locks.
|
||||
|
||||
On BSD Unixes, this module may use
|
||||
On BSD Unixes, this module uses
|
||||
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2).
|
||||
BSD locks may _not_ be compatible with process-associated POSIX advisory locks.
|
||||
On BSD Unixes, BSD locks are fully compatible with POSIX advisory locks.
|
||||
|
||||
##### TL;DR
|
||||
On Windows, this module uses `LockFile`, `LockFileEx`, and `UnlockFile`,
|
||||
like SQLite.
|
||||
|
||||
In all platforms for which this package builds,
|
||||
it should be safe to use it to access databases concurrently,
|
||||
from multiple goroutines, processes, and
|
||||
with _other_ implementations of SQLite.
|
||||
|
||||
If the package does not build for your platform,
|
||||
see [this](vfs/README.md#portability).
|
||||
On all other platforms, file locking is not supported, and you must use
|
||||
[`nolock=1`](https://www.sqlite.org/uri.html#urinolock)
|
||||
to open database files.
|
||||
To use the [`database/sql`](https://pkg.go.dev/database/sql) driver
|
||||
with `nolock=1` you must disable connection pooling by calling
|
||||
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
|
||||
|
||||
#### Testing
|
||||
|
||||
The pure Go VFS is tested by running SQLite's
|
||||
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
|
||||
on Linux, macOS and Windows.
|
||||
on Linux, macOS, Windows and FreeBSD.
|
||||
Performance is tested by running
|
||||
[speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c).
|
||||
|
||||
@@ -86,6 +88,8 @@ Performance is tested by running
|
||||
- [x] nested transactions
|
||||
- [x] incremental BLOB I/O
|
||||
- [x] online backup
|
||||
- [x] JSON support
|
||||
- [ ] virtual tables
|
||||
- [ ] session extension
|
||||
- [ ] custom VFSes
|
||||
- [x] custom VFS API
|
||||
|
||||
7
blob.go
7
blob.go
@@ -118,8 +118,8 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
want := int64(1024 * 1024)
|
||||
avail := b.bytes - b.offset
|
||||
want := int64(65536)
|
||||
if want > avail {
|
||||
want = avail
|
||||
}
|
||||
@@ -175,8 +175,11 @@ func (b *Blob) Write(p []byte) (n int, err error) {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/blob_write.html
|
||||
func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
want := int64(1024 * 1024)
|
||||
avail := b.bytes - b.offset
|
||||
want := int64(65536)
|
||||
if l, ok := r.(*io.LimitedReader); ok && want > l.N {
|
||||
want = l.N
|
||||
}
|
||||
if want > avail {
|
||||
want = avail
|
||||
}
|
||||
|
||||
78
conn.go
78
conn.go
@@ -7,10 +7,9 @@ import (
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
// Conn is a database connection handle.
|
||||
@@ -21,7 +20,6 @@ type Conn struct {
|
||||
*sqlite
|
||||
|
||||
interrupt context.Context
|
||||
waiter chan struct{}
|
||||
pending *Stmt
|
||||
arena arena
|
||||
|
||||
@@ -48,6 +46,8 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
|
||||
return newConn(filename, flags)
|
||||
}
|
||||
|
||||
type connKey struct{}
|
||||
|
||||
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
@@ -63,6 +63,7 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
|
||||
|
||||
c := &Conn{sqlite: sqlite}
|
||||
c.arena = c.newArena(1024)
|
||||
c.ctx = context.WithValue(c.ctx, connKey{}, c)
|
||||
c.handle, err = c.openDB(filename, flags)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -131,7 +132,6 @@ func (c *Conn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.SetInterrupt(context.Background())
|
||||
c.pending.Close()
|
||||
c.pending = nil
|
||||
|
||||
@@ -244,65 +244,40 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Is a waiter running?
|
||||
if c.waiter != nil {
|
||||
c.waiter <- struct{}{} // Cancel the waiter.
|
||||
<-c.waiter // Wait for it to finish.
|
||||
c.waiter = nil
|
||||
}
|
||||
// Reset the pending statement.
|
||||
if c.pending != nil {
|
||||
// An uncompleted SQL statement prevents SQLite from ignoring
|
||||
// an interrupt that comes before any other statements are started.
|
||||
if c.pending == nil {
|
||||
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
|
||||
} else {
|
||||
c.pending.Reset()
|
||||
}
|
||||
|
||||
old = c.interrupt
|
||||
c.interrupt = ctx
|
||||
// Remove the handler if the context can't be canceled.
|
||||
if ctx == nil || ctx.Done() == nil {
|
||||
c.call(c.api.progressHandler, uint64(c.handle), 0)
|
||||
return old
|
||||
}
|
||||
|
||||
// Creating an uncompleted SQL statement prevents SQLite from ignoring
|
||||
// an interrupt that comes before any other statements are started.
|
||||
if c.pending == nil {
|
||||
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
|
||||
}
|
||||
c.pending.Step()
|
||||
|
||||
// Don't create the goroutine if we're already interrupted.
|
||||
// This happens frequently while restoring to a previously interrupted state.
|
||||
if c.checkInterrupt() {
|
||||
return old
|
||||
}
|
||||
|
||||
waiter := make(chan struct{})
|
||||
c.waiter = waiter
|
||||
go func() {
|
||||
select {
|
||||
case <-waiter: // Waiter was cancelled.
|
||||
break
|
||||
|
||||
case <-ctx.Done(): // Done was closed.
|
||||
const isInterruptedOffset = 288
|
||||
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
|
||||
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
|
||||
// Wait for the next call to SetInterrupt.
|
||||
<-waiter
|
||||
}
|
||||
|
||||
// Signal that the waiter has finished.
|
||||
waiter <- struct{}{}
|
||||
}()
|
||||
c.call(c.api.progressHandler, uint64(c.handle), 100)
|
||||
return old
|
||||
}
|
||||
|
||||
func (c *Conn) checkInterrupt() bool {
|
||||
if c.interrupt == nil || c.interrupt.Err() == nil {
|
||||
return false
|
||||
func callbackProgress(ctx context.Context, mod api.Module, _ uint32) uint32 {
|
||||
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
|
||||
if c.interrupt != nil && c.interrupt.Err() != nil {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *Conn) checkInterrupt() {
|
||||
if c.interrupt != nil && c.interrupt.Err() != nil {
|
||||
c.call(c.api.interrupt, uint64(c.handle))
|
||||
}
|
||||
const isInterruptedOffset = 288
|
||||
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
|
||||
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
|
||||
return true
|
||||
}
|
||||
|
||||
// Pragma executes a PRAGMA statement and returns any results.
|
||||
@@ -328,12 +303,9 @@ func (c *Conn) error(rc uint64, sql ...string) error {
|
||||
|
||||
// DriverConn is implemented by the SQLite [database/sql] driver connection.
|
||||
//
|
||||
// It can be used to access advanced SQLite features like
|
||||
// [savepoints], [online backup] and [incremental BLOB I/O].
|
||||
// It can be used to access SQLite features like [online backup].
|
||||
//
|
||||
// [savepoints]: https://www.sqlite.org/lang_savepoint.html
|
||||
// [online backup]: https://www.sqlite.org/backup.html
|
||||
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
|
||||
type DriverConn interface {
|
||||
Raw() *Conn
|
||||
}
|
||||
|
||||
1
const.go
1
const.go
@@ -97,6 +97,7 @@ const (
|
||||
IOERR_ROLLBACK_ATOMIC ExtendedErrorCode = xErrorCode(IOERR) | (31 << 8)
|
||||
IOERR_DATA ExtendedErrorCode = xErrorCode(IOERR) | (32 << 8)
|
||||
IOERR_CORRUPTFS ExtendedErrorCode = xErrorCode(IOERR) | (33 << 8)
|
||||
IOERR_IN_PAGE ExtendedErrorCode = xErrorCode(IOERR) | (34 << 8)
|
||||
LOCKED_SHAREDCACHE ExtendedErrorCode = xErrorCode(LOCKED) | (1 << 8)
|
||||
LOCKED_VTAB ExtendedErrorCode = xErrorCode(LOCKED) | (2 << 8)
|
||||
BUSY_RECOVERY ExtendedErrorCode = xErrorCode(BUSY) | (1 << 8)
|
||||
|
||||
153
context.go
153
context.go
@@ -1,6 +1,7 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
"time"
|
||||
@@ -13,24 +14,33 @@ import (
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/context.html
|
||||
type Context struct {
|
||||
*sqlite
|
||||
c *Conn
|
||||
handle uint32
|
||||
}
|
||||
|
||||
// Conn returns the database connection of the
|
||||
// [Conn.CreateFunction] or [Conn.CreateWindowFunction]
|
||||
// routines that originally registered the application defined function.
|
||||
//
|
||||
// https://sqlite.org/c3ref/context_db_handle.html
|
||||
func (ctx Context) Conn() *Conn {
|
||||
return ctx.c
|
||||
}
|
||||
|
||||
// SetAuxData saves metadata for argument n of the function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/get_auxdata.html
|
||||
func (c Context) SetAuxData(n int, data any) {
|
||||
ptr := util.AddHandle(c.ctx, data)
|
||||
c.call(c.api.setAuxData, uint64(c.handle), uint64(n), uint64(ptr))
|
||||
func (ctx Context) SetAuxData(n int, data any) {
|
||||
ptr := util.AddHandle(ctx.c.ctx, data)
|
||||
ctx.c.call(ctx.c.api.setAuxData, uint64(ctx.handle), uint64(n), uint64(ptr))
|
||||
}
|
||||
|
||||
// GetAuxData returns metadata for argument n of the function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/get_auxdata.html
|
||||
func (c Context) GetAuxData(n int) any {
|
||||
ptr := uint32(c.call(c.api.getAuxData, uint64(c.handle), uint64(n)))
|
||||
return util.GetHandle(c.ctx, ptr)
|
||||
func (ctx Context) GetAuxData(n int) any {
|
||||
ptr := uint32(ctx.c.call(ctx.c.api.getAuxData, uint64(ctx.handle), uint64(n)))
|
||||
return util.GetHandle(ctx.c.ctx, ptr)
|
||||
}
|
||||
|
||||
// ResultBool sets the result of the function to a bool.
|
||||
@@ -38,125 +48,160 @@ func (c Context) GetAuxData(n int) any {
|
||||
// Instead, boolean values are stored as integers 0 (false) and 1 (true).
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultBool(value bool) {
|
||||
func (ctx Context) ResultBool(value bool) {
|
||||
var i int64
|
||||
if value {
|
||||
i = 1
|
||||
}
|
||||
c.ResultInt64(i)
|
||||
ctx.ResultInt64(i)
|
||||
}
|
||||
|
||||
// ResultInt sets the result of the function to an int.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultInt(value int) {
|
||||
c.ResultInt64(int64(value))
|
||||
func (ctx Context) ResultInt(value int) {
|
||||
ctx.ResultInt64(int64(value))
|
||||
}
|
||||
|
||||
// ResultInt64 sets the result of the function to an int64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultInt64(value int64) {
|
||||
c.call(c.api.resultInteger,
|
||||
uint64(c.handle), uint64(value))
|
||||
func (ctx Context) ResultInt64(value int64) {
|
||||
ctx.c.call(ctx.c.api.resultInteger,
|
||||
uint64(ctx.handle), uint64(value))
|
||||
}
|
||||
|
||||
// ResultFloat sets the result of the function to a float64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultFloat(value float64) {
|
||||
c.call(c.api.resultFloat,
|
||||
uint64(c.handle), math.Float64bits(value))
|
||||
func (ctx Context) ResultFloat(value float64) {
|
||||
ctx.c.call(ctx.c.api.resultFloat,
|
||||
uint64(ctx.handle), math.Float64bits(value))
|
||||
}
|
||||
|
||||
// ResultText sets the result of the function to a string.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultText(value string) {
|
||||
ptr := c.newString(value)
|
||||
c.call(c.api.resultText,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(value)),
|
||||
uint64(c.api.destructor), _UTF8)
|
||||
func (ctx Context) ResultText(value string) {
|
||||
ptr := ctx.c.newString(value)
|
||||
ctx.c.call(ctx.c.api.resultText,
|
||||
uint64(ctx.handle), uint64(ptr), uint64(len(value)),
|
||||
uint64(ctx.c.api.destructor), _UTF8)
|
||||
}
|
||||
|
||||
// ResultBlob sets the result of the function to a []byte.
|
||||
// Returning a nil slice is the same as calling [Context.ResultNull].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultBlob(value []byte) {
|
||||
ptr := c.newBytes(value)
|
||||
c.call(c.api.resultBlob,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(value)),
|
||||
uint64(c.api.destructor))
|
||||
func (ctx Context) ResultBlob(value []byte) {
|
||||
ptr := ctx.c.newBytes(value)
|
||||
ctx.c.call(ctx.c.api.resultBlob,
|
||||
uint64(ctx.handle), uint64(ptr), uint64(len(value)),
|
||||
uint64(ctx.c.api.destructor))
|
||||
}
|
||||
|
||||
// BindZeroBlob sets the result of the function to a zero-filled, length n BLOB.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultZeroBlob(n int64) {
|
||||
c.call(c.api.resultZeroBlob,
|
||||
uint64(c.handle), uint64(n))
|
||||
func (ctx Context) ResultZeroBlob(n int64) {
|
||||
ctx.c.call(ctx.c.api.resultZeroBlob,
|
||||
uint64(ctx.handle), uint64(n))
|
||||
}
|
||||
|
||||
// ResultNull sets the result of the function to NULL.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultNull() {
|
||||
c.call(c.api.resultNull,
|
||||
uint64(c.handle))
|
||||
func (ctx Context) ResultNull() {
|
||||
ctx.c.call(ctx.c.api.resultNull,
|
||||
uint64(ctx.handle))
|
||||
}
|
||||
|
||||
// ResultTime sets the result of the function to a [time.Time].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultTime(value time.Time, format TimeFormat) {
|
||||
func (ctx Context) ResultTime(value time.Time, format TimeFormat) {
|
||||
if format == TimeFormatDefault {
|
||||
c.resultRFC3339Nano(value)
|
||||
ctx.resultRFC3339Nano(value)
|
||||
return
|
||||
}
|
||||
switch v := format.Encode(value).(type) {
|
||||
case string:
|
||||
c.ResultText(v)
|
||||
ctx.ResultText(v)
|
||||
case int64:
|
||||
c.ResultInt64(v)
|
||||
ctx.ResultInt64(v)
|
||||
case float64:
|
||||
c.ResultFloat(v)
|
||||
ctx.ResultFloat(v)
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
}
|
||||
|
||||
func (c Context) resultRFC3339Nano(value time.Time) {
|
||||
const maxlen = uint64(len(time.RFC3339Nano))
|
||||
func (ctx Context) resultRFC3339Nano(value time.Time) {
|
||||
const maxlen = uint64(len(time.RFC3339Nano)) + 5
|
||||
|
||||
ptr := c.new(maxlen)
|
||||
buf := util.View(c.mod, ptr, maxlen)
|
||||
ptr := ctx.c.new(maxlen)
|
||||
buf := util.View(ctx.c.mod, ptr, maxlen)
|
||||
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
|
||||
|
||||
c.call(c.api.resultText,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(buf)),
|
||||
uint64(c.api.destructor), _UTF8)
|
||||
ctx.c.call(ctx.c.api.resultText,
|
||||
uint64(ctx.handle), uint64(ptr), uint64(len(buf)),
|
||||
uint64(ctx.c.api.destructor), _UTF8)
|
||||
}
|
||||
|
||||
// ResultPointer sets the result of the function to NULL, just like [Context.ResultNull],
|
||||
// except that it also associates ptr with that NULL value such that it can be retrieved
|
||||
// within an application-defined SQL function using [Value.Pointer].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultPointer(ptr any) {
|
||||
valPtr := util.AddHandle(ctx.c.ctx, ptr)
|
||||
ctx.c.call(ctx.c.api.resultPointer, uint64(valPtr))
|
||||
}
|
||||
|
||||
// ResultJSON sets the result of the function to the JSON encoding of value.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultJSON(value any) {
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
}
|
||||
ptr := ctx.c.newBytes(data)
|
||||
ctx.c.call(ctx.c.api.resultText,
|
||||
uint64(ctx.handle), uint64(ptr), uint64(len(data)),
|
||||
uint64(ctx.c.api.destructor))
|
||||
}
|
||||
|
||||
// ResultValue sets the result of the function a copy of [Value].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultValue(value Value) {
|
||||
if value.sqlite != ctx.c.sqlite {
|
||||
ctx.ResultError(MISUSE)
|
||||
}
|
||||
ctx.c.call(ctx.c.api.resultValue,
|
||||
uint64(ctx.handle), uint64(value.handle))
|
||||
}
|
||||
|
||||
// ResultError sets the result of the function an error.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultError(err error) {
|
||||
func (ctx Context) ResultError(err error) {
|
||||
if errors.Is(err, NOMEM) {
|
||||
c.call(c.api.resultErrorMem, uint64(c.handle))
|
||||
ctx.c.call(ctx.c.api.resultErrorMem, uint64(ctx.handle))
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, TOOBIG) {
|
||||
c.call(c.api.resultErrorBig, uint64(c.handle))
|
||||
ctx.c.call(ctx.c.api.resultErrorBig, uint64(ctx.handle))
|
||||
return
|
||||
}
|
||||
|
||||
str := err.Error()
|
||||
ptr := c.newString(str)
|
||||
c.call(c.api.resultError,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(str)))
|
||||
c.free(ptr)
|
||||
ptr := ctx.c.newString(str)
|
||||
ctx.c.call(ctx.c.api.resultError,
|
||||
uint64(ctx.handle), uint64(ptr), uint64(len(str)))
|
||||
ctx.c.free(ptr)
|
||||
|
||||
var code uint64
|
||||
var ecode ErrorCode
|
||||
@@ -168,7 +213,7 @@ func (c Context) ResultError(err error) {
|
||||
code = uint64(ecode)
|
||||
}
|
||||
if code != 0 {
|
||||
c.call(c.api.resultErrorCode,
|
||||
uint64(c.handle), code)
|
||||
ctx.c.call(ctx.c.api.resultErrorCode,
|
||||
uint64(ctx.handle), code)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,6 +30,8 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
@@ -55,8 +57,8 @@ func init() {
|
||||
//
|
||||
// The init function is called by the driver on new connections.
|
||||
// The conn can be used to execute queries, register functions, etc.
|
||||
// Any error return closes the conn and passes the error to database/sql.
|
||||
func Open(dataSourceName string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*sql.DB, error) {
|
||||
// Any error return closes the conn and passes the error to [database/sql].
|
||||
func Open(dataSourceName string, init func(*sqlite3.Conn) error) (*sql.DB, error) {
|
||||
c, err := newConnector(dataSourceName, init)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -78,7 +80,7 @@ func (sqlite) OpenConnector(name string) (driver.Connector, error) {
|
||||
return newConnector(name, nil)
|
||||
}
|
||||
|
||||
func newConnector(name string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*connector, error) {
|
||||
func newConnector(name string, init func(*sqlite3.Conn) error) (*connector, error) {
|
||||
c := connector{name: name, init: init}
|
||||
if strings.HasPrefix(name, "file:") {
|
||||
if _, after, ok := strings.Cut(name, "?"); ok {
|
||||
@@ -94,7 +96,7 @@ func newConnector(name string, init func(ctx context.Context, conn *sqlite3.Conn
|
||||
}
|
||||
|
||||
type connector struct {
|
||||
init func(ctx context.Context, conn *sqlite3.Conn) error
|
||||
init func(*sqlite3.Conn) error
|
||||
name string
|
||||
txlock string
|
||||
pragmas bool
|
||||
@@ -132,27 +134,24 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.reusable = true
|
||||
} else {
|
||||
s, _, err := c.Conn.Prepare(`
|
||||
SELECT * FROM
|
||||
PRAGMA_locking_mode,
|
||||
PRAGMA_query_only;
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.Step() {
|
||||
c.reusable = s.ColumnText(0) == "normal"
|
||||
c.readOnly = s.ColumnRawText(1)[0] // 0 or 1
|
||||
}
|
||||
err = s.Close()
|
||||
}
|
||||
if n.init != nil {
|
||||
err = n.init(c.Conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if n.init != nil {
|
||||
err = n.init(ctx, c.Conn)
|
||||
if n.pragmas || n.init != nil {
|
||||
s, _, err := c.Conn.Prepare(`PRAGMA query_only`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.Step() && s.ColumnBool(0) {
|
||||
c.readOnly = '1'
|
||||
} else {
|
||||
c.readOnly = '0'
|
||||
}
|
||||
err = s.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -165,7 +164,6 @@ type conn struct {
|
||||
txBegin string
|
||||
txCommit string
|
||||
txRollback string
|
||||
reusable bool
|
||||
readOnly byte
|
||||
}
|
||||
|
||||
@@ -174,7 +172,6 @@ var (
|
||||
_ driver.ConnPrepareContext = &conn{}
|
||||
_ driver.ExecerContext = &conn{}
|
||||
_ driver.ConnBeginTx = &conn{}
|
||||
_ driver.Validator = &conn{}
|
||||
_ sqlite3.DriverConn = &conn{}
|
||||
)
|
||||
|
||||
@@ -182,10 +179,6 @@ func (c *conn) Raw() *sqlite3.Conn {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *conn) IsValid() bool {
|
||||
return c.reusable
|
||||
}
|
||||
|
||||
func (c *conn) Begin() (driver.Tx, error) {
|
||||
return c.BeginTx(context.Background(), driver.TxOptions{})
|
||||
}
|
||||
@@ -199,10 +192,10 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
|
||||
txBegin = `
|
||||
BEGIN deferred;
|
||||
PRAGMA query_only=on`
|
||||
c.txCommit = `
|
||||
c.txRollback = `
|
||||
ROLLBACK;
|
||||
PRAGMA query_only=` + string(c.readOnly)
|
||||
c.txRollback = c.txCommit
|
||||
c.txCommit = c.txRollback
|
||||
}
|
||||
|
||||
switch opts.Isolation {
|
||||
@@ -233,7 +226,13 @@ func (c *conn) Commit() error {
|
||||
}
|
||||
|
||||
func (c *conn) Rollback() error {
|
||||
return c.Conn.Exec(c.txRollback)
|
||||
err := c.Conn.Exec(c.txRollback)
|
||||
if errors.Is(err, sqlite3.INTERRUPT) {
|
||||
old := c.Conn.SetInterrupt(context.Background())
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
err = c.Conn.Exec(c.txRollback)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) Prepare(query string) (driver.Stmt, error) {
|
||||
@@ -270,6 +269,12 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
|
||||
return nil, driver.ErrSkip
|
||||
}
|
||||
|
||||
if savept, ok := ctx.(*saveptCtx); ok {
|
||||
// Called from driver.Savepoint.
|
||||
savept.Savepoint = c.Savepoint()
|
||||
return resultRowsAffected(0), nil
|
||||
}
|
||||
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
|
||||
@@ -281,6 +286,10 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
|
||||
return newResult(c.Conn), nil
|
||||
}
|
||||
|
||||
func (*conn) CheckNamedValue(arg *driver.NamedValue) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type stmt struct {
|
||||
Stmt *sqlite3.Stmt
|
||||
Conn *sqlite3.Conn
|
||||
@@ -377,8 +386,12 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
|
||||
err = s.Stmt.BindBlob(id, a)
|
||||
case sqlite3.ZeroBlob:
|
||||
err = s.Stmt.BindZeroBlob(id, int64(a))
|
||||
case interface{ Value() any }:
|
||||
err = s.Stmt.BindPointer(id, a.Value())
|
||||
case time.Time:
|
||||
err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
|
||||
case json.Marshaler:
|
||||
err = s.Stmt.BindJSON(id, a)
|
||||
case nil:
|
||||
err = s.Stmt.BindNull(id)
|
||||
default:
|
||||
@@ -395,7 +408,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
|
||||
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
|
||||
switch arg.Value.(type) {
|
||||
case bool, int, int64, float64, string, []byte,
|
||||
sqlite3.ZeroBlob, time.Time, nil:
|
||||
sqlite3.ZeroBlob, interface{ Value() any },
|
||||
time.Time, json.Marshaler, nil:
|
||||
return nil
|
||||
default:
|
||||
return driver.ErrSkip
|
||||
@@ -474,11 +488,7 @@ func (r *rows) Next(dest []driver.Value) error {
|
||||
case sqlite3.TEXT:
|
||||
dest[i] = stringOrTime(r.Stmt.ColumnRawText(i))
|
||||
case sqlite3.NULL:
|
||||
if buf, ok := dest[i].([]byte); ok {
|
||||
dest[i] = buf[0:0]
|
||||
} else {
|
||||
dest[i] = nil
|
||||
}
|
||||
dest[i] = nil
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
|
||||
65
driver/json_test.go
Normal file
65
driver/json_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package driver_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
func Example_json() {
|
||||
db, err := driver.Open("file:/test.db?vfs=memdb", nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE orders (
|
||||
cart_id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
cart TEXT
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
type CartItem struct {
|
||||
ItemID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Quantity int `json:"quantity,omitempty"`
|
||||
Price int `json:"price,omitempty"`
|
||||
}
|
||||
|
||||
type Cart struct {
|
||||
Items []CartItem `json:"items"`
|
||||
}
|
||||
|
||||
_, err = db.Exec(`INSERT INTO orders (user_id, cart) VALUES (?, ?)`, 123, sqlite3.JSON(Cart{
|
||||
[]CartItem{
|
||||
{ItemID: "111", Name: "T-shirt", Quantity: 1, Price: 250},
|
||||
{ItemID: "222", Name: "Trousers", Quantity: 1, Price: 600},
|
||||
},
|
||||
}))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var total string
|
||||
err = db.QueryRow(`
|
||||
SELECT total(json_each.value -> 'price')
|
||||
FROM orders, json_each(cart -> 'items')
|
||||
WHERE cart_id = last_insert_rowid()
|
||||
`).Scan(&total)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println("total:", total)
|
||||
// Output:
|
||||
// total: 850
|
||||
}
|
||||
27
driver/savepoint.go
Normal file
27
driver/savepoint.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package driver
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
// Savepoint establishes a new transaction savepoint.
|
||||
//
|
||||
// https://www.sqlite.org/lang_savepoint.html
|
||||
func Savepoint(tx *sql.Tx) sqlite3.Savepoint {
|
||||
var ctx saveptCtx
|
||||
tx.ExecContext(&ctx, "")
|
||||
return ctx.Savepoint
|
||||
}
|
||||
|
||||
type saveptCtx struct{ sqlite3.Savepoint }
|
||||
|
||||
func (*saveptCtx) Deadline() (deadline time.Time, ok bool) { return }
|
||||
|
||||
func (*saveptCtx) Done() <-chan struct{} { return nil }
|
||||
|
||||
func (*saveptCtx) Err() error { return nil }
|
||||
|
||||
func (*saveptCtx) Value(key any) any { return nil }
|
||||
87
driver/savepoint_test.go
Normal file
87
driver/savepoint_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package driver_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
func ExampleSavepoint() {
|
||||
db, err := driver.Open("file:/test.db?vfs=memdb", nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = func() error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt, err := tx.Prepare(`INSERT INTO users (id, name) VALUES (?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec(0, "go")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = stmt.Exec(1, "zig")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
savept := driver.Savepoint(tx)
|
||||
|
||||
_, err = stmt.Exec(2, "whatever")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = savept.Rollback()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = stmt.Exec(3, "rust")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := db.Query(`SELECT id, name FROM users`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var id, name string
|
||||
err = rows.Scan(&id, &name)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("%s %s\n", id, name)
|
||||
}
|
||||
// Output:
|
||||
// 0 go
|
||||
// 1 zig
|
||||
// 3 rust
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
var db *sql.DB
|
||||
|
||||
func ExampleDriverConn() {
|
||||
var err error
|
||||
db, err = sql.Open("sqlite3", "demo.db")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer os.Remove("demo.db")
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
conn, err := db.Conn(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
res, err := conn.ExecContext(ctx, `INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(11))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = conn.Raw(func(driverConn any) error {
|
||||
conn := driverConn.(sqlite3.DriverConn).Raw()
|
||||
savept := conn.Savepoint()
|
||||
defer savept.Release(&err)
|
||||
|
||||
blob, err := conn.OpenBlob("main", "test", "col", id, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer blob.Close()
|
||||
|
||||
_, err = fmt.Fprint(blob, "Hello BLOB!")
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var msg string
|
||||
err = conn.QueryRowContext(ctx, `SELECT col FROM test`).Scan(&msg)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println(msg)
|
||||
// Output:
|
||||
// Hello BLOB!
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
# Embeddable WASM build of SQLite
|
||||
|
||||
This folder includes an embeddable WASM build of SQLite 3.43.1 for use with
|
||||
This folder includes an embeddable WASM build of SQLite 3.44.0 for use with
|
||||
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
|
||||
|
||||
The following optional features are compiled in:
|
||||
|
||||
@@ -13,6 +13,8 @@ sqlite3_finalize
|
||||
sqlite3_reset
|
||||
sqlite3_step
|
||||
sqlite3_exec
|
||||
sqlite3_interrupt
|
||||
sqlite3_progress_handler_go
|
||||
sqlite3_clear_bindings
|
||||
sqlite3_bind_parameter_count
|
||||
sqlite3_bind_parameter_index
|
||||
@@ -23,6 +25,7 @@ sqlite3_bind_double
|
||||
sqlite3_bind_text64
|
||||
sqlite3_bind_blob64
|
||||
sqlite3_bind_zeroblob64
|
||||
sqlite3_bind_pointer_go
|
||||
sqlite3_column_count
|
||||
sqlite3_column_name
|
||||
sqlite3_column_type
|
||||
@@ -62,12 +65,15 @@ sqlite3_value_double
|
||||
sqlite3_value_text
|
||||
sqlite3_value_blob
|
||||
sqlite3_value_bytes
|
||||
sqlite3_value_pointer_go
|
||||
sqlite3_result_null
|
||||
sqlite3_result_int64
|
||||
sqlite3_result_double
|
||||
sqlite3_result_text64
|
||||
sqlite3_result_blob64
|
||||
sqlite3_result_zeroblob64
|
||||
sqlite3_result_pointer_go
|
||||
sqlite3_result_value
|
||||
sqlite3_result_error
|
||||
sqlite3_result_error_code
|
||||
sqlite3_result_error_nomem
|
||||
|
||||
Binary file not shown.
59
ext/blob/blob.go
Normal file
59
ext/blob/blob.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Package blob provides an alternative interface to incremental BLOB I/O.
|
||||
package blob
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
// Register registers the blob_open SQL function.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
db.CreateFunction("blob_open", -1,
|
||||
sqlite3.DETERMINISTIC|sqlite3.DIRECTONLY, openBlob)
|
||||
}
|
||||
|
||||
func openBlob(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if len(arg) < 6 {
|
||||
ctx.ResultError(errors.New("wrong number of arguments to function blob_open()"))
|
||||
return
|
||||
}
|
||||
|
||||
row := arg[3].Int64()
|
||||
|
||||
var err error
|
||||
blob, ok := ctx.GetAuxData(0).(*sqlite3.Blob)
|
||||
if ok {
|
||||
err = blob.Reopen(row)
|
||||
if errors.Is(err, sqlite3.MISUSE) {
|
||||
// Blob was closed (db, table or column changed).
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
|
||||
if !ok {
|
||||
db := arg[0].Text()
|
||||
table := arg[1].Text()
|
||||
column := arg[2].Text()
|
||||
write := arg[4].Bool()
|
||||
blob, err = ctx.Conn().OpenBlob(db, table, column, row, write)
|
||||
}
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
|
||||
fn := arg[5].Pointer().(OpenCallback)
|
||||
err = fn(blob, arg[6:]...)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
|
||||
// This ensures the blob is closed if db, table or column change.
|
||||
ctx.SetAuxData(0, blob)
|
||||
ctx.SetAuxData(1, blob)
|
||||
ctx.SetAuxData(2, blob)
|
||||
}
|
||||
|
||||
type OpenCallback func(*sqlite3.Blob, ...sqlite3.Value) error
|
||||
61
ext/blob/blob_test.go
Normal file
61
ext/blob/blob_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package blob_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/ext/blob"
|
||||
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
func Example() {
|
||||
// Open the database, registering the extension.
|
||||
db, err := driver.Open("file:/test.db?vfs=memdb", func(conn *sqlite3.Conn) error {
|
||||
blob.Register(conn)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
const message = "Hello BLOB!"
|
||||
|
||||
// Create the BLOB.
|
||||
_, err = db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(len(message)))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Write the BLOB.
|
||||
_, err = db.Exec(`SELECT blob_open('main', 'test', 'col', last_insert_rowid(), true, ?)`,
|
||||
sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
|
||||
_, err = io.WriteString(blob, message)
|
||||
return err
|
||||
}))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the BLOB.
|
||||
_, err = db.Exec(`SELECT blob_open('main', 'test', 'col', rowid, false, ?) FROM test`,
|
||||
sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
|
||||
_, err = io.Copy(os.Stdout, blob)
|
||||
return err
|
||||
}))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Output:
|
||||
// Hello BLOB!
|
||||
}
|
||||
71
func.go
71
func.go
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
@@ -47,6 +46,7 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(
|
||||
|
||||
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
|
||||
// If fn returns a [WindowFunction], then an aggregate window function is created.
|
||||
// If fn returns an [io.Closer], it will be called to free resources.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/create_function.html
|
||||
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
|
||||
@@ -70,7 +70,7 @@ type AggregateFunction interface {
|
||||
// The function arguments, if any, corresponding to the row being added are passed to Step.
|
||||
Step(ctx Context, arg ...Value)
|
||||
|
||||
// Value is invoked to return the current value of the aggregate.
|
||||
// Value is invoked to return the current (or final) value of the aggregate.
|
||||
Value(ctx Context)
|
||||
}
|
||||
|
||||
@@ -85,17 +85,6 @@ type WindowFunction interface {
|
||||
Inverse(ctx Context, arg ...Value)
|
||||
}
|
||||
|
||||
func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
|
||||
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
|
||||
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
|
||||
util.ExportFuncVIII(env, "go_func", callbackFunc)
|
||||
util.ExportFuncVIII(env, "go_step", callbackStep)
|
||||
util.ExportFuncVI(env, "go_final", callbackFinal)
|
||||
util.ExportFuncVI(env, "go_value", callbackValue)
|
||||
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
|
||||
return env
|
||||
}
|
||||
|
||||
func callbackDestroy(ctx context.Context, mod api.Module, pApp uint32) {
|
||||
util.DelHandle(ctx, pApp)
|
||||
}
|
||||
@@ -106,57 +95,57 @@ func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nK
|
||||
}
|
||||
|
||||
func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackHandle(sqlite, pCtx).(func(ctx Context, arg ...Value))
|
||||
fn(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
|
||||
db := ctx.Value(connKey{}).(*Conn)
|
||||
fn := callbackHandle(db, pCtx).(func(ctx Context, arg ...Value))
|
||||
fn(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
|
||||
fn.Step(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
|
||||
db := ctx.Value(connKey{}).(*Conn)
|
||||
fn := callbackAggregate(db, pCtx, nil).(AggregateFunction)
|
||||
fn.Step(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) {
|
||||
var handle uint32
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackAggregate(sqlite, pCtx, &handle).(AggregateFunction)
|
||||
fn.Value(Context{sqlite, pCtx})
|
||||
db := ctx.Value(connKey{}).(*Conn)
|
||||
fn := callbackAggregate(db, pCtx, &handle).(AggregateFunction)
|
||||
fn.Value(Context{db, pCtx})
|
||||
if err := util.DelHandle(ctx, handle); err != nil {
|
||||
Context{sqlite, pCtx}.ResultError(err)
|
||||
Context{db, pCtx}.ResultError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
|
||||
fn.Value(Context{sqlite, pCtx})
|
||||
db := ctx.Value(connKey{}).(*Conn)
|
||||
fn := callbackAggregate(db, pCtx, nil).(AggregateFunction)
|
||||
fn.Value(Context{db, pCtx})
|
||||
}
|
||||
|
||||
func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackAggregate(sqlite, pCtx, nil).(WindowFunction)
|
||||
fn.Inverse(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
|
||||
db := ctx.Value(connKey{}).(*Conn)
|
||||
fn := callbackAggregate(db, pCtx, nil).(WindowFunction)
|
||||
fn.Inverse(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackHandle(sqlite *sqlite, pCtx uint32) any {
|
||||
pApp := uint32(sqlite.call(sqlite.api.userData, uint64(pCtx)))
|
||||
return util.GetHandle(sqlite.ctx, pApp)
|
||||
func callbackHandle(db *Conn, pCtx uint32) any {
|
||||
pApp := uint32(db.call(db.api.userData, uint64(pCtx)))
|
||||
return util.GetHandle(db.ctx, pApp)
|
||||
}
|
||||
|
||||
func callbackAggregate(sqlite *sqlite, pCtx uint32, close *uint32) any {
|
||||
func callbackAggregate(db *Conn, pCtx uint32, close *uint32) any {
|
||||
// On close, we're getting rid of the handle.
|
||||
// Don't allocate space to store it.
|
||||
var size uint64
|
||||
if close == nil {
|
||||
size = ptrlen
|
||||
}
|
||||
ptr := uint32(sqlite.call(sqlite.api.aggregateCtx, uint64(pCtx), size))
|
||||
ptr := uint32(db.call(db.api.aggregateCtx, uint64(pCtx), size))
|
||||
|
||||
// Try loading the handle, if we already have one, or want a new one.
|
||||
if ptr != 0 || size != 0 {
|
||||
if handle := util.ReadUint32(sqlite.mod, ptr); handle != 0 {
|
||||
fn := util.GetHandle(sqlite.ctx, handle)
|
||||
if handle := util.ReadUint32(db.mod, ptr); handle != 0 {
|
||||
fn := util.GetHandle(db.ctx, handle)
|
||||
if close != nil {
|
||||
*close = handle
|
||||
}
|
||||
@@ -167,19 +156,19 @@ func callbackAggregate(sqlite *sqlite, pCtx uint32, close *uint32) any {
|
||||
}
|
||||
|
||||
// Create a new aggregate and store the handle.
|
||||
fn := callbackHandle(sqlite, pCtx).(func() AggregateFunction)()
|
||||
fn := callbackHandle(db, pCtx).(func() AggregateFunction)()
|
||||
if ptr != 0 {
|
||||
util.WriteUint32(sqlite.mod, ptr, util.AddHandle(sqlite.ctx, fn))
|
||||
util.WriteUint32(db.mod, ptr, util.AddHandle(db.ctx, fn))
|
||||
}
|
||||
return fn
|
||||
}
|
||||
|
||||
func callbackArgs(sqlite *sqlite, nArg, pArg uint32) []Value {
|
||||
func callbackArgs(db *Conn, nArg, pArg uint32) []Value {
|
||||
args := make([]Value, nArg)
|
||||
for i := range args {
|
||||
args[i] = Value{
|
||||
sqlite: sqlite,
|
||||
handle: util.ReadUint32(sqlite.mod, pArg+ptrlen*uint32(i)),
|
||||
sqlite: db.sqlite,
|
||||
handle: util.ReadUint32(db.mod, pArg+ptrlen*uint32(i)),
|
||||
}
|
||||
}
|
||||
return args
|
||||
|
||||
8
go.mod
8
go.mod
@@ -3,12 +3,12 @@ module github.com/ncruces/go-sqlite3
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/ncruces/julianday v0.1.5
|
||||
github.com/ncruces/julianday v1.0.0
|
||||
github.com/psanford/httpreadat v0.1.0
|
||||
github.com/tetratelabs/wazero v1.5.0
|
||||
golang.org/x/sync v0.3.0
|
||||
golang.org/x/sys v0.12.0
|
||||
golang.org/x/text v0.13.0
|
||||
golang.org/x/sync v0.5.0
|
||||
golang.org/x/sys v0.14.0
|
||||
golang.org/x/text v0.14.0
|
||||
)
|
||||
|
||||
retract v0.4.0 // tagged from the wrong branch
|
||||
|
||||
16
go.sum
16
go.sum
@@ -1,12 +1,12 @@
|
||||
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
|
||||
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
|
||||
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
|
||||
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
|
||||
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
|
||||
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
|
||||
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
|
||||
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
|
||||
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
|
||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
|
||||
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
|
||||
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
github.com/ncruces/go-sqlite3 v0.9.0/go.mod h1:IyRoNwT0Z+mNRXIVeP2DgWPNl78Kmc/B+pO9i6GNgRg=
|
||||
github.com/ncruces/go-sqlite3 v0.9.1/go.mod h1:jFoUbaCDNUS1KN5ZgFxN7bgcWoWfO0EOKeik9QAHZ08=
|
||||
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
|
||||
|
||||
@@ -125,7 +125,7 @@ func parseDDL(strs ...string) (*ddl, error) {
|
||||
ColumnTypeValue: sql.NullString{String: matches[2], Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Valid: true},
|
||||
UniqueValue: sql.NullBool{Valid: true},
|
||||
NullableValue: sql.NullBool{Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: true, Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
}
|
||||
|
||||
@@ -175,6 +175,18 @@ func parseDDL(strs ...string) (*ddl, error) {
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (d *ddl) clone() *ddl {
|
||||
copied := new(ddl)
|
||||
*copied = *d
|
||||
|
||||
copied.fields = make([]string, len(d.fields))
|
||||
copy(copied.fields, d.fields)
|
||||
copied.columns = make([]migrator.ColumnType, len(d.columns))
|
||||
copy(copied.columns, d.columns)
|
||||
|
||||
return copied
|
||||
}
|
||||
|
||||
func (d *ddl) compile() string {
|
||||
if len(d.fields) == 0 {
|
||||
return d.head
|
||||
@@ -183,6 +195,21 @@ func (d *ddl) compile() string {
|
||||
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
|
||||
}
|
||||
|
||||
func (d *ddl) renameTable(dst, src string) error {
|
||||
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + regexp.QuoteMeta(src) + "\\b('|`|\")?\\s*")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
replaced := tableReg.ReplaceAllString(d.head, fmt.Sprintf(" `%s` ", dst))
|
||||
if replaced == d.head {
|
||||
return fmt.Errorf("failed to look up tablename `%s` from DDL head '%s'", src, d.head)
|
||||
}
|
||||
|
||||
d.head = replaced
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *ddl) addConstraint(name string, sql string) {
|
||||
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
|
||||
|
||||
@@ -208,6 +235,18 @@ func (d *ddl) removeConstraint(name string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
//lint:ignore U1000 ignore unused code.
|
||||
func (d *ddl) hasConstraint(name string) bool {
|
||||
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
|
||||
|
||||
for _, f := range d.fields {
|
||||
if reg.MatchString(f) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *ddl) getColumns() []string {
|
||||
res := []string{}
|
||||
|
||||
@@ -229,3 +268,30 @@ func (d *ddl) getColumns() []string {
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (d *ddl) alterColumn(name, sql string) bool {
|
||||
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields[i] = sql
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
d.fields = append(d.fields, sql)
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *ddl) removeColumn(name string) bool {
|
||||
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields = append(d.fields[:i], d.fields[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -20,16 +20,16 @@ func TestParseDDL(t *testing.T) {
|
||||
"CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)",
|
||||
}, 6, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}},
|
||||
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
},
|
||||
},
|
||||
{"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
}},
|
||||
{"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
@@ -39,7 +39,7 @@ func TestParseDDL(t *testing.T) {
|
||||
{"with_special_characters", []string{
|
||||
"CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")",
|
||||
}, 1, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -122,7 +122,7 @@ func TestParseDDL_Whitespaces(t *testing.T) {
|
||||
NameValue: sql.NullString{String: "id", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: true, Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
UniqueValue: sql.NullBool{Bool: true, Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true},
|
||||
@@ -131,7 +131,7 @@ func TestParseDDL_Whitespaces(t *testing.T) {
|
||||
NameValue: sql.NullString{String: "dark_mode", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "numeric", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "numeric", Valid: true},
|
||||
NullableValue: sql.NullBool{Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: true, Valid: true},
|
||||
DefaultValueValue: sql.NullString{String: "true", Valid: true},
|
||||
UniqueValue: sql.NullBool{Bool: false, Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (dialector Dialector) Translate(err error) error {
|
||||
func (_Dialector) Translate(err error) error {
|
||||
switch {
|
||||
case
|
||||
errors.Is(err, sqlite3.CONSTRAINT_UNIQUE),
|
||||
|
||||
@@ -3,8 +3,8 @@ module github.com/ncruces/go-sqlite3/gormlite
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/ncruces/go-sqlite3 v0.9.0
|
||||
gorm.io/gorm v1.25.4
|
||||
github.com/ncruces/go-sqlite3 v0.9.1
|
||||
gorm.io/gorm v1.25.5
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -12,5 +12,5 @@ require (
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/ncruces/julianday v0.1.5 // indirect
|
||||
github.com/tetratelabs/wazero v1.5.0 // indirect
|
||||
golang.org/x/sys v0.12.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
)
|
||||
|
||||
@@ -2,15 +2,15 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/ncruces/go-sqlite3 v0.9.0 h1:tl5eEmGEyzZH2ur8sDgPJTdzV4CRnKpsFngoP1QRjD8=
|
||||
github.com/ncruces/go-sqlite3 v0.9.0/go.mod h1:IyRoNwT0Z+mNRXIVeP2DgWPNl78Kmc/B+pO9i6GNgRg=
|
||||
github.com/ncruces/go-sqlite3 v0.9.1 h1:kV7Zy+ZNyHMfMyZeWc1Yyq+wtgYZDZdp2qAA/wfeMWo=
|
||||
github.com/ncruces/go-sqlite3 v0.9.1/go.mod h1:jFoUbaCDNUS1KN5ZgFxN7bgcWoWfO0EOKeik9QAHZ08=
|
||||
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
|
||||
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
|
||||
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
|
||||
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
|
||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
gorm.io/gorm v1.25.4 h1:iyNd8fNAe8W9dvtlgeRI5zSVZPsq3OpcTu37cYcpCmw=
|
||||
gorm.io/gorm v1.25.4/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
|
||||
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
|
||||
@@ -3,7 +3,6 @@ package gormlite
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -12,11 +11,11 @@ import (
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type Migrator struct {
|
||||
type _Migrator struct {
|
||||
migrator.Migrator
|
||||
}
|
||||
|
||||
func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
|
||||
func (m *_Migrator) RunWithoutForeignKey(fc func() error) error {
|
||||
var enabled int
|
||||
m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
|
||||
if enabled == 1 {
|
||||
@@ -27,7 +26,7 @@ func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
|
||||
return fc()
|
||||
}
|
||||
|
||||
func (m Migrator) HasTable(value interface{}) bool {
|
||||
func (m _Migrator) HasTable(value interface{}) bool {
|
||||
var count int
|
||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
|
||||
@@ -35,7 +34,7 @@ func (m Migrator) HasTable(value interface{}) bool {
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) DropTable(values ...interface{}) error {
|
||||
func (m _Migrator) DropTable(values ...interface{}) error {
|
||||
return m.RunWithoutForeignKey(func() error {
|
||||
values = m.ReorderModels(values, false)
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
@@ -52,11 +51,11 @@ func (m Migrator) DropTable(values ...interface{}) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) GetTables() (tableList []string, err error) {
|
||||
func (m _Migrator) GetTables() (tableList []string, err error) {
|
||||
return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error
|
||||
}
|
||||
|
||||
func (m Migrator) HasColumn(value interface{}, name string) bool {
|
||||
func (m _Migrator) HasColumn(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
@@ -76,31 +75,24 @@ func (m Migrator) HasColumn(value interface{}, name string) bool {
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) AlterColumn(value interface{}, name string) error {
|
||||
func (m _Migrator) AlterColumn(value interface{}, name string) error {
|
||||
return m.RunWithoutForeignKey(func() error {
|
||||
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
// lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)`
|
||||
reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
if ddl.alterColumn(field.DBName, fmt.Sprintf("`%s` ?", field.DBName)) {
|
||||
return nil, nil, fmt.Errorf("field `%s` not found in origin ddl, ddl= '%s'", name, ddl.compile())
|
||||
}
|
||||
|
||||
createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName))
|
||||
|
||||
if createSQL == rawDDL {
|
||||
return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL)
|
||||
}
|
||||
|
||||
return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil
|
||||
return ddl, []interface{}{m.FullDataTypeOf(field)}, nil
|
||||
}
|
||||
return "", nil, fmt.Errorf("failed to alter field with name %v", name)
|
||||
|
||||
return nil, nil, fmt.Errorf("failed to alter field with name `%s`", name)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
|
||||
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
func (m _Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
columnTypes := make([]gorm.ColumnType, 0)
|
||||
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
|
||||
var (
|
||||
@@ -148,29 +140,23 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
return columnTypes, execErr
|
||||
}
|
||||
|
||||
func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
func (m _Migrator) DropColumn(value interface{}, name string) error {
|
||||
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
createSQL := reg.ReplaceAllString(rawDDL, "")
|
||||
|
||||
return createSQL, nil, nil
|
||||
ddl.removeColumn(name)
|
||||
return ddl, nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
func (m _Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
|
||||
return m.recreateTable(value, &table,
|
||||
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
|
||||
var (
|
||||
constraintName string
|
||||
constraintSql string
|
||||
@@ -185,22 +171,16 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
constraintSql = "CONSTRAINT ? CHECK (?)"
|
||||
constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
|
||||
} else {
|
||||
return "", nil, nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
createDDL, err := parseDDL(rawDDL)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
createDDL.addConstraint(constraintName, constraintSql)
|
||||
createSQL := createDDL.compile()
|
||||
|
||||
return createSQL, constraintValues, nil
|
||||
ddl.addConstraint(constraintName, constraintSql)
|
||||
return ddl, constraintValues, nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||
func (m _Migrator) DropConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
@@ -210,20 +190,14 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||
}
|
||||
|
||||
return m.recreateTable(value, &table,
|
||||
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
createDDL, err := parseDDL(rawDDL)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
createDDL.removeConstraint(name)
|
||||
createSQL := createDDL.compile()
|
||||
|
||||
return createSQL, nil, nil
|
||||
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
|
||||
ddl.removeConstraint(name)
|
||||
return ddl, nil, nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
func (m _Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
@@ -244,13 +218,13 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) CurrentDatabase() (name string) {
|
||||
func (m _Migrator) CurrentDatabase() (name string) {
|
||||
var null interface{}
|
||||
m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
|
||||
func (m _Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
|
||||
for _, opt := range opts {
|
||||
str := stmt.Quote(opt.DBName)
|
||||
if opt.Expression != "" {
|
||||
@@ -269,7 +243,7 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
func (m _Migrator) CreateIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
@@ -298,7 +272,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
func (m _Migrator) HasIndex(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
@@ -317,7 +291,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
func (m _Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
var sql string
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
|
||||
@@ -331,7 +305,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropIndex(value interface{}, name string) error {
|
||||
func (m _Migrator) DropIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
@@ -365,7 +339,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) getRawDDL(table string) (string, error) {
|
||||
func (m _Migrator) getRawDDL(table string) (string, error) {
|
||||
var createSQL string
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
|
||||
|
||||
@@ -375,8 +349,10 @@ func (m Migrator) getRawDDL(table string) (string, error) {
|
||||
return createSQL, nil
|
||||
}
|
||||
|
||||
func (m Migrator) recreateTable(value interface{}, tablePtr *string,
|
||||
getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error {
|
||||
func (m _Migrator) recreateTable(
|
||||
value interface{}, tablePtr *string,
|
||||
getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error),
|
||||
) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
table := stmt.Table
|
||||
if tablePtr != nil {
|
||||
@@ -388,27 +364,26 @@ func (m Migrator) recreateTable(value interface{}, tablePtr *string,
|
||||
return err
|
||||
}
|
||||
|
||||
newTableName := table + "__temp"
|
||||
|
||||
createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt)
|
||||
originDDL, err := parseDDL(rawDDL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if createSQL == "" {
|
||||
|
||||
createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if createDDL == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
|
||||
if err != nil {
|
||||
newTableName := table + "__temp"
|
||||
if err := createDDL.renameTable(newTableName, table); err != nil {
|
||||
return err
|
||||
}
|
||||
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
|
||||
|
||||
createDDL, err := parseDDL(createSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
columns := createDDL.getColumns()
|
||||
createSQL := createDDL.compile()
|
||||
|
||||
return m.DB.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {
|
||||
|
||||
@@ -3,6 +3,7 @@ package gormlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strconv"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -15,20 +16,26 @@ import (
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
)
|
||||
|
||||
type Dialector struct {
|
||||
// Open opens a GORM dialector from a data source name.
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &_Dialector{DSN: dsn}
|
||||
}
|
||||
|
||||
// Open opens a GORM dialector from a database handle.
|
||||
func OpenDB(db *sql.DB) gorm.Dialector {
|
||||
return &_Dialector{Conn: db}
|
||||
}
|
||||
|
||||
type _Dialector struct {
|
||||
DSN string
|
||||
Conn gorm.ConnPool
|
||||
}
|
||||
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &Dialector{DSN: dsn}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Name() string {
|
||||
func (dialector _Dialector) Name() string {
|
||||
return "sqlite"
|
||||
}
|
||||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
func (dialector _Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
if dialector.Conn != nil {
|
||||
db.ConnPool = dialector.Conn
|
||||
} else {
|
||||
@@ -47,7 +54,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
if compareVersion(version, "3.35.0") >= 0 {
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
|
||||
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
|
||||
UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE", "RETURNING"},
|
||||
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
|
||||
LastInsertIDReversed: true,
|
||||
})
|
||||
@@ -63,7 +70,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
||||
func (dialector _Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
||||
return map[string]clause.ClauseBuilder{
|
||||
"INSERT": func(c clause.Clause, builder clause.Builder) {
|
||||
if insert, ok := c.Expression.(clause.Insert); ok {
|
||||
@@ -112,7 +119,7 @@ func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
|
||||
func (dialector _Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
|
||||
if field.AutoIncrement {
|
||||
return clause.Expr{SQL: "NULL"}
|
||||
}
|
||||
@@ -121,19 +128,19 @@ func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression
|
||||
return clause.Expr{SQL: "DEFAULT"}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
func (dialector _Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return _Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
CreateIndexAfterCreateTable: true,
|
||||
}}}
|
||||
}
|
||||
|
||||
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||||
func (dialector _Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||||
writer.WriteByte('?')
|
||||
}
|
||||
|
||||
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
func (dialector _Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
var (
|
||||
underQuoted, selfQuoted bool
|
||||
continuousBacktick int8
|
||||
@@ -181,16 +188,17 @@ func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
writer.WriteString("`")
|
||||
}
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
func (dialector _Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, nil, `"`, vars...)
|
||||
}
|
||||
|
||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
func (dialector _Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
return "numeric"
|
||||
case schema.Int, schema.Uint:
|
||||
if field.AutoIncrement && !field.PrimaryKey {
|
||||
if field.AutoIncrement {
|
||||
// doesn't check `PrimaryKey`, to keep backward compatibility
|
||||
// https://www.sqlite.org/autoinc.html
|
||||
return "integer PRIMARY KEY AUTOINCREMENT"
|
||||
} else {
|
||||
@@ -214,12 +222,12 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
return string(field.DataType)
|
||||
}
|
||||
|
||||
func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error {
|
||||
func (dialectopr _Dialector) SavePoint(tx *gorm.DB, name string) error {
|
||||
tx.Exec("SAVEPOINT " + name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
|
||||
func (dialectopr _Dialector) RollbackTo(tx *gorm.DB, name string) error {
|
||||
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
@@ -17,7 +16,7 @@ func TestDialector(t *testing.T) {
|
||||
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
|
||||
|
||||
// Custom connection with a custom function called "my_custom_function".
|
||||
conn, err := driver.Open(InMemoryDSN, func(ctx context.Context, conn *sqlite3.Conn) error {
|
||||
db, err := driver.Open(InMemoryDSN, func(conn *sqlite3.Conn) error {
|
||||
return conn.CreateFunction("my_custom_function", 0, sqlite3.DETERMINISTIC,
|
||||
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
ctx.ResultText("my-result")
|
||||
@@ -29,43 +28,35 @@ func TestDialector(t *testing.T) {
|
||||
|
||||
rows := []struct {
|
||||
description string
|
||||
dialector *Dialector
|
||||
dialector gorm.Dialector
|
||||
openSuccess bool
|
||||
query string
|
||||
querySuccess bool
|
||||
}{
|
||||
{
|
||||
description: "Default driver",
|
||||
dialector: &Dialector{
|
||||
DSN: InMemoryDSN,
|
||||
},
|
||||
description: "Default driver",
|
||||
dialector: Open(InMemoryDSN),
|
||||
openSuccess: true,
|
||||
query: "SELECT 1",
|
||||
querySuccess: true,
|
||||
},
|
||||
{
|
||||
description: "Custom function",
|
||||
dialector: &Dialector{
|
||||
DSN: InMemoryDSN,
|
||||
},
|
||||
description: "Custom function",
|
||||
dialector: Open(InMemoryDSN),
|
||||
openSuccess: true,
|
||||
query: "SELECT my_custom_function()",
|
||||
querySuccess: false,
|
||||
},
|
||||
{
|
||||
description: "Custom connection",
|
||||
dialector: &Dialector{
|
||||
Conn: conn,
|
||||
},
|
||||
description: "Custom connection",
|
||||
dialector: OpenDB(db),
|
||||
openSuccess: true,
|
||||
query: "SELECT 1",
|
||||
querySuccess: true,
|
||||
},
|
||||
{
|
||||
description: "Custom connection, custom function",
|
||||
dialector: &Dialector{
|
||||
Conn: conn,
|
||||
},
|
||||
description: "Custom connection, custom function",
|
||||
dialector: OpenDB(db),
|
||||
openSuccess: true,
|
||||
query: "SELECT my_custom_function()",
|
||||
querySuccess: true,
|
||||
|
||||
@@ -72,6 +72,7 @@ const (
|
||||
IOERR_ROLLBACK_ATOMIC = IOERR | (31 << 8)
|
||||
IOERR_DATA = IOERR | (32 << 8)
|
||||
IOERR_CORRUPTFS = IOERR | (33 << 8)
|
||||
IOERR_IN_PAGE = IOERR | (34 << 8)
|
||||
LOCKED_SHAREDCACHE = LOCKED | (1 << 8)
|
||||
LOCKED_VTAB = LOCKED | (2 << 8)
|
||||
BUSY_RECOVERY = BUSY | (1 << 8)
|
||||
|
||||
@@ -23,6 +23,7 @@ const (
|
||||
OffsetErr = ErrorString("sqlite3: invalid offset")
|
||||
TailErr = ErrorString("sqlite3: multiple statements")
|
||||
IsolationErr = ErrorString("sqlite3: unsupported isolation level")
|
||||
ValueErr = ErrorString("sqlite3: unsupported value")
|
||||
NoVFSErr = ErrorString("sqlite3: no such vfs: ")
|
||||
)
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ func View(mod api.Module, ptr uint32, size uint64) []byte {
|
||||
if size > math.MaxUint32 {
|
||||
panic(RangeErr)
|
||||
}
|
||||
if size == 0 {
|
||||
return nil
|
||||
}
|
||||
buf, ok := mod.Memory().Read(ptr, uint32(size))
|
||||
if !ok {
|
||||
panic(RangeErr)
|
||||
|
||||
56
json.go
Normal file
56
json.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// JSON returns:
|
||||
// a [json.Marshaler] that can be used as an argument to
|
||||
// [database/sql.DB.Exec] and similar methods to
|
||||
// store value as JSON; and
|
||||
// a [database/sql.Scanner] that can be used as an argument to
|
||||
// [database/sql.Row.Scan] and similar methods to
|
||||
// decode JSON into value.
|
||||
func JSON(value any) any {
|
||||
return jsonValue{value}
|
||||
}
|
||||
|
||||
type jsonValue struct{ any }
|
||||
|
||||
func (j jsonValue) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(j.any)
|
||||
}
|
||||
|
||||
func (j jsonValue) UnmarshalJSON(data []byte) error {
|
||||
return json.Unmarshal(data, j.any)
|
||||
}
|
||||
|
||||
func (j jsonValue) Scan(value any) error {
|
||||
var buf []byte
|
||||
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
buf = v
|
||||
case string:
|
||||
buf = unsafe.Slice(unsafe.StringData(v), len(v))
|
||||
case int64:
|
||||
buf = strconv.AppendInt(nil, v, 10)
|
||||
case float64:
|
||||
buf = strconv.AppendFloat(nil, v, 'g', -1, 64)
|
||||
case time.Time:
|
||||
buf = append(buf, '"')
|
||||
buf = v.AppendFormat(buf, time.RFC3339Nano)
|
||||
buf = append(buf, '"')
|
||||
case nil:
|
||||
buf = append(buf, "null"...)
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
|
||||
return j.UnmarshalJSON(buf)
|
||||
}
|
||||
14
pointer.go
Normal file
14
pointer.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package sqlite3
|
||||
|
||||
// Pointer returns a pointer to a value
|
||||
// that can be used as an argument to
|
||||
// [database/sql.DB.Exec] and similar methods.
|
||||
//
|
||||
// https://www.sqlite.org/bindptr.html
|
||||
func Pointer[T any](val T) any {
|
||||
return pointer[T]{val}
|
||||
}
|
||||
|
||||
type pointer[T any] struct{ val T }
|
||||
|
||||
func (p pointer[T]) Value() any { return p.val }
|
||||
112
quote.go
Normal file
112
quote.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Quote escapes and quotes a value
|
||||
// making it safe to embed in SQL text.
|
||||
func Quote(value any) string {
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return "NULL"
|
||||
case bool:
|
||||
if v {
|
||||
return "1"
|
||||
} else {
|
||||
return "0"
|
||||
}
|
||||
|
||||
case int:
|
||||
return strconv.Itoa(v)
|
||||
case int64:
|
||||
return strconv.FormatInt(v, 10)
|
||||
case float64:
|
||||
switch {
|
||||
case math.IsNaN(v):
|
||||
return "NULL"
|
||||
case math.IsInf(v, 1):
|
||||
return "9.0e999"
|
||||
case math.IsInf(v, -1):
|
||||
return "-9.0e999"
|
||||
}
|
||||
return strconv.FormatFloat(v, 'g', -1, 64)
|
||||
case time.Time:
|
||||
return "'" + v.Format(time.RFC3339Nano) + "'"
|
||||
|
||||
case string:
|
||||
if strings.IndexByte(v, 0) >= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
buf := make([]byte, 2+len(v)+strings.Count(v, "'"))
|
||||
buf[0] = '\''
|
||||
i := 1
|
||||
for _, b := range []byte(v) {
|
||||
if b == '\'' {
|
||||
buf[i] = b
|
||||
i += 1
|
||||
}
|
||||
buf[i] = b
|
||||
i += 1
|
||||
}
|
||||
buf[i] = '\''
|
||||
return unsafe.String(&buf[0], len(buf))
|
||||
|
||||
case []byte:
|
||||
buf := make([]byte, 3+2*len(v))
|
||||
buf[0] = 'x'
|
||||
buf[1] = '\''
|
||||
i := 2
|
||||
for _, b := range v {
|
||||
const hex = "0123456789ABCDEF"
|
||||
buf[i+0] = hex[b/16]
|
||||
buf[i+1] = hex[b%16]
|
||||
i += 2
|
||||
}
|
||||
buf[i] = '\''
|
||||
return unsafe.String(&buf[0], len(buf))
|
||||
|
||||
case ZeroBlob:
|
||||
if v > ZeroBlob(1e9-3)/2 {
|
||||
break
|
||||
}
|
||||
|
||||
buf := bytes.Repeat([]byte("0"), int(3+2*int64(v)))
|
||||
buf[0] = 'x'
|
||||
buf[1] = '\''
|
||||
buf[len(buf)-1] = '\''
|
||||
return unsafe.String(&buf[0], len(buf))
|
||||
}
|
||||
|
||||
panic(util.ValueErr)
|
||||
}
|
||||
|
||||
// QuoteIdentifier escapes and quotes an identifier
|
||||
// making it safe to embed in SQL text.
|
||||
func QuoteIdentifier(id string) string {
|
||||
if strings.IndexByte(id, 0) >= 0 {
|
||||
panic(util.ValueErr)
|
||||
}
|
||||
|
||||
buf := make([]byte, 2+len(id)+strings.Count(id, `"`))
|
||||
buf[0] = '"'
|
||||
i := 1
|
||||
for _, b := range []byte(id) {
|
||||
if b == '"' {
|
||||
buf[i] = b
|
||||
i += 1
|
||||
}
|
||||
buf[i] = b
|
||||
i += 1
|
||||
}
|
||||
buf[i] = '"'
|
||||
return unsafe.String(&buf[0], len(buf))
|
||||
}
|
||||
53
sqlite.go
53
sqlite.go
@@ -22,6 +22,8 @@ import (
|
||||
var (
|
||||
Binary []byte // WASM binary to load.
|
||||
Path string // Path to load the binary from.
|
||||
|
||||
RuntimeConfig wazero.RuntimeConfig
|
||||
)
|
||||
|
||||
var instance struct {
|
||||
@@ -32,12 +34,16 @@ var instance struct {
|
||||
}
|
||||
|
||||
func compileSQLite() {
|
||||
if RuntimeConfig == nil {
|
||||
RuntimeConfig = wazero.NewRuntimeConfig()
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
instance.runtime = wazero.NewRuntime(ctx)
|
||||
instance.runtime = wazero.NewRuntimeWithConfig(ctx, RuntimeConfig)
|
||||
|
||||
env := instance.runtime.NewHostModuleBuilder("env")
|
||||
env = vfs.ExportHostFunctions(env)
|
||||
env = exportHostFunctions(env)
|
||||
env = exportCallbacks(env)
|
||||
_, instance.err = env.Instantiate(ctx)
|
||||
if instance.err != nil {
|
||||
return
|
||||
@@ -65,8 +71,6 @@ type sqlite struct {
|
||||
stack [8]uint64
|
||||
}
|
||||
|
||||
type sqliteKey struct{}
|
||||
|
||||
func instantiateSQLite() (sqlt *sqlite, err error) {
|
||||
instance.once.Do(compileSQLite)
|
||||
if instance.err != nil {
|
||||
@@ -75,7 +79,6 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
|
||||
|
||||
sqlt = new(sqlite)
|
||||
sqlt.ctx = util.NewContext(context.Background())
|
||||
sqlt.ctx = context.WithValue(sqlt.ctx, sqliteKey{}, sqlt)
|
||||
|
||||
sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
|
||||
instance.compiled, wazero.NewModuleConfig())
|
||||
@@ -117,6 +120,8 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
|
||||
reset: getFun("sqlite3_reset"),
|
||||
step: getFun("sqlite3_step"),
|
||||
exec: getFun("sqlite3_exec"),
|
||||
interrupt: getFun("sqlite3_interrupt"),
|
||||
progressHandler: getFun("sqlite3_progress_handler_go"),
|
||||
clearBindings: getFun("sqlite3_clear_bindings"),
|
||||
bindCount: getFun("sqlite3_bind_parameter_count"),
|
||||
bindIndex: getFun("sqlite3_bind_parameter_index"),
|
||||
@@ -127,6 +132,7 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
|
||||
bindText: getFun("sqlite3_bind_text64"),
|
||||
bindBlob: getFun("sqlite3_bind_blob64"),
|
||||
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
|
||||
bindPointer: getFun("sqlite3_bind_pointer_go"),
|
||||
columnCount: getFun("sqlite3_column_count"),
|
||||
columnName: getFun("sqlite3_column_name"),
|
||||
columnType: getFun("sqlite3_column_type"),
|
||||
@@ -164,12 +170,15 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
|
||||
valueText: getFun("sqlite3_value_text"),
|
||||
valueBlob: getFun("sqlite3_value_blob"),
|
||||
valueBytes: getFun("sqlite3_value_bytes"),
|
||||
valuePointer: getFun("sqlite3_value_pointer_go"),
|
||||
resultNull: getFun("sqlite3_result_null"),
|
||||
resultInteger: getFun("sqlite3_result_int64"),
|
||||
resultFloat: getFun("sqlite3_result_double"),
|
||||
resultText: getFun("sqlite3_result_text64"),
|
||||
resultBlob: getFun("sqlite3_result_blob64"),
|
||||
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
|
||||
resultPointer: getFun("sqlite3_result_pointer_go"),
|
||||
resultValue: getFun("sqlite3_result_value"),
|
||||
resultError: getFun("sqlite3_result_error"),
|
||||
resultErrorCode: getFun("sqlite3_result_error_code"),
|
||||
resultErrorMem: getFun("sqlite3_result_error_nomem"),
|
||||
@@ -200,13 +209,15 @@ func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
|
||||
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
|
||||
}
|
||||
|
||||
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
|
||||
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
|
||||
}
|
||||
if handle != 0 {
|
||||
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
|
||||
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
|
||||
}
|
||||
|
||||
if sql != nil {
|
||||
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
|
||||
err.sql = sql[0][r:]
|
||||
if sql != nil {
|
||||
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
|
||||
err.sql = sql[0][r:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -245,7 +256,7 @@ func (sqlt *sqlite) new(size uint64) uint32 {
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) newBytes(b []byte) uint32 {
|
||||
if b == nil {
|
||||
if (*[0]byte)(b) == nil {
|
||||
return 0
|
||||
}
|
||||
ptr := sqlt.new(uint64(len(b)))
|
||||
@@ -333,6 +344,8 @@ type sqliteAPI struct {
|
||||
reset api.Function
|
||||
step api.Function
|
||||
exec api.Function
|
||||
interrupt api.Function
|
||||
progressHandler api.Function
|
||||
clearBindings api.Function
|
||||
bindCount api.Function
|
||||
bindIndex api.Function
|
||||
@@ -343,6 +356,7 @@ type sqliteAPI struct {
|
||||
bindText api.Function
|
||||
bindBlob api.Function
|
||||
bindZeroBlob api.Function
|
||||
bindPointer api.Function
|
||||
columnCount api.Function
|
||||
columnName api.Function
|
||||
columnType api.Function
|
||||
@@ -380,15 +394,30 @@ type sqliteAPI struct {
|
||||
valueText api.Function
|
||||
valueBlob api.Function
|
||||
valueBytes api.Function
|
||||
valuePointer api.Function
|
||||
resultNull api.Function
|
||||
resultInteger api.Function
|
||||
resultFloat api.Function
|
||||
resultText api.Function
|
||||
resultBlob api.Function
|
||||
resultZeroBlob api.Function
|
||||
resultPointer api.Function
|
||||
resultValue api.Function
|
||||
resultError api.Function
|
||||
resultErrorCode api.Function
|
||||
resultErrorMem api.Function
|
||||
resultErrorBig api.Function
|
||||
destructor uint32
|
||||
}
|
||||
|
||||
func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
|
||||
util.ExportFuncII(env, "go_progress", callbackProgress)
|
||||
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
|
||||
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
|
||||
util.ExportFuncVIII(env, "go_func", callbackFunc)
|
||||
util.ExportFuncVIII(env, "go_step", callbackStep)
|
||||
util.ExportFuncVI(env, "go_final", callbackFinal)
|
||||
util.ExportFuncVI(env, "go_value", callbackValue)
|
||||
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
|
||||
return env
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ set -euo pipefail
|
||||
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3430100.zip"
|
||||
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3440000.zip"
|
||||
unzip -d . sqlite-amalgamation-*.zip
|
||||
mv sqlite-amalgamation-*/sqlite3* .
|
||||
rm -rf sqlite-amalgamation-*
|
||||
@@ -12,24 +12,24 @@ cat *.patch | patch --posix
|
||||
|
||||
mkdir -p ext/
|
||||
cd ext/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/decimal.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/uint.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/uuid.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/base64.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/regexp.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/series.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/anycollseq.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/decimal.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/uint.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/uuid.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/base64.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/regexp.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/series.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/anycollseq.c"
|
||||
cd ~-
|
||||
|
||||
cd ../vfs/tests/mptest/testdata/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/mptest.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/config01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/config02.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/crash01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/crash02.subtest"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/multiwrite01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/mptest.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/config01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/config02.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/crash01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/crash02.subtest"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/multiwrite01.test"
|
||||
cd ~-
|
||||
|
||||
cd ../vfs/tests/speedtest1/testdata/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/test/speedtest1.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/test/speedtest1.c"
|
||||
cd ~-
|
||||
@@ -1,4 +1,4 @@
|
||||
#include <string.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
|
||||
@@ -18,14 +18,15 @@ int sqlite3_create_collation_go(sqlite3 *db, const char *zName, void *pApp) {
|
||||
int sqlite3_create_function_go(sqlite3 *db, const char *zName, int nArg,
|
||||
int flags, void *pApp) {
|
||||
return sqlite3_create_function_v2(db, zName, nArg, SQLITE_UTF8 | flags, pApp,
|
||||
go_func, NULL, NULL, go_destroy);
|
||||
go_func, /*step=*/NULL, /*final=*/NULL,
|
||||
go_destroy);
|
||||
}
|
||||
|
||||
int sqlite3_create_aggregate_function_go(sqlite3 *db, const char *zName,
|
||||
int nArg, int flags, void *pApp) {
|
||||
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
|
||||
pApp, go_step, go_final, NULL, NULL,
|
||||
go_destroy);
|
||||
pApp, go_step, go_final, /*value=*/NULL,
|
||||
/*inverse=*/NULL, go_destroy);
|
||||
}
|
||||
|
||||
int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg,
|
||||
@@ -38,3 +39,17 @@ int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg,
|
||||
void sqlite3_set_auxdata_go(sqlite3_context *ctx, int iArg, void *pAux) {
|
||||
sqlite3_set_auxdata(ctx, iArg, pAux, go_destroy);
|
||||
}
|
||||
|
||||
#define GO_POINTER_TYPE "github.com/ncruces/go-sqlite3.Pointer"
|
||||
|
||||
int sqlite3_bind_pointer_go(sqlite3_stmt *stmt, int i, void *pApp) {
|
||||
return sqlite3_bind_pointer(stmt, i, pApp, GO_POINTER_TYPE, go_destroy);
|
||||
}
|
||||
|
||||
void sqlite3_result_pointer_go(sqlite3_context *ctx, void *pApp) {
|
||||
sqlite3_result_pointer(ctx, pApp, GO_POINTER_TYPE, go_destroy);
|
||||
}
|
||||
|
||||
void *sqlite3_value_pointer_go(sqlite3_value *val) {
|
||||
return sqlite3_value_pointer(val, GO_POINTER_TYPE);
|
||||
}
|
||||
|
||||
34
sqlite3/isoweek.patch
Normal file
34
sqlite3/isoweek.patch
Normal file
@@ -0,0 +1,34 @@
|
||||
# ISO week date specifiers.
|
||||
# https://sqlite.org/forum/forumpost/73d99e4497e8e6a7
|
||||
--- sqlite3.c.orig
|
||||
+++ sqlite3.c
|
||||
@@ -1373,6 +1373,29 @@ static void strftimeFunc(
|
||||
sqlite3_str_appendchar(&sRes, 1, c);
|
||||
break;
|
||||
}
|
||||
+ case 'V': /* Fall thru */
|
||||
+ case 'G': {
|
||||
+ DateTime y = x;
|
||||
+ computeJD(&y);
|
||||
+ y.validYMD = 0;
|
||||
+ /* Adjust date to Thursday this week:
|
||||
+ The number in parentheses is 0 for Monday, 3 for Thursday */
|
||||
+ y.iJD += (3 - (((y.iJD+43200000)/86400000) % 7))*86400000;
|
||||
+ computeYMD(&y);
|
||||
+ if( cf=='G' ){
|
||||
+ sqlite3_str_appendf(&sRes,"%04d",y.Y);
|
||||
+ }else{
|
||||
+ int nDay; /* Number of days since 1st day of year */
|
||||
+ i64 tJD = y.iJD;
|
||||
+ y.validJD = 0;
|
||||
+ y.M = 1;
|
||||
+ y.D = 1;
|
||||
+ computeJD(&y);
|
||||
+ nDay = (int)((tJD-y.iJD+43200000)/86400000);
|
||||
+ sqlite3_str_appendf(&sRes,"%02d",nDay/7+1);
|
||||
+ }
|
||||
+ break;
|
||||
+ }
|
||||
case 'Y': {
|
||||
sqlite3_str_appendf(&sRes,"%04d",x.Y);
|
||||
break;
|
||||
@@ -1,6 +1,3 @@
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
|
||||
// Amalgamation
|
||||
#include "sqlite3.c"
|
||||
// VFS
|
||||
@@ -14,6 +11,7 @@
|
||||
#include "ext/uint.c"
|
||||
#include "ext/uuid.c"
|
||||
#include "func.c"
|
||||
#include "progress.c"
|
||||
#include "time.c"
|
||||
|
||||
__attribute__((constructor)) void init() {
|
||||
@@ -25,4 +23,4 @@ __attribute__((constructor)) void init() {
|
||||
sqlite3_auto_extension((void (*)(void))sqlite3_uint_init);
|
||||
sqlite3_auto_extension((void (*)(void))sqlite3_uuid_init);
|
||||
sqlite3_auto_extension((void (*)(void))sqlite3_time_init);
|
||||
}
|
||||
}
|
||||
9
sqlite3/progress.c
Normal file
9
sqlite3/progress.c
Normal file
@@ -0,0 +1,9 @@
|
||||
#include <stddef.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
|
||||
int go_progress(void *);
|
||||
|
||||
void sqlite3_progress_handler_go(sqlite3 *db, int n) {
|
||||
sqlite3_progress_handler(db, n, go_progress, /*arg=*/NULL);
|
||||
}
|
||||
@@ -5,12 +5,28 @@
|
||||
#define SQLITE_OS_OTHER 1
|
||||
#define SQLITE_BYTEORDER 1234
|
||||
|
||||
#define HAVE_INT8_T 1
|
||||
#define HAVE_INT16_T 1
|
||||
#define HAVE_INT32_T 1
|
||||
#define HAVE_INT64_T 1
|
||||
#define HAVE_UINT8_T 1
|
||||
#define HAVE_UINT16_T 1
|
||||
#define HAVE_UINT32_T 1
|
||||
#define HAVE_UINT64_T 1
|
||||
#define HAVE_STDINT_H 1
|
||||
#define HAVE_INTTYPES_H 1
|
||||
|
||||
#define HAVE_LOG2 1
|
||||
#define HAVE_LOG10 1
|
||||
#define HAVE_ISNAN 1
|
||||
|
||||
#define HAVE_USLEEP 1
|
||||
#define HAVE_NANOSLEEP 1
|
||||
|
||||
#define HAVE_GMTIME_R 1
|
||||
#define HAVE_LOCALTIME_S 1
|
||||
|
||||
#define HAVE_MALLOC_H 1
|
||||
#define HAVE_MALLOC_USABLE_SIZE 1
|
||||
|
||||
// Recommended Options
|
||||
@@ -23,7 +39,6 @@
|
||||
#define SQLITE_MAX_EXPR_DEPTH 0
|
||||
#define SQLITE_OMIT_DECLTYPE
|
||||
#define SQLITE_OMIT_DEPRECATED
|
||||
#define SQLITE_OMIT_PROGRESS_CALLBACK
|
||||
#define SQLITE_OMIT_SHARED_CACHE
|
||||
#define SQLITE_OMIT_AUTOINIT
|
||||
#define SQLITE_USE_ALLOCA
|
||||
@@ -52,11 +67,11 @@
|
||||
#define SQLITE_ENABLE_RTREE 1
|
||||
#define SQLITE_ENABLE_GEOPOLY 1
|
||||
|
||||
#define SQLITE_SOUNDEX
|
||||
|
||||
// Session Extension
|
||||
// #define SQLITE_ENABLE_SESSION
|
||||
// #define SQLITE_ENABLE_PREUPDATE_HOOK
|
||||
|
||||
#define SQLITE_SOUNDEX
|
||||
|
||||
// Implemented in vfs.c.
|
||||
int localtime_s(struct tm *const pTm, time_t const *const pTime);
|
||||
@@ -1,3 +1,4 @@
|
||||
#include <stddef.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
@@ -26,7 +27,63 @@ static int time_collation(void *pArg, int nKey1, const void *pKey1, int nKey2,
|
||||
return rc;
|
||||
}
|
||||
|
||||
static void json_time_func(sqlite3_context *context, int argc,
|
||||
sqlite3_value **argv) {
|
||||
DateTime x;
|
||||
if (isDate(context, argc, argv, &x)) return;
|
||||
if (x.tzSet && x.tz) {
|
||||
x.iJD += x.tz * 60000;
|
||||
if (!validJulianDay(x.iJD)) return;
|
||||
x.validYMD = 0;
|
||||
x.validHMS = 0;
|
||||
}
|
||||
computeYMD_HMS(&x);
|
||||
|
||||
sqlite3 *db = sqlite3_context_db_handle(context);
|
||||
sqlite3_str *res = sqlite3_str_new(db);
|
||||
|
||||
sqlite3_str_appendf(res, "%04d-%02d-%02dT%02d:%02d:%02d", //
|
||||
x.Y, x.M, x.D, //
|
||||
x.h, x.m, (int)(x.iJD / 1000 % 60));
|
||||
|
||||
if (x.useSubsec) {
|
||||
int rem = x.iJD % 1000;
|
||||
if (rem) {
|
||||
sqlite3_str_appendchar(res, 1, '.');
|
||||
sqlite3_str_appendchar(res, 1, '0' + rem / 100);
|
||||
if ((rem %= 100)) {
|
||||
sqlite3_str_appendchar(res, 1, '0' + rem / 10);
|
||||
if ((rem %= 10)) {
|
||||
sqlite3_str_appendchar(res, 1, '0' + rem);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (x.tz) {
|
||||
sqlite3_str_appendf(res, "%+03d:%02d", x.tz / 60, abs(x.tz) % 60);
|
||||
} else {
|
||||
sqlite3_str_appendchar(res, 1, 'Z');
|
||||
}
|
||||
|
||||
int rc = sqlite3_str_errcode(res);
|
||||
if (rc) {
|
||||
sqlite3_result_error_code(context, rc);
|
||||
return;
|
||||
}
|
||||
|
||||
int n = sqlite3_str_length(res);
|
||||
sqlite3_result_text(context, sqlite3_str_finish(res), n, sqlite3_free);
|
||||
}
|
||||
|
||||
int sqlite3_time_init(sqlite3 *db, char **pzErrMsg,
|
||||
const sqlite3_api_routines *pApi) {
|
||||
return sqlite3_create_collation(db, "time", SQLITE_UTF8, 0, time_collation);
|
||||
sqlite3_create_collation_v2(db, "time", SQLITE_UTF8, /*arg=*/NULL,
|
||||
time_collation,
|
||||
/*destroy=*/NULL);
|
||||
sqlite3_create_function_v2(
|
||||
db, "json_time", -1,
|
||||
SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS, /*arg=*/NULL,
|
||||
json_time_func, /*step=*/NULL, /*final=*/NULL, /*destroy=*/NULL);
|
||||
return SQLITE_OK;
|
||||
}
|
||||
45
sqlite3/timezone.patch
Normal file
45
sqlite3/timezone.patch
Normal file
@@ -0,0 +1,45 @@
|
||||
# Set UTC timezone, compute local offset.
|
||||
--- sqlite3.c.orig
|
||||
+++ sqlite3.c
|
||||
@@ -340,6 +340,7 @@ static int setDateTimeToCurrent(sqlite3_context *context, DateTime *p){
|
||||
p->iJD = sqlite3StmtCurrentTime(context);
|
||||
if( p->iJD>0 ){
|
||||
p->validJD = 1;
|
||||
+ p->tzSet = 1;
|
||||
return 0;
|
||||
}else{
|
||||
return 1;
|
||||
@@ -355,6 +356,7 @@ static int setDateTimeToCurrent(sqlite3_context *context, DateTime *p){
|
||||
static void setRawDateNumber(DateTime *p, double r){
|
||||
p->s = r;
|
||||
p->rawS = 1;
|
||||
+ p->tzSet = 1;
|
||||
if( r>=0.0 && r<5373484.5 ){
|
||||
p->iJD = (sqlite3_int64)(r*86400000.0 + 0.5);
|
||||
p->validJD = 1;
|
||||
@@ -731,7 +733,16 @@ static int parseModifier(
|
||||
** show local time.
|
||||
*/
|
||||
if( sqlite3_stricmp(z, "localtime")==0 && sqlite3NotPureFunc(pCtx) ){
|
||||
- rc = toLocaltime(p, pCtx);
|
||||
+ if( p->tzSet!=0 || p->tz==0 ) {
|
||||
+ rc = toLocaltime(p, pCtx);
|
||||
+ i64 iOrigJD = p->iJD;
|
||||
+ p->tzSet = 0;
|
||||
+ computeJD(p);
|
||||
+ p->tz = (p->iJD-iOrigJD)/60000;
|
||||
+ if( abs(p->tz)>= 900 ) p->tz = 0;
|
||||
+ } else {
|
||||
+ rc = 0;
|
||||
+ }
|
||||
}
|
||||
break;
|
||||
}
|
||||
@@ -781,6 +792,7 @@ static int parseModifier(
|
||||
p->validJD = 1;
|
||||
p->tzSet = 1;
|
||||
}
|
||||
+ p->tz = 0;
|
||||
rc = SQLITE_OK;
|
||||
}
|
||||
#endif
|
||||
@@ -1,3 +1,5 @@
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
#include <time.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
@@ -90,22 +92,25 @@ int localtime_s(struct tm *const pTm, time_t const *const pTime) {
|
||||
sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
|
||||
if (zVfsName) {
|
||||
static sqlite3_vfs *go_vfs_list;
|
||||
sqlite3_vfs *found = NULL;
|
||||
for (sqlite3_vfs **next = &go_vfs_list; *next;) {
|
||||
sqlite3_vfs *it = *next;
|
||||
|
||||
for (sqlite3_vfs *it = go_vfs_list; it; it = it->pNext) {
|
||||
if (!strcmp(zVfsName, it->zName) && go_vfs_find(it->zName)) {
|
||||
return it;
|
||||
}
|
||||
}
|
||||
|
||||
for (sqlite3_vfs **ptr = &go_vfs_list; *ptr;) {
|
||||
sqlite3_vfs *it = *ptr;
|
||||
if (go_vfs_find(it->zName)) {
|
||||
if (!strcmp(zVfsName, it->zName)) found = it;
|
||||
next = &it->pNext;
|
||||
ptr = &it->pNext;
|
||||
} else {
|
||||
*next = it->pNext;
|
||||
*ptr = it->pNext;
|
||||
free(it);
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
return found;
|
||||
}
|
||||
|
||||
if (go_vfs_find(zVfsName)) {
|
||||
sqlite3_vfs *prev = go_vfs_list;
|
||||
sqlite3_vfs *head = go_vfs_list;
|
||||
go_vfs_list = malloc(sizeof(sqlite3_vfs) + strlen(zVfsName) + 1);
|
||||
char *name = (char *)(go_vfs_list + 1);
|
||||
strcpy(name, zVfsName);
|
||||
@@ -114,7 +119,7 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
|
||||
.szOsFile = sizeof(struct go_file),
|
||||
.mxPathname = 512,
|
||||
.zName = name,
|
||||
.pNext = prev,
|
||||
.pNext = head,
|
||||
|
||||
.xOpen = go_open_wrapper,
|
||||
.xDelete = go_delete,
|
||||
@@ -132,6 +137,5 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
|
||||
return sqlite3_vfs_find_orig(zVfsName);
|
||||
}
|
||||
|
||||
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");
|
||||
static_assert(offsetof(sqlite3_vfs, zName) == 16, "Unexpected offset");
|
||||
static_assert(offsetof(sqlite3, u1.isInterrupted) == 288, "Unexpected offset");
|
||||
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");
|
||||
@@ -3,6 +3,7 @@ package sqlite3
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
@@ -56,6 +57,12 @@ func Test_sqlite_new(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("_MAX_ALLOCATION_SIZE", func(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
if os.Getenv("CI") != "" {
|
||||
t.Skip("skipping in CI")
|
||||
}
|
||||
defer func() { _ = recover() }()
|
||||
sqlite.new(_MAX_ALLOCATION_SIZE)
|
||||
sqlite.new(_MAX_ALLOCATION_SIZE)
|
||||
@@ -132,6 +139,15 @@ func Test_sqlite_newBytes(t *testing.T) {
|
||||
if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
|
||||
ptr = sqlite.newBytes(buf[:0])
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
if got := util.View(sqlite.mod, ptr, 0); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sqlite_newString(t *testing.T) {
|
||||
|
||||
57
stmt.go
57
stmt.go
@@ -1,7 +1,9 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
@@ -235,7 +237,7 @@ func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
|
||||
}
|
||||
|
||||
func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
|
||||
const maxlen = uint64(len(time.RFC3339Nano))
|
||||
const maxlen = uint64(len(time.RFC3339Nano)) + 5
|
||||
|
||||
ptr := s.c.new(maxlen)
|
||||
buf := util.View(s.c.mod, ptr, maxlen)
|
||||
@@ -248,6 +250,35 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindPointer binds a NULL to the prepared statement, just like [Stmt.BindNull],
|
||||
// but it also associates ptr with that NULL value such that it can be retrieved
|
||||
// within an application-defined SQL function using [Value.Pointer].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindPointer(param int, ptr any) error {
|
||||
valPtr := util.AddHandle(s.c.ctx, ptr)
|
||||
r := s.c.call(s.c.api.bindPointer,
|
||||
uint64(s.handle), uint64(param), uint64(valPtr))
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindJSON binds the JSON encoding of value to the prepared statement.
|
||||
// The leftmost SQL parameter has an index of 1.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindJSON(param int, value any) error {
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ptr := s.c.newBytes(data)
|
||||
r := s.c.call(s.c.api.bindText,
|
||||
uint64(s.handle), uint64(param),
|
||||
uint64(ptr), uint64(len(data)),
|
||||
uint64(s.c.api.destructor), _UTF8)
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// ColumnCount returns the number of columns in a result set.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_count.html
|
||||
@@ -402,6 +433,30 @@ func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte {
|
||||
return util.View(s.c.mod, ptr, r)
|
||||
}
|
||||
|
||||
// ColumnJSON parses the JSON-encoded value of the result column
|
||||
// and stores it in the value pointed to by ptr.
|
||||
// The leftmost column of the result set has the index 0.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_blob.html
|
||||
func (s *Stmt) ColumnJSON(col int, ptr any) error {
|
||||
var data []byte
|
||||
switch s.ColumnType(col) {
|
||||
case NULL:
|
||||
data = append(data, "null"...)
|
||||
case TEXT:
|
||||
data = s.ColumnRawText(col)
|
||||
case BLOB:
|
||||
data = s.ColumnRawBlob(col)
|
||||
case INTEGER:
|
||||
data = strconv.AppendInt(nil, s.ColumnInt64(col), 10)
|
||||
case FLOAT:
|
||||
data = strconv.AppendFloat(nil, s.ColumnFloat(col), 'g', -1, 64)
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
return json.Unmarshal(data, ptr)
|
||||
}
|
||||
|
||||
// Return true if stmt is an empty SQL statement.
|
||||
// This is used as an optimization.
|
||||
// It's OK to always return false here.
|
||||
|
||||
@@ -182,7 +182,7 @@ func TestConn_SetInterrupt(t *testing.T) {
|
||||
defer stmt.Close()
|
||||
|
||||
db.SetInterrupt(ctx)
|
||||
cancel()
|
||||
go cancel()
|
||||
|
||||
// Interrupting works.
|
||||
err = stmt.Exec()
|
||||
|
||||
@@ -25,6 +25,13 @@ func TestDB_file(t *testing.T) {
|
||||
testDB(t, filepath.Join(t.TempDir(), "test.db"))
|
||||
}
|
||||
|
||||
func TestDB_nolock(t *testing.T) {
|
||||
t.Parallel()
|
||||
testDB(t, "file:"+
|
||||
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
|
||||
"?nolock=1")
|
||||
}
|
||||
|
||||
func TestDB_wal(t *testing.T) {
|
||||
t.Parallel()
|
||||
wal := filepath.Join(t.TempDir(), "test.db")
|
||||
|
||||
@@ -36,8 +36,17 @@ func TestCreateFunction(t *testing.T) {
|
||||
case 7:
|
||||
ctx.ResultTime(arg.Time(sqlite3.TimeFormatUnix), sqlite3.TimeFormatDefault)
|
||||
case 8:
|
||||
ctx.ResultNull()
|
||||
var v any
|
||||
if err := arg.JSON(&v); err != nil {
|
||||
ctx.ResultError(err)
|
||||
} else {
|
||||
ctx.ResultJSON(v)
|
||||
}
|
||||
case 9:
|
||||
ctx.ResultValue(arg)
|
||||
case 10:
|
||||
ctx.ResultNull()
|
||||
case 11:
|
||||
ctx.ResultError(sqlite3.FULL)
|
||||
}
|
||||
})
|
||||
@@ -45,7 +54,7 @@ func TestCreateFunction(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0, 9)`)
|
||||
stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0)`)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -123,6 +132,27 @@ func TestCreateFunction(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
var got int
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != 8 {
|
||||
t.Errorf("got %v, want 8", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnInt64(0); got != 9 {
|
||||
t.Errorf("got %v, want 9", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
|
||||
68
tests/json_test.go
Normal file
68
tests/json_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
"github.com/ncruces/julianday"
|
||||
)
|
||||
|
||||
func TestJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
|
||||
|
||||
_, err = db.Exec(
|
||||
`INSERT INTO test (col) VALUES (?), (?), (?), (?)`,
|
||||
nil, 1, math.Pi, reference,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(
|
||||
`INSERT INTO test (col) VALUES (?), (?), (?), (?)`,
|
||||
sqlite3.JSON(math.Pi), sqlite3.JSON(false),
|
||||
julianday.Format(reference), sqlite3.JSON([]string{}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := db.Query("SELECT * FROM test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
want := []string{
|
||||
"null", "1", "3.141592653589793",
|
||||
`"2013-10-07T04:23:19.12-04:00"`,
|
||||
"3.141592653589793", "false",
|
||||
"2456572.849526851851852", "[]",
|
||||
}
|
||||
for rows.Next() {
|
||||
var got json.RawMessage
|
||||
err = rows.Scan(sqlite3.JSON(&got))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(got) != want[0] {
|
||||
t.Errorf("got %q, want %q", got, want[0])
|
||||
}
|
||||
want = want[1:]
|
||||
}
|
||||
}
|
||||
82
tests/quote_test.go
Normal file
82
tests/quote_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestQuote(t *testing.T) {
|
||||
tests := []struct {
|
||||
val any
|
||||
want string
|
||||
}{
|
||||
{`abc`, "'abc'"},
|
||||
{`a"bc`, "'a\"bc'"},
|
||||
{`a'bc`, "'a''bc'"},
|
||||
{"\x07bc", "'\abc'"},
|
||||
{"\x1c\n", "'\x1c\n'"},
|
||||
{[]byte("\xB0\x00\x0B"), "x'B0000B'"},
|
||||
{"\xB0\x00\x0B", ""},
|
||||
|
||||
{0, "0"},
|
||||
{true, "1"},
|
||||
{false, "0"},
|
||||
{nil, "NULL"},
|
||||
{math.NaN(), "NULL"},
|
||||
{math.Inf(1), "9.0e999"},
|
||||
{math.Inf(-1), "-9.0e999"},
|
||||
{math.Pi, "3.141592653589793"},
|
||||
{int64(math.MaxInt64), "9223372036854775807"},
|
||||
{time.Unix(0, 0).UTC(), "'1970-01-01T00:00:00Z'"},
|
||||
{sqlite3.ZeroBlob(4), "x'00000000'"},
|
||||
{sqlite3.ZeroBlob(1e9), ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.want, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil && tt.want != "" {
|
||||
t.Errorf("Quote(%q) = %v", tt.val, r)
|
||||
}
|
||||
}()
|
||||
|
||||
got := sqlite3.Quote(tt.val)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Quote(%v) = %q, want %q", tt.val, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
id string
|
||||
want string
|
||||
}{
|
||||
{`abc`, `"abc"`},
|
||||
{`a"bc`, `"a""bc"`},
|
||||
{`a'bc`, `"a'bc"`},
|
||||
{"\x07bc", "\"\abc\""},
|
||||
{"\x1c\n", "\"\x1c\n\""},
|
||||
{"\xB0\x00\x0B", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.want, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil && tt.want != "" {
|
||||
t.Errorf("QuoteIdentifier(%q) = %v", tt.id, r)
|
||||
}
|
||||
}()
|
||||
|
||||
got := sqlite3.QuoteIdentifier(tt.id)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("QuoteIdentifier(%v) = %q, want %q", tt.id, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -81,6 +82,13 @@ func TestStmt(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.BindBlob(1, []byte("")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.BindBlob(1, []byte("blob")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -102,6 +110,13 @@ func TestStmt(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.BindJSON(1, true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stmt.Exec(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := stmt.ClearBindings(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -114,7 +129,7 @@ func TestStmt(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// The table should have: 0, 1, 2, π, NULL, "", "text", "blob", NULL, "\0\0\0\0", NULL
|
||||
// The table should have: 0, 1, 2, π, NULL, "", "text", "", "blob", NULL, "\0\0\0\0", "true", NULL
|
||||
stmt, _, err = db.Prepare(`SELECT col FROM test`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -140,6 +155,12 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "0" {
|
||||
t.Errorf("got %q, want zero", got)
|
||||
}
|
||||
var got int
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -161,6 +182,12 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "1" {
|
||||
t.Errorf("got %q, want one", got)
|
||||
}
|
||||
var got float32
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != 1 {
|
||||
t.Errorf("got %v, want one", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -182,6 +209,12 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "2" {
|
||||
t.Errorf("got %q, want two", got)
|
||||
}
|
||||
var got json.Number
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != "2" {
|
||||
t.Errorf("got %v, want two", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -203,6 +236,12 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "3.14159265358979" {
|
||||
t.Errorf("got %q, want π", got)
|
||||
}
|
||||
var got float64
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != math.Pi {
|
||||
t.Errorf("got %v, want π", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -224,6 +263,12 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
var got any = 1
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != nil {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -245,6 +290,10 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
var got any
|
||||
if err := stmt.ColumnJSON(0, &got); err == nil {
|
||||
t.Errorf("got %v, want error", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -266,6 +315,35 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "text" {
|
||||
t.Errorf(`got %q, want "text"`, got)
|
||||
}
|
||||
var got any
|
||||
if err := stmt.ColumnJSON(0, &got); err == nil {
|
||||
t.Errorf("got %v, want error", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
|
||||
t.Errorf("got %v, want BLOB", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
var got any
|
||||
if err := stmt.ColumnJSON(0, &got); err == nil {
|
||||
t.Errorf("got %v, want error", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -287,6 +365,10 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "blob" {
|
||||
t.Errorf(`got %q, want "blob"`, got)
|
||||
}
|
||||
var got any
|
||||
if err := stmt.ColumnJSON(0, &got); err == nil {
|
||||
t.Errorf("got %v, want error", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -308,6 +390,12 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
var got any = 1
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != nil {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -329,6 +417,37 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "\x00\x00\x00\x00" {
|
||||
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
|
||||
}
|
||||
var got any
|
||||
if err := stmt.ColumnJSON(0, &got); err == nil {
|
||||
t.Errorf("got %v, want error", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(0); got != 0 {
|
||||
t.Errorf("got %v, want zero", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "true" {
|
||||
t.Errorf("got %q, want true", got)
|
||||
}
|
||||
if got := stmt.ColumnBlob(0, nil); string(got) != "true" {
|
||||
t.Errorf("got %q, want true", got)
|
||||
}
|
||||
var got any = 1
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != true {
|
||||
t.Errorf("got %v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
@@ -350,6 +469,12 @@ func TestStmt(t *testing.T) {
|
||||
if got := stmt.ColumnBlob(0, nil); got != nil {
|
||||
t.Errorf("got %q, want nil", got)
|
||||
}
|
||||
var got any = 1
|
||||
if err := stmt.ColumnJSON(0, &got); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != nil {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
}
|
||||
|
||||
if err := stmt.Close(); err != nil {
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
)
|
||||
|
||||
func TestTimeFormat_Encode(t *testing.T) {
|
||||
@@ -39,70 +41,72 @@ func TestTimeFormat_Encode(t *testing.T) {
|
||||
func TestTimeFormat_Decode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
|
||||
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
|
||||
zone := time.FixedZone("", -4*3600)
|
||||
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, zone)
|
||||
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, zone)
|
||||
|
||||
tests := []struct {
|
||||
fmt sqlite3.TimeFormat
|
||||
val any
|
||||
want time.Time
|
||||
wantDelta time.Duration
|
||||
wantLoc *time.Location
|
||||
wantErr bool
|
||||
}{
|
||||
{sqlite3.TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
|
||||
{sqlite3.TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
|
||||
{sqlite3.TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
|
||||
{sqlite3.TimeFormatJulianDay, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatJulianDay, "2456572.849526851851852", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, time.UTC, false},
|
||||
{sqlite3.TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, time.UTC, false},
|
||||
{sqlite3.TimeFormatJulianDay, false, time.Time{}, 0, nil, true},
|
||||
|
||||
{sqlite3.TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatUnix, int64(1381134199), reference, time.Second, false},
|
||||
{sqlite3.TimeFormatUnix, "abc", time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnix, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnix, "1381134199.120", reference, time.Microsecond, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnix, 1381134199.120, reference, time.Microsecond, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnix, int64(1381134199), reference, time.Second, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnix, "abc", time.Time{}, 0, nil, true},
|
||||
{sqlite3.TimeFormatUnix, false, time.Time{}, 0, nil, true},
|
||||
|
||||
{sqlite3.TimeFormatUnixMilli, "1381134199120", reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnixMilli, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnixMilli, "1381134199120", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixMilli, 1381134199.120e3, reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixMilli, int64(1381134199_120), reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixMilli, "abc", time.Time{}, 0, nil, true},
|
||||
{sqlite3.TimeFormatUnixMilli, false, time.Time{}, 0, nil, true},
|
||||
|
||||
{sqlite3.TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnixMicro, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnixMicro, "1381134199120000", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixMicro, 1381134199.120e6, reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixMicro, "abc", time.Time{}, 0, nil, true},
|
||||
{sqlite3.TimeFormatUnixMicro, false, time.Time{}, 0, nil, true},
|
||||
|
||||
{sqlite3.TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
|
||||
{sqlite3.TimeFormatUnixNano, "abc", time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnixNano, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatUnixNano, "1381134199120000000", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixNano, 1381134199.120e9, reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatUnixNano, "abc", time.Time{}, 0, nil, true},
|
||||
{sqlite3.TimeFormatUnixNano, false, time.Time{}, 0, nil, true},
|
||||
|
||||
{sqlite3.TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
|
||||
{sqlite3.TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199", reference, time.Second, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199120", reference, 0, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, false},
|
||||
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
|
||||
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, false},
|
||||
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "2456572", reference, 24 * time.Hour, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120", reference, time.Microsecond, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199", reference, time.Second, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199120", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, zone, false},
|
||||
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, zone, false},
|
||||
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, nil, true},
|
||||
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, nil, true},
|
||||
|
||||
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
|
||||
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
|
||||
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, false},
|
||||
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, false},
|
||||
{sqlite3.TimeFormat3, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormat9, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, zone, false},
|
||||
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, zone, false},
|
||||
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormat3, false, time.Time{}, 0, nil, true},
|
||||
{sqlite3.TimeFormat9, false, time.Time{}, 0, nil, true},
|
||||
|
||||
{sqlite3.TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
|
||||
{sqlite3.TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
|
||||
{sqlite3.TimeFormatDefault, false, time.Time{}, 0, true},
|
||||
{sqlite3.TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, zone, false},
|
||||
{sqlite3.TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, time.UTC, false},
|
||||
{sqlite3.TimeFormatDefault, false, time.Time{}, 0, nil, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -112,13 +116,48 @@ func TestTimeFormat_Decode(t *testing.T) {
|
||||
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.want.Sub(got).Abs() > tt.wantDelta {
|
||||
if got.Sub(tt.want).Abs() > tt.wantDelta {
|
||||
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
|
||||
}
|
||||
if got.Location().String() != tt.wantLoc.String() {
|
||||
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got.Location(), tt.wantLoc)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeFormat_Scanner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(
|
||||
`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
|
||||
|
||||
_, err = db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.TimeFormat7TZ.Encode(reference))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var got time.Time
|
||||
err = db.QueryRow("SELECT * FROM test").Scan(sqlite3.TimeFormatAuto.Scanner(&got))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !got.Equal(reference) {
|
||||
t.Errorf("got %v, want %v", got, reference)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_timeCollation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -167,3 +206,57 @@ func TestDB_timeCollation(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_isoWeek(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []time.Time{
|
||||
time.Date(1977, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1977, 1, 2, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1977, 12, 31, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1978, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1978, 1, 2, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1978, 12, 31, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1979, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1979, 1, 2, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1979, 12, 31, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1980, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1980, 12, 28, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1980, 12, 29, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1980, 12, 30, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1980, 12, 31, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1981, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1981, 12, 31, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1982, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1982, 1, 2, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1982, 1, 3, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT strftime('%G-W%V-%u', ?)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for _, tm := range tests {
|
||||
stmt.BindTime(1, tm, sqlite3.TimeFormatDefault)
|
||||
if stmt.Step() {
|
||||
y, w := tm.ISOWeek()
|
||||
d := tm.Weekday()
|
||||
if d == 0 {
|
||||
d = 7
|
||||
}
|
||||
want := fmt.Sprintf("%04d-W%02d-%d", y, w, d)
|
||||
if got := stmt.ColumnText(0); got != want {
|
||||
t.Errorf("got %q, want %q (%v)", got, want, tm)
|
||||
}
|
||||
}
|
||||
stmt.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
56
time.go
56
time.go
@@ -164,9 +164,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
|
||||
case float64:
|
||||
sec, frac := math.Modf(v)
|
||||
nsec := math.Floor(frac * 1e9)
|
||||
return time.Unix(int64(sec), int64(nsec)), nil
|
||||
return time.Unix(int64(sec), int64(nsec)).UTC(), nil
|
||||
case int64:
|
||||
return time.Unix(v, 0), nil
|
||||
return time.Unix(v, 0).UTC(), nil
|
||||
default:
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
@@ -181,9 +181,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
return time.UnixMilli(int64(math.Floor(v))), nil
|
||||
return time.UnixMilli(int64(math.Floor(v))).UTC(), nil
|
||||
case int64:
|
||||
return time.UnixMilli(int64(v)), nil
|
||||
return time.UnixMilli(int64(v)).UTC(), nil
|
||||
default:
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
@@ -198,9 +198,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
return time.UnixMicro(int64(math.Floor(v))), nil
|
||||
return time.UnixMicro(int64(math.Floor(v))).UTC(), nil
|
||||
case int64:
|
||||
return time.UnixMicro(int64(v)), nil
|
||||
return time.UnixMicro(int64(v)).UTC(), nil
|
||||
default:
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
@@ -215,9 +215,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
return time.Unix(0, int64(math.Floor(v))), nil
|
||||
return time.Unix(0, int64(math.Floor(v))).UTC(), nil
|
||||
case int64:
|
||||
return time.Unix(0, int64(v)), nil
|
||||
return time.Unix(0, int64(v)).UTC(), nil
|
||||
default:
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
@@ -238,26 +238,16 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
|
||||
}
|
||||
|
||||
dates := []TimeFormat{
|
||||
TimeFormat6TZ, TimeFormat6, TimeFormat3TZ, TimeFormat3,
|
||||
TimeFormat5TZ, TimeFormat5, TimeFormat2TZ, TimeFormat2,
|
||||
TimeFormat1,
|
||||
TimeFormat9, TimeFormat8,
|
||||
TimeFormat6, TimeFormat5,
|
||||
TimeFormat3, TimeFormat2, TimeFormat1,
|
||||
}
|
||||
for _, f := range dates {
|
||||
t, err := time.Parse(string(f), s)
|
||||
t, err := f.Decode(s)
|
||||
if err == nil {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
times := []TimeFormat{
|
||||
TimeFormat9TZ, TimeFormat9, TimeFormat8TZ, TimeFormat8,
|
||||
}
|
||||
for _, f := range times {
|
||||
t, err := time.Parse(string(f), s)
|
||||
if err == nil {
|
||||
return t.AddDate(2000, 0, 0), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case float64:
|
||||
@@ -314,7 +304,10 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
|
||||
return time.Time{}, util.TimeErr
|
||||
}
|
||||
t, err := f.parseRelaxed(s)
|
||||
return t.AddDate(2000, 0, 0), err
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return t.AddDate(2000, 0, 0), nil
|
||||
|
||||
default:
|
||||
s, ok := v.(string)
|
||||
@@ -338,3 +331,20 @@ func (f TimeFormat) parseRelaxed(s string) (time.Time, error) {
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Scanner returns a [database/sql.Scanner] that can be used as an argument to
|
||||
// [database/sql.Row.Scan] and similar methods to
|
||||
// decode a time value into dest using this format.
|
||||
func (f TimeFormat) Scanner(dest *time.Time) interface{ Scan(any) error } {
|
||||
return timeScanner{dest, f}
|
||||
}
|
||||
|
||||
type timeScanner struct {
|
||||
*time.Time
|
||||
TimeFormat
|
||||
}
|
||||
|
||||
func (s timeScanner) Scan(src any) (err error) {
|
||||
*s.Time, err = s.Decode(src)
|
||||
return
|
||||
}
|
||||
|
||||
33
tx.go
33
tx.go
@@ -7,6 +7,7 @@ import (
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Tx is an in-progress database transaction.
|
||||
@@ -119,17 +120,8 @@ type Savepoint struct {
|
||||
//
|
||||
// https://www.sqlite.org/lang_savepoint.html
|
||||
func (c *Conn) Savepoint() Savepoint {
|
||||
name := "sqlite3.Savepoint"
|
||||
var pc [1]uintptr
|
||||
if n := runtime.Callers(2, pc[:]); n > 0 {
|
||||
frames := runtime.CallersFrames(pc[:n])
|
||||
frame, _ := frames.Next()
|
||||
if frame.Function != "" {
|
||||
name = frame.Function
|
||||
}
|
||||
}
|
||||
// Names can be reused; this makes catching bugs more likely.
|
||||
name += "#" + strconv.Itoa(int(rand.Int31()))
|
||||
name := saveptName() + "_" + strconv.Itoa(int(rand.Int31()))
|
||||
|
||||
err := c.txExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name))
|
||||
if err != nil {
|
||||
@@ -138,6 +130,27 @@ func (c *Conn) Savepoint() Savepoint {
|
||||
return Savepoint{c: c, name: name}
|
||||
}
|
||||
|
||||
func saveptName() (name string) {
|
||||
defer func() {
|
||||
if name == "" {
|
||||
name = "sqlite3.Savepoint"
|
||||
}
|
||||
}()
|
||||
|
||||
var pc [8]uintptr
|
||||
n := runtime.Callers(3, pc[:])
|
||||
if n <= 0 {
|
||||
return ""
|
||||
}
|
||||
frames := runtime.CallersFrames(pc[:n])
|
||||
frame, more := frames.Next()
|
||||
for more && (strings.HasPrefix(frame.Function, "database/sql.") ||
|
||||
strings.HasPrefix(frame.Function, "github.com/ncruces/go-sqlite3/driver.")) {
|
||||
frame, more = frames.Next()
|
||||
}
|
||||
return frame.Function
|
||||
}
|
||||
|
||||
// Release releases the savepoint rolling back any changes
|
||||
// if *error points to a non-nil error.
|
||||
//
|
||||
|
||||
30
value.go
30
value.go
@@ -1,7 +1,9 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
@@ -123,3 +125,31 @@ func (v Value) rawBytes(ptr uint32) []byte {
|
||||
r := v.call(v.api.valueBytes, uint64(v.handle))
|
||||
return util.View(v.mod, ptr, r)
|
||||
}
|
||||
|
||||
// Pointer gets the pointer associated with this value,
|
||||
// or nil if it has no associated pointer.
|
||||
func (v Value) Pointer() any {
|
||||
r := v.call(v.api.valuePointer, uint64(v.handle))
|
||||
return util.GetHandle(v.ctx, uint32(r))
|
||||
}
|
||||
|
||||
// JSON parses a JSON-encoded value
|
||||
// and stores the result in the value pointed to by ptr.
|
||||
func (v Value) JSON(ptr any) error {
|
||||
var data []byte
|
||||
switch v.Type() {
|
||||
case NULL:
|
||||
data = append(data, "null"...)
|
||||
case TEXT:
|
||||
data = v.RawText()
|
||||
case BLOB:
|
||||
data = v.RawBlob()
|
||||
case INTEGER:
|
||||
data = strconv.AppendInt(nil, v.Int64(), 10)
|
||||
case FLOAT:
|
||||
data = strconv.AppendFloat(nil, v.Float(), 'g', -1, 64)
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
return json.Unmarshal(data, ptr)
|
||||
}
|
||||
|
||||
@@ -4,28 +4,4 @@ This package implements the SQLite [OS Interface](https://www.sqlite.org/vfs.htm
|
||||
|
||||
It replaces the default SQLite VFS with a pure Go implementation.
|
||||
|
||||
It also exposes interfaces that should allow you to implement your own custom VFSes.
|
||||
|
||||
## Portability
|
||||
|
||||
This package is tested on Linux, macOS and Windows,
|
||||
but it should also work on FreeBSD and illumos
|
||||
(code paths for those plaforms are tested on macOS and Linux, respectively).
|
||||
|
||||
In all platforms for which this package builds,
|
||||
it should be safe to use it to access databases concurrently,
|
||||
from multiple goroutines, processes, and
|
||||
with _other_ implementations of SQLite.
|
||||
|
||||
If the package does not build for your platform,
|
||||
you may try to use the `sqlite3_flock` and `sqlite3_nolock` build tags.
|
||||
These are only minimally tested and concurrency test failures should be expected.
|
||||
|
||||
The `sqlite3_flock` tag uses
|
||||
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2).
|
||||
It should be safe to access databases concurrently from multiple goroutines and processes,
|
||||
but **not** with _other_ implementations of SQLite
|
||||
(_unless_ these are _also_ configured to use `flock`).
|
||||
|
||||
The `sqlite3_nolock` tag uses no locking at all.
|
||||
Database corruption is the likely result from concurrent write access.
|
||||
It also exposes interfaces that should allow you to implement your own custom VFSes.
|
||||
23
vfs/lock.go
23
vfs/lock.go
@@ -1,12 +1,6 @@
|
||||
//go:build !sqlite3_nolock
|
||||
|
||||
package vfs
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
import "github.com/ncruces/go-sqlite3/internal/util"
|
||||
|
||||
const (
|
||||
_PENDING_BYTE = 0x40000000
|
||||
@@ -134,18 +128,3 @@ func (f *vfsFile) CheckReservedLock() (bool, error) {
|
||||
}
|
||||
return osCheckReservedLock(f.File)
|
||||
}
|
||||
|
||||
func osGetReservedLock(file *os.File) _ErrorCode {
|
||||
// Acquire the RESERVED lock.
|
||||
return osWriteLock(file, _RESERVED_BYTE, 1, 0)
|
||||
}
|
||||
|
||||
func osGetPendingLock(file *os.File) _ErrorCode {
|
||||
// Acquire the PENDING lock.
|
||||
return osWriteLock(file, _PENDING_BYTE, 1, 0)
|
||||
}
|
||||
|
||||
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
|
||||
// Test the RESERVED lock.
|
||||
return osCheckLock(file, _RESERVED_BYTE, 1)
|
||||
}
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
//go:build sqlite3_nolock
|
||||
|
||||
package vfs
|
||||
|
||||
const (
|
||||
_PENDING_BYTE = 0x40000000
|
||||
_RESERVED_BYTE = (_PENDING_BYTE + 1)
|
||||
_SHARED_FIRST = (_PENDING_BYTE + 2)
|
||||
_SHARED_SIZE = 510
|
||||
)
|
||||
|
||||
func (f *vfsFile) Lock(lock LockLevel) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *vfsFile) Unlock(lock LockLevel) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *vfsFile) CheckReservedLock() (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build sqlite3_flock || freebsd
|
||||
//go:build (freebsd || openbsd || netbsd || dragonfly || sqlite3_flock) && !sqlite3_nosys
|
||||
|
||||
package vfs
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !sqlite3_flock
|
||||
//go:build !sqlite3_flock && !sqlite3_nosys
|
||||
|
||||
package vfs
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !sqlite3_nosys
|
||||
|
||||
package vfs
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,28 +1,33 @@
|
||||
//go:build sqlite3_nolock && unix && !(linux || darwin || freebsd || illumos)
|
||||
//go:build !(linux || darwin || windows || freebsd || openbsd || netbsd || dragonfly || illumos) || sqlite3_nosys
|
||||
|
||||
package vfs
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
import "os"
|
||||
|
||||
func osUnlock(file *os.File, start, len int64) _ErrorCode {
|
||||
return _OK
|
||||
func osGetSharedLock(file *os.File) _ErrorCode {
|
||||
return _IOERR_RDLOCK
|
||||
}
|
||||
|
||||
func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, def _ErrorCode) _ErrorCode {
|
||||
return _OK
|
||||
func osGetReservedLock(file *os.File) _ErrorCode {
|
||||
return _IOERR_LOCK
|
||||
}
|
||||
|
||||
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
|
||||
return _OK
|
||||
func osGetPendingLock(file *os.File) _ErrorCode {
|
||||
return _IOERR_LOCK
|
||||
}
|
||||
|
||||
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
|
||||
return _OK
|
||||
func osGetExclusiveLock(file *os.File) _ErrorCode {
|
||||
return _IOERR_LOCK
|
||||
}
|
||||
|
||||
func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) {
|
||||
return false, _OK
|
||||
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
|
||||
return _IOERR_RDLOCK
|
||||
}
|
||||
|
||||
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
|
||||
return _IOERR_UNLOCK
|
||||
}
|
||||
|
||||
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
|
||||
return false, _IOERR_CHECKRESERVEDLOCK
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux || illumos) && !sqlite3_flock
|
||||
//go:build (linux || illumos) && !sqlite3_nosys
|
||||
|
||||
package vfs
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !unix
|
||||
//go:build !unix || sqlite3_nosys
|
||||
|
||||
package vfs
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && (!darwin || sqlite3_flock)
|
||||
//go:build !(linux || darwin) || sqlite3_flock || sqlite3_nosys
|
||||
|
||||
package vfs
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !unix
|
||||
//go:build !unix || sqlite3_nosys
|
||||
|
||||
package vfs
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !windows
|
||||
//go:build !windows || sqlite3_nosys
|
||||
|
||||
package vfs
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && (!darwin || sqlite3_flock)
|
||||
//go:build !(linux || darwin) || sqlite3_flock || sqlite3_nosys
|
||||
|
||||
package vfs
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
//go:build unix
|
||||
//go:build unix && !sqlite3_nosys
|
||||
|
||||
package vfs
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
@@ -32,60 +31,3 @@ func osSetMode(file *os.File, modeof string) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func osGetSharedLock(file *os.File) _ErrorCode {
|
||||
// Test the PENDING lock before acquiring a new SHARED lock.
|
||||
if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending {
|
||||
return _BUSY
|
||||
}
|
||||
// Acquire the SHARED lock.
|
||||
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
|
||||
}
|
||||
|
||||
func osGetExclusiveLock(file *os.File) _ErrorCode {
|
||||
// Acquire the EXCLUSIVE lock.
|
||||
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, time.Millisecond)
|
||||
}
|
||||
|
||||
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
|
||||
if state >= LOCK_EXCLUSIVE {
|
||||
// Downgrade to a SHARED lock.
|
||||
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
|
||||
// In theory, the downgrade to a SHARED cannot fail because another
|
||||
// process is holding an incompatible lock. If it does, this
|
||||
// indicates that the other process is not following the locking
|
||||
// protocol. If this happens, return _IOERR_RDLOCK. Returning
|
||||
// BUSY would confuse the upper layer.
|
||||
return _IOERR_RDLOCK
|
||||
}
|
||||
}
|
||||
// Release the PENDING and RESERVED locks.
|
||||
return osUnlock(file, _PENDING_BYTE, 2)
|
||||
}
|
||||
|
||||
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
|
||||
// Release all locks.
|
||||
return osUnlock(file, 0, 0)
|
||||
}
|
||||
|
||||
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
|
||||
if err == nil {
|
||||
return _OK
|
||||
}
|
||||
if errno, ok := err.(unix.Errno); ok {
|
||||
switch errno {
|
||||
case
|
||||
unix.EACCES,
|
||||
unix.EAGAIN,
|
||||
unix.EBUSY,
|
||||
unix.EINTR,
|
||||
unix.ENOLCK,
|
||||
unix.EDEADLK,
|
||||
unix.ETIMEDOUT:
|
||||
return _BUSY
|
||||
case unix.EPERM:
|
||||
return _PERM
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
82
vfs/os_unix2.go
Normal file
82
vfs/os_unix2.go
Normal file
@@ -0,0 +1,82 @@
|
||||
//go:build (linux || darwin || freebsd || openbsd || netbsd || dragonfly || illumos) && !sqlite3_nosys
|
||||
|
||||
package vfs
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func osGetSharedLock(file *os.File) _ErrorCode {
|
||||
// Test the PENDING lock before acquiring a new SHARED lock.
|
||||
if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending {
|
||||
return _BUSY
|
||||
}
|
||||
// Acquire the SHARED lock.
|
||||
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
|
||||
}
|
||||
|
||||
func osGetReservedLock(file *os.File) _ErrorCode {
|
||||
// Acquire the RESERVED lock.
|
||||
return osWriteLock(file, _RESERVED_BYTE, 1, 0)
|
||||
}
|
||||
|
||||
func osGetPendingLock(file *os.File) _ErrorCode {
|
||||
// Acquire the PENDING lock.
|
||||
return osWriteLock(file, _PENDING_BYTE, 1, 0)
|
||||
}
|
||||
|
||||
func osGetExclusiveLock(file *os.File) _ErrorCode {
|
||||
// Acquire the EXCLUSIVE lock.
|
||||
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, time.Millisecond)
|
||||
}
|
||||
|
||||
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
|
||||
if state >= LOCK_EXCLUSIVE {
|
||||
// Downgrade to a SHARED lock.
|
||||
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
|
||||
// In theory, the downgrade to a SHARED cannot fail because another
|
||||
// process is holding an incompatible lock. If it does, this
|
||||
// indicates that the other process is not following the locking
|
||||
// protocol. If this happens, return _IOERR_RDLOCK. Returning
|
||||
// BUSY would confuse the upper layer.
|
||||
return _IOERR_RDLOCK
|
||||
}
|
||||
}
|
||||
// Release the PENDING and RESERVED locks.
|
||||
return osUnlock(file, _PENDING_BYTE, 2)
|
||||
}
|
||||
|
||||
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
|
||||
// Release all locks.
|
||||
return osUnlock(file, 0, 0)
|
||||
}
|
||||
|
||||
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
|
||||
// Test the RESERVED lock.
|
||||
return osCheckLock(file, _RESERVED_BYTE, 1)
|
||||
}
|
||||
|
||||
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
|
||||
if err == nil {
|
||||
return _OK
|
||||
}
|
||||
if errno, ok := err.(unix.Errno); ok {
|
||||
switch errno {
|
||||
case
|
||||
unix.EACCES,
|
||||
unix.EAGAIN,
|
||||
unix.EBUSY,
|
||||
unix.EINTR,
|
||||
unix.ENOLCK,
|
||||
unix.EDEADLK,
|
||||
unix.ETIMEDOUT:
|
||||
return _BUSY
|
||||
case unix.EPERM:
|
||||
return _PERM
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !sqlite3_nosys
|
||||
|
||||
package vfs
|
||||
|
||||
import (
|
||||
@@ -39,6 +41,16 @@ func osGetSharedLock(file *os.File) _ErrorCode {
|
||||
return rc
|
||||
}
|
||||
|
||||
func osGetReservedLock(file *os.File) _ErrorCode {
|
||||
// Acquire the RESERVED lock.
|
||||
return osWriteLock(file, _RESERVED_BYTE, 1, 0)
|
||||
}
|
||||
|
||||
func osGetPendingLock(file *os.File) _ErrorCode {
|
||||
// Acquire the PENDING lock.
|
||||
return osWriteLock(file, _PENDING_BYTE, 1, 0)
|
||||
}
|
||||
|
||||
func osGetExclusiveLock(file *os.File) _ErrorCode {
|
||||
// Release the SHARED lock.
|
||||
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
|
||||
@@ -90,6 +102,11 @@ func osReleaseLock(file *os.File, state LockLevel) _ErrorCode {
|
||||
return _OK
|
||||
}
|
||||
|
||||
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
|
||||
// Test the RESERVED lock.
|
||||
return osCheckLock(file, _RESERVED_BYTE, 1)
|
||||
}
|
||||
|
||||
func osUnlock(file *os.File, start, len uint32) _ErrorCode {
|
||||
err := windows.UnlockFileEx(windows.Handle(file.Fd()),
|
||||
0, len, 0, &windows.Overlapped{Offset: start})
|
||||
|
||||
@@ -2,6 +2,7 @@ package mptest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/bzip2"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"embed"
|
||||
@@ -24,8 +25,8 @@ import (
|
||||
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
|
||||
)
|
||||
|
||||
//go:embed testdata/mptest.wasm
|
||||
var binary []byte
|
||||
//go:embed testdata/mptest.wasm.bz2
|
||||
var compressed string
|
||||
|
||||
//go:embed testdata/*.*test
|
||||
var scripts embed.FS
|
||||
@@ -48,6 +49,11 @@ func TestMain(m *testing.M) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
binary, err := io.ReadAll(bzip2.NewReader(strings.NewReader(compressed)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
module, err = rt.CompileModule(ctx, binary)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
||||
2
vfs/tests/mptest/testdata/.gitattributes
vendored
2
vfs/tests/mptest/testdata/.gitattributes
vendored
@@ -1,2 +1,2 @@
|
||||
mptest.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
mptest.wasm.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.*test -crlf
|
||||
3
vfs/tests/mptest/testdata/build.sh
vendored
3
vfs/tests/mptest/testdata/build.sh
vendored
@@ -28,4 +28,5 @@ WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
|
||||
--enable-simd --enable-mutable-globals --enable-multivalue \
|
||||
--enable-bulk-memory --enable-reference-types \
|
||||
--enable-nontrapping-float-to-int --enable-sign-ext
|
||||
mv mptest.tmp mptest.wasm
|
||||
mv mptest.tmp mptest.wasm
|
||||
bzip2 -9f mptest.wasm
|
||||
6
vfs/tests/mptest/testdata/main.c
vendored
6
vfs/tests/mptest/testdata/main.c
vendored
@@ -1,5 +1,4 @@
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
#include <unistd.h>
|
||||
|
||||
// Amalgamation
|
||||
#include "sqlite3.c"
|
||||
@@ -8,9 +7,8 @@
|
||||
|
||||
__attribute__((constructor)) void init() { sqlite3_initialize(); }
|
||||
|
||||
static int dont_unlink(const char *pathname) { return 0; }
|
||||
#define sqlite3_enable_load_extension(...)
|
||||
#define sqlite3_trace(...)
|
||||
#define unlink dont_unlink
|
||||
#define unlink(...) (0)
|
||||
#undef UNUSED_PARAMETER
|
||||
#include "mptest.c"
|
||||
3
vfs/tests/mptest/testdata/mptest.wasm
vendored
3
vfs/tests/mptest/testdata/mptest.wasm
vendored
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5b77e9e13a487e976a6e71bc698542098433d1cc586ad8f24784f1f325ffb8dd
|
||||
size 1459145
|
||||
3
vfs/tests/mptest/testdata/mptest.wasm.bz2
vendored
Normal file
3
vfs/tests/mptest/testdata/mptest.wasm.bz2
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:74294bf19d213056ef5ffb7a980c3a7de5d029d0621ded53394d3055dfc4f604
|
||||
size 513604
|
||||
@@ -2,6 +2,7 @@ package speedtest1
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/bzip2"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"flag"
|
||||
@@ -23,8 +24,8 @@ import (
|
||||
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
//go:embed testdata/speedtest1.wasm
|
||||
var binary []byte
|
||||
//go:embed testdata/speedtest1.wasm.bz2
|
||||
var compressed string
|
||||
|
||||
var (
|
||||
rt wazero.Runtime
|
||||
@@ -45,6 +46,11 @@ func TestMain(m *testing.M) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
binary, err := io.ReadAll(bzip2.NewReader(strings.NewReader(compressed)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
module, err = rt.CompileModule(ctx, binary)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
||||
2
vfs/tests/speedtest1/testdata/.gitattributes
vendored
2
vfs/tests/speedtest1/testdata/.gitattributes
vendored
@@ -1 +1 @@
|
||||
speedtest1.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
speedtest1.wasm.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
3
vfs/tests/speedtest1/testdata/build.sh
vendored
3
vfs/tests/speedtest1/testdata/build.sh
vendored
@@ -23,4 +23,5 @@ WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
|
||||
--enable-simd --enable-mutable-globals --enable-multivalue \
|
||||
--enable-bulk-memory --enable-reference-types \
|
||||
--enable-nontrapping-float-to-int --enable-sign-ext
|
||||
mv speedtest1.tmp speedtest1.wasm
|
||||
mv speedtest1.tmp speedtest1.wasm
|
||||
bzip2 -9f speedtest1.wasm
|
||||
2
vfs/tests/speedtest1/testdata/main.c
vendored
2
vfs/tests/speedtest1/testdata/main.c
vendored
@@ -6,5 +6,5 @@
|
||||
// VFS
|
||||
#include "vfs.c"
|
||||
|
||||
#define randomFunc(args...) randomFunc2(args)
|
||||
#define randomFunc randomFunc2
|
||||
#include "speedtest1.c"
|
||||
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3b52de3306965ac3f812592be29697d75232802a13bb16a34344f8d81dbf0637
|
||||
size 1499410
|
||||
3
vfs/tests/speedtest1/testdata/speedtest1.wasm.bz2
vendored
Normal file
3
vfs/tests/speedtest1/testdata/speedtest1.wasm.bz2
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:966754393264cc43eb931ece22941d0d607e7e776e26c26b548209d2264d01a1
|
||||
size 527530
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"io/fs"
|
||||
"math"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"testing"
|
||||
@@ -208,6 +209,10 @@ func Test_vfsAccess(t *testing.T) {
|
||||
t.Error("can't access file")
|
||||
}
|
||||
|
||||
if usr, err := user.Current(); err == nil && usr.Uid == "0" {
|
||||
t.Skip("skipping as root")
|
||||
}
|
||||
|
||||
util.WriteString(mod, 8, file)
|
||||
rc = vfsAccess(ctx, mod, 0, 8, ACCESS_READWRITE, 4)
|
||||
if rc != _OK {
|
||||
|
||||
Reference in New Issue
Block a user