Compare commits

...

32 Commits

Author SHA1 Message Date
Nuno Cruces
828788912e JSON example. 2023-11-09 12:11:36 +00:00
Nuno Cruces
6f8645cd2e Tests. 2023-11-08 07:28:48 +00:00
Nuno Cruces
c00927e8bb Driver savepoints. 2023-11-07 15:19:40 +00:00
dependabot[bot]
6b28be6d0e Bump golang.org/x/sys from 0.13.0 to 0.14.0 (#36)
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.13.0 to 0.14.0.
- [Commits](https://github.com/golang/sys/compare/v0.13.0...v0.14.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-07 15:00:00 +00:00
dependabot[bot]
310b4ff29d Bump golang.org/x/sync from 0.4.0 to 0.5.0 (#35)
Bumps [golang.org/x/sync](https://github.com/golang/sync) from 0.4.0 to 0.5.0.
- [Commits](https://github.com/golang/sync/compare/v0.4.0...v0.5.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sync
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-07 12:38:12 +00:00
dependabot[bot]
e82cf16b11 Bump golang.org/x/text from 0.13.0 to 0.14.0 (#34)
Bumps [golang.org/x/text](https://github.com/golang/text) from 0.13.0 to 0.14.0.
- [Release notes](https://github.com/golang/text/releases)
- [Commits](https://github.com/golang/text/compare/v0.13.0...v0.14.0)

---
updated-dependencies:
- dependency-name: golang.org/x/text
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-07 00:57:28 +00:00
Nuno Cruces
24c9b57c56 Pointer-passing interfaces. 2023-11-07 00:50:43 +00:00
Nuno Cruces
24b965ac7e Refactor. 2023-11-06 18:29:28 +00:00
Nuno Cruces
446168c572 Update workflows. 2023-11-04 11:21:31 +00:00
Nuno Cruces
a9e2cbbfc5 Quote values, identifiers. 2023-11-04 01:18:25 +00:00
Nuno Cruces
a7c00eb150 SQLite 3.44.0. 2023-11-03 03:43:14 -07:00
Nuno Cruces
0bcdb712ba SQL json_time function. 2023-11-03 03:40:46 -07:00
Nuno Cruces
2157d0f325 Interrupts: avoid goroutine. 2023-10-25 14:12:21 +01:00
Nuno Cruces
6353160619 Improve benchmark repeatability. 2023-10-25 13:17:37 +01:00
Nuno Cruces
501d157279 Update BSD test. 2023-10-24 23:27:10 +01:00
Nuno Cruces
4db18a7b9a JSON encoding fix. 2023-10-19 16:46:58 +01:00
Nuno Cruces
a9dddaa86c Optimize VFS search. 2023-10-19 16:43:54 +01:00
Nuno Cruces
b25936dbec Unix formats return UTC. 2023-10-19 12:07:03 +01:00
dependabot[bot]
bf23041e46 Bump github.com/ncruces/julianday from 0.1.5 to 1.0.0 (#33)
Bumps [github.com/ncruces/julianday](https://github.com/ncruces/julianday) from 0.1.5 to 1.0.0.
- [Commits](https://github.com/ncruces/julianday/compare/v0.1.5...v1.0.0)

---
updated-dependencies:
- dependency-name: github.com/ncruces/julianday
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-19 00:34:35 +01:00
Nuno Cruces
d60fceac92 JSON support. 2023-10-18 23:14:46 +01:00
Nuno Cruces
61da30f44a Allow configuring wazero. 2023-10-18 15:06:32 +01:00
Nuno Cruces
d4ff605983 Time scanner. 2023-10-18 15:06:12 +01:00
Nuno Cruces
8d0c654178 Cross compilation. 2023-10-17 15:30:08 +01:00
Nuno Cruces
728e59951b Test BSD. 2023-10-16 12:51:49 +01:00
Nuno Cruces
f7b16bad5c Patch flaky tests. 2023-10-16 12:26:25 +01:00
Nuno Cruces
db3e6da31a BSD locks. 2023-10-16 02:11:20 +01:00
Nuno Cruces
3f443b2ecc API change. 2023-10-13 18:53:37 +01:00
Nuno Cruces
eec45ea684 Towards JSON. 2023-10-13 17:06:05 +01:00
Nuno Cruces
f6d77f3cf4 GORM v1.25.5. 2023-10-13 00:42:06 +01:00
Nuno Cruces
d5d7cd1f2d SQLite 3.43.2. 2023-10-12 10:52:48 +01:00
dependabot[bot]
a33a187d48 Bump golang.org/x/sys from 0.12.0 to 0.13.0 (#30)
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.12.0 to 0.13.0.
- [Commits](https://github.com/golang/sys/compare/v0.12.0...v0.13.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-05 23:53:11 +01:00
dependabot[bot]
70c6ee15c6 Bump golang.org/x/sync from 0.3.0 to 0.4.0 (#29)
Bumps [golang.org/x/sync](https://github.com/golang/sync) from 0.3.0 to 0.4.0.
- [Commits](https://github.com/golang/sync/compare/v0.3.0...v0.4.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sync
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-05 23:43:11 +01:00
89 changed files with 2027 additions and 820 deletions

29
.github/workflows/bsd.yml vendored Normal file
View File

@@ -0,0 +1,29 @@
name: BSD
on:
workflow_dispatch:
jobs:
test:
runs-on: macos-12
steps:
- uses: actions/checkout@v4
with:
lfs: 'true'
- name: Set up
uses: actions/setup-go@v4
with:
go-version: stable
- name: Build
run: GOOS=freebsd go test -c ./...
- name: Test
uses: cross-platform-actions/action@v0.21.1
with:
operating_system: freebsd
version: '13.2'
sync_files: runner-to-vm
run: find . -name '*.test' -maxdepth 1 -exec {} -test.v \;

View File

@@ -1,76 +0,0 @@
# For most projects, this workflow file will not need changing; you simply need
# to commit it to your repository.
#
# You may wish to alter this file to override the set of languages analyzed,
# or to provide custom queries or build logic.
#
# ******** NOTE ********
# We have attempted to detect the languages in your repository. Please check
# the `language` matrix defined below to confirm you have the correct set of
# supported CodeQL languages.
#
name: "CodeQL"
on:
push:
branches: [ "main" ]
pull_request:
# The branches below must be a subset of the branches above
branches: [ "main" ]
schedule:
- cron: '15 18 * * 6'
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: [ 'go' ]
# CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
# Use only 'java' to analyze code written in Java, Kotlin or both
# Use only 'javascript' to analyze code written in JavaScript, TypeScript or both
# Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support
steps:
- name: Checkout repository
uses: actions/checkout@v3
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v2
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
# queries: security-extended,security-and-quality
# Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v2
# Command-line programs to run using the OS shell.
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
# If the Autobuild fails above, remove it and uncomment the following three lines.
# modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance.
# - run: |
# echo "Run, Build Application using script"
# ./location_of_script_within_repo/buildscript.sh
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2
with:
category: "/language:${{matrix.language}}"

22
.github/workflows/cross.sh vendored Executable file
View File

@@ -0,0 +1,22 @@
#!/usr/bin/env bash
echo android ; GOOS=android GOARCH=amd64 go build .
echo darwin ; GOOS=darwin GOARCH=amd64 go build .
echo dragonfly ; GOOS=dragonfly GOARCH=amd64 go build .
echo freebsd ; GOOS=freebsd GOARCH=amd64 go build .
echo illumos ; GOOS=illumos GOARCH=amd64 go build .
echo ios ; GOOS=ios GOARCH=amd64 go build .
echo linux ; GOOS=linux GOARCH=amd64 go build .
echo netbsd ; GOOS=netbsd GOARCH=amd64 go build .
echo openbsd ; GOOS=openbsd GOARCH=amd64 go build .
echo plan9 ; GOOS=plan9 GOARCH=amd64 go build .
echo solaris ; GOOS=solaris GOARCH=amd64 go build .
echo windows ; GOOS=windows GOARCH=amd64 go build .
# echo aix ; GOOS=aix GOARCH=ppc64 go build .
echo js ; GOOS=js GOARCH=wasm go build .
echo wasip1 ; GOOS=wasip1 GOARCH=wasm go build .
echo darwin-flock ; GOOS=darwin GOARCH=amd64 go build -tags sqlite3_flock .
echo darwin-nosys ; GOOS=darwin GOARCH=amd64 go build -tags sqlite3_nosys .
echo linux-nosys ; GOOS=linux GOARCH=amd64 go build -tags sqlite3_nosys .
echo windows-nosys ; GOOS=windows GOARCH=amd64 go build -tags sqlite3_nosys .
echo freebsd-nosys ; GOOS=freebsd GOARCH=amd64 go build -tags sqlite3_nosys .

21
.github/workflows/cross.yml vendored Normal file
View File

@@ -0,0 +1,21 @@
name: Cross compile
on:
workflow_dispatch:
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
lfs: 'true'
- name: Set up
uses: actions/setup-go@v4
with:
go-version: stable
- name: Build
run: .github/workflows/cross.sh

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
lfs: 'true'
@@ -39,7 +39,6 @@ jobs:
- name: Vet
run: go vet ./...
continue-on-error: true
- name: Build
run: go build -v ./...
@@ -48,8 +47,7 @@ jobs:
run: go test -v ./...
- name: Test no locks
run: go test -v -tags sqlite3_nolock .
if: matrix.os == 'ubuntu-latest'
run: go test -v -tags sqlite3_nosys ./tests -run TestDB_nolock
- name: Test BSD locks
run: go test -v -tags sqlite3_flock ./...
@@ -58,10 +56,9 @@ jobs:
- name: Coverage report
uses: ncruces/go-coverage-report@v0
with:
chart: 'true'
amend: 'true'
reuse-go: 'true'
chart: true
amend: true
reuse-go: true
if: |
github.event_name == 'push' &&
matrix.os == 'ubuntu-latest'
continue-on-error: true

View File

@@ -15,16 +15,18 @@ and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings.
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
- [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
embeds a build of SQLite into your application.
- [`github.com/ncruces/go-sqlite3/ext/blob`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/blob)
simplifies incremental BLOB I/O.
- [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
registers [statistics functions](https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html).
- [`github.com/ncruces/go-sqlite3/ext/unicode`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/unicode)
registers Unicode aware functions.
- [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
- [`github.com/ncruces/go-sqlite3/vfs/memdb`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/memdb)
implements an in-memory VFS.
- [`github.com/ncruces/go-sqlite3/vfs/readervfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/readervfs)
implements a VFS for immutable databases.
- [`github.com/ncruces/go-sqlite3/ext/unicode`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/unicode)
registers Unicode aware functions.
- [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
registers [statistics functions](https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html).
- [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
provides a [GORM](https://gorm.io) driver.
@@ -43,39 +45,39 @@ To work around this limitation, SQLite is [patched](sqlite3/locking_mode.patch)
to always use `EXCLUSIVE` locking mode for WAL databases.
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
to use the [`database/sql`](https://pkg.go.dev/database/sql)
driver with WAL mode databases you should disable connection pooling by calling
to use the [`database/sql`](https://pkg.go.dev/database/sql) driver
with WAL mode databases you should disable connection pooling by calling
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
#### POSIX Advisory Locks
#### File Locking
POSIX advisory locks, which SQLite uses, are
POSIX advisory locks, which SQLite uses on Unix, are
[broken by design](https://www.sqlite.org/src/artifact/2e8b12?ln=1073-1161).
On Linux, macOS and illumos, this module uses
[OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
to synchronize access to database files.
OFD locks are fully compatible with process-associated POSIX advisory locks.
OFD locks are fully compatible with POSIX advisory locks.
On BSD Unixes, this module may use
On BSD Unixes, this module uses
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2).
BSD locks may _not_ be compatible with process-associated POSIX advisory locks.
On BSD Unixes, BSD locks are fully compatible with POSIX advisory locks.
##### TL;DR
On Windows, this module uses `LockFile`, `LockFileEx`, and `UnlockFile`,
like SQLite.
In all platforms for which this package builds,
it should be safe to use it to access databases concurrently,
from multiple goroutines, processes, and
with _other_ implementations of SQLite.
If the package does not build for your platform,
see [this](vfs/README.md#portability).
On all other platforms, file locking is not supported, and you must use
[`nolock=1`](https://www.sqlite.org/uri.html#urinolock)
to open database files.
To use the [`database/sql`](https://pkg.go.dev/database/sql) driver
with `nolock=1` you must disable connection pooling by calling
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
#### Testing
The pure Go VFS is tested by running SQLite's
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
on Linux, macOS and Windows.
on Linux, macOS, Windows and FreeBSD.
Performance is tested by running
[speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c).
@@ -86,6 +88,8 @@ Performance is tested by running
- [x] nested transactions
- [x] incremental BLOB I/O
- [x] online backup
- [x] JSON support
- [ ] virtual tables
- [ ] session extension
- [ ] custom VFSes
- [x] custom VFS API

View File

@@ -118,8 +118,8 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
return 0, nil
}
want := int64(1024 * 1024)
avail := b.bytes - b.offset
want := int64(65536)
if want > avail {
want = avail
}
@@ -175,8 +175,11 @@ func (b *Blob) Write(p []byte) (n int, err error) {
//
// https://www.sqlite.org/c3ref/blob_write.html
func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) {
want := int64(1024 * 1024)
avail := b.bytes - b.offset
want := int64(65536)
if l, ok := r.(*io.LimitedReader); ok && want > l.N {
want = l.N
}
if want > avail {
want = avail
}

78
conn.go
View File

@@ -7,10 +7,9 @@ import (
"net/url"
"runtime"
"strings"
"sync/atomic"
"unsafe"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
)
// Conn is a database connection handle.
@@ -21,7 +20,6 @@ type Conn struct {
*sqlite
interrupt context.Context
waiter chan struct{}
pending *Stmt
arena arena
@@ -48,6 +46,8 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
return newConn(filename, flags)
}
type connKey struct{}
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
sqlite, err := instantiateSQLite()
if err != nil {
@@ -63,6 +63,7 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
c := &Conn{sqlite: sqlite}
c.arena = c.newArena(1024)
c.ctx = context.WithValue(c.ctx, connKey{}, c)
c.handle, err = c.openDB(filename, flags)
if err != nil {
return nil, err
@@ -131,7 +132,6 @@ func (c *Conn) Close() error {
return nil
}
c.SetInterrupt(context.Background())
c.pending.Close()
c.pending = nil
@@ -244,65 +244,40 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
return ctx
}
// Is a waiter running?
if c.waiter != nil {
c.waiter <- struct{}{} // Cancel the waiter.
<-c.waiter // Wait for it to finish.
c.waiter = nil
}
// Reset the pending statement.
if c.pending != nil {
// An uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
if c.pending == nil {
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
} else {
c.pending.Reset()
}
old = c.interrupt
c.interrupt = ctx
// Remove the handler if the context can't be canceled.
if ctx == nil || ctx.Done() == nil {
c.call(c.api.progressHandler, uint64(c.handle), 0)
return old
}
// Creating an uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
if c.pending == nil {
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
}
c.pending.Step()
// Don't create the goroutine if we're already interrupted.
// This happens frequently while restoring to a previously interrupted state.
if c.checkInterrupt() {
return old
}
waiter := make(chan struct{})
c.waiter = waiter
go func() {
select {
case <-waiter: // Waiter was cancelled.
break
case <-ctx.Done(): // Done was closed.
const isInterruptedOffset = 288
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
// Wait for the next call to SetInterrupt.
<-waiter
}
// Signal that the waiter has finished.
waiter <- struct{}{}
}()
c.call(c.api.progressHandler, uint64(c.handle), 100)
return old
}
func (c *Conn) checkInterrupt() bool {
if c.interrupt == nil || c.interrupt.Err() == nil {
return false
func callbackProgress(ctx context.Context, mod api.Module, _ uint32) uint32 {
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
if c.interrupt != nil && c.interrupt.Err() != nil {
return 1
}
}
return 0
}
func (c *Conn) checkInterrupt() {
if c.interrupt != nil && c.interrupt.Err() != nil {
c.call(c.api.interrupt, uint64(c.handle))
}
const isInterruptedOffset = 288
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
return true
}
// Pragma executes a PRAGMA statement and returns any results.
@@ -328,12 +303,9 @@ func (c *Conn) error(rc uint64, sql ...string) error {
// DriverConn is implemented by the SQLite [database/sql] driver connection.
//
// It can be used to access advanced SQLite features like
// [savepoints], [online backup] and [incremental BLOB I/O].
// It can be used to access SQLite features like [online backup].
//
// [savepoints]: https://www.sqlite.org/lang_savepoint.html
// [online backup]: https://www.sqlite.org/backup.html
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
type DriverConn interface {
Raw() *Conn
}

View File

@@ -97,6 +97,7 @@ const (
IOERR_ROLLBACK_ATOMIC ExtendedErrorCode = xErrorCode(IOERR) | (31 << 8)
IOERR_DATA ExtendedErrorCode = xErrorCode(IOERR) | (32 << 8)
IOERR_CORRUPTFS ExtendedErrorCode = xErrorCode(IOERR) | (33 << 8)
IOERR_IN_PAGE ExtendedErrorCode = xErrorCode(IOERR) | (34 << 8)
LOCKED_SHAREDCACHE ExtendedErrorCode = xErrorCode(LOCKED) | (1 << 8)
LOCKED_VTAB ExtendedErrorCode = xErrorCode(LOCKED) | (2 << 8)
BUSY_RECOVERY ExtendedErrorCode = xErrorCode(BUSY) | (1 << 8)

View File

@@ -1,6 +1,7 @@
package sqlite3
import (
"encoding/json"
"errors"
"math"
"time"
@@ -13,24 +14,33 @@ import (
//
// https://www.sqlite.org/c3ref/context.html
type Context struct {
*sqlite
c *Conn
handle uint32
}
// Conn returns the database connection of the
// [Conn.CreateFunction] or [Conn.CreateWindowFunction]
// routines that originally registered the application defined function.
//
// https://sqlite.org/c3ref/context_db_handle.html
func (ctx Context) Conn() *Conn {
return ctx.c
}
// SetAuxData saves metadata for argument n of the function.
//
// https://www.sqlite.org/c3ref/get_auxdata.html
func (c Context) SetAuxData(n int, data any) {
ptr := util.AddHandle(c.ctx, data)
c.call(c.api.setAuxData, uint64(c.handle), uint64(n), uint64(ptr))
func (ctx Context) SetAuxData(n int, data any) {
ptr := util.AddHandle(ctx.c.ctx, data)
ctx.c.call(ctx.c.api.setAuxData, uint64(ctx.handle), uint64(n), uint64(ptr))
}
// GetAuxData returns metadata for argument n of the function.
//
// https://www.sqlite.org/c3ref/get_auxdata.html
func (c Context) GetAuxData(n int) any {
ptr := uint32(c.call(c.api.getAuxData, uint64(c.handle), uint64(n)))
return util.GetHandle(c.ctx, ptr)
func (ctx Context) GetAuxData(n int) any {
ptr := uint32(ctx.c.call(ctx.c.api.getAuxData, uint64(ctx.handle), uint64(n)))
return util.GetHandle(ctx.c.ctx, ptr)
}
// ResultBool sets the result of the function to a bool.
@@ -38,125 +48,160 @@ func (c Context) GetAuxData(n int) any {
// Instead, boolean values are stored as integers 0 (false) and 1 (true).
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultBool(value bool) {
func (ctx Context) ResultBool(value bool) {
var i int64
if value {
i = 1
}
c.ResultInt64(i)
ctx.ResultInt64(i)
}
// ResultInt sets the result of the function to an int.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultInt(value int) {
c.ResultInt64(int64(value))
func (ctx Context) ResultInt(value int) {
ctx.ResultInt64(int64(value))
}
// ResultInt64 sets the result of the function to an int64.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultInt64(value int64) {
c.call(c.api.resultInteger,
uint64(c.handle), uint64(value))
func (ctx Context) ResultInt64(value int64) {
ctx.c.call(ctx.c.api.resultInteger,
uint64(ctx.handle), uint64(value))
}
// ResultFloat sets the result of the function to a float64.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultFloat(value float64) {
c.call(c.api.resultFloat,
uint64(c.handle), math.Float64bits(value))
func (ctx Context) ResultFloat(value float64) {
ctx.c.call(ctx.c.api.resultFloat,
uint64(ctx.handle), math.Float64bits(value))
}
// ResultText sets the result of the function to a string.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultText(value string) {
ptr := c.newString(value)
c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(value)),
uint64(c.api.destructor), _UTF8)
func (ctx Context) ResultText(value string) {
ptr := ctx.c.newString(value)
ctx.c.call(ctx.c.api.resultText,
uint64(ctx.handle), uint64(ptr), uint64(len(value)),
uint64(ctx.c.api.destructor), _UTF8)
}
// ResultBlob sets the result of the function to a []byte.
// Returning a nil slice is the same as calling [Context.ResultNull].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultBlob(value []byte) {
ptr := c.newBytes(value)
c.call(c.api.resultBlob,
uint64(c.handle), uint64(ptr), uint64(len(value)),
uint64(c.api.destructor))
func (ctx Context) ResultBlob(value []byte) {
ptr := ctx.c.newBytes(value)
ctx.c.call(ctx.c.api.resultBlob,
uint64(ctx.handle), uint64(ptr), uint64(len(value)),
uint64(ctx.c.api.destructor))
}
// BindZeroBlob sets the result of the function to a zero-filled, length n BLOB.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultZeroBlob(n int64) {
c.call(c.api.resultZeroBlob,
uint64(c.handle), uint64(n))
func (ctx Context) ResultZeroBlob(n int64) {
ctx.c.call(ctx.c.api.resultZeroBlob,
uint64(ctx.handle), uint64(n))
}
// ResultNull sets the result of the function to NULL.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultNull() {
c.call(c.api.resultNull,
uint64(c.handle))
func (ctx Context) ResultNull() {
ctx.c.call(ctx.c.api.resultNull,
uint64(ctx.handle))
}
// ResultTime sets the result of the function to a [time.Time].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultTime(value time.Time, format TimeFormat) {
func (ctx Context) ResultTime(value time.Time, format TimeFormat) {
if format == TimeFormatDefault {
c.resultRFC3339Nano(value)
ctx.resultRFC3339Nano(value)
return
}
switch v := format.Encode(value).(type) {
case string:
c.ResultText(v)
ctx.ResultText(v)
case int64:
c.ResultInt64(v)
ctx.ResultInt64(v)
case float64:
c.ResultFloat(v)
ctx.ResultFloat(v)
default:
panic(util.AssertErr())
}
}
func (c Context) resultRFC3339Nano(value time.Time) {
const maxlen = uint64(len(time.RFC3339Nano))
func (ctx Context) resultRFC3339Nano(value time.Time) {
const maxlen = uint64(len(time.RFC3339Nano)) + 5
ptr := c.new(maxlen)
buf := util.View(c.mod, ptr, maxlen)
ptr := ctx.c.new(maxlen)
buf := util.View(ctx.c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(buf)),
uint64(c.api.destructor), _UTF8)
ctx.c.call(ctx.c.api.resultText,
uint64(ctx.handle), uint64(ptr), uint64(len(buf)),
uint64(ctx.c.api.destructor), _UTF8)
}
// ResultPointer sets the result of the function to NULL, just like [Context.ResultNull],
// except that it also associates ptr with that NULL value such that it can be retrieved
// within an application-defined SQL function using [Value.Pointer].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultPointer(ptr any) {
valPtr := util.AddHandle(ctx.c.ctx, ptr)
ctx.c.call(ctx.c.api.resultPointer, uint64(valPtr))
}
// ResultJSON sets the result of the function to the JSON encoding of value.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultJSON(value any) {
data, err := json.Marshal(value)
if err != nil {
ctx.ResultError(err)
}
ptr := ctx.c.newBytes(data)
ctx.c.call(ctx.c.api.resultText,
uint64(ctx.handle), uint64(ptr), uint64(len(data)),
uint64(ctx.c.api.destructor))
}
// ResultValue sets the result of the function a copy of [Value].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultValue(value Value) {
if value.sqlite != ctx.c.sqlite {
ctx.ResultError(MISUSE)
}
ctx.c.call(ctx.c.api.resultValue,
uint64(ctx.handle), uint64(value.handle))
}
// ResultError sets the result of the function an error.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultError(err error) {
func (ctx Context) ResultError(err error) {
if errors.Is(err, NOMEM) {
c.call(c.api.resultErrorMem, uint64(c.handle))
ctx.c.call(ctx.c.api.resultErrorMem, uint64(ctx.handle))
return
}
if errors.Is(err, TOOBIG) {
c.call(c.api.resultErrorBig, uint64(c.handle))
ctx.c.call(ctx.c.api.resultErrorBig, uint64(ctx.handle))
return
}
str := err.Error()
ptr := c.newString(str)
c.call(c.api.resultError,
uint64(c.handle), uint64(ptr), uint64(len(str)))
c.free(ptr)
ptr := ctx.c.newString(str)
ctx.c.call(ctx.c.api.resultError,
uint64(ctx.handle), uint64(ptr), uint64(len(str)))
ctx.c.free(ptr)
var code uint64
var ecode ErrorCode
@@ -168,7 +213,7 @@ func (c Context) ResultError(err error) {
code = uint64(ecode)
}
if code != 0 {
c.call(c.api.resultErrorCode,
uint64(c.handle), code)
ctx.c.call(ctx.c.api.resultErrorCode,
uint64(ctx.handle), code)
}
}

View File

@@ -30,6 +30,8 @@ import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"io"
"net/url"
@@ -55,8 +57,8 @@ func init() {
//
// The init function is called by the driver on new connections.
// The conn can be used to execute queries, register functions, etc.
// Any error return closes the conn and passes the error to database/sql.
func Open(dataSourceName string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*sql.DB, error) {
// Any error return closes the conn and passes the error to [database/sql].
func Open(dataSourceName string, init func(*sqlite3.Conn) error) (*sql.DB, error) {
c, err := newConnector(dataSourceName, init)
if err != nil {
return nil, err
@@ -78,7 +80,7 @@ func (sqlite) OpenConnector(name string) (driver.Connector, error) {
return newConnector(name, nil)
}
func newConnector(name string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*connector, error) {
func newConnector(name string, init func(*sqlite3.Conn) error) (*connector, error) {
c := connector{name: name, init: init}
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
@@ -94,7 +96,7 @@ func newConnector(name string, init func(ctx context.Context, conn *sqlite3.Conn
}
type connector struct {
init func(ctx context.Context, conn *sqlite3.Conn) error
init func(*sqlite3.Conn) error
name string
txlock string
pragmas bool
@@ -132,27 +134,24 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
if err != nil {
return nil, err
}
c.reusable = true
} else {
s, _, err := c.Conn.Prepare(`
SELECT * FROM
PRAGMA_locking_mode,
PRAGMA_query_only;
`)
if err != nil {
return nil, err
}
if s.Step() {
c.reusable = s.ColumnText(0) == "normal"
c.readOnly = s.ColumnRawText(1)[0] // 0 or 1
}
err = s.Close()
}
if n.init != nil {
err = n.init(c.Conn)
if err != nil {
return nil, err
}
}
if n.init != nil {
err = n.init(ctx, c.Conn)
if n.pragmas || n.init != nil {
s, _, err := c.Conn.Prepare(`PRAGMA query_only`)
if err != nil {
return nil, err
}
if s.Step() && s.ColumnBool(0) {
c.readOnly = '1'
} else {
c.readOnly = '0'
}
err = s.Close()
if err != nil {
return nil, err
}
@@ -165,7 +164,6 @@ type conn struct {
txBegin string
txCommit string
txRollback string
reusable bool
readOnly byte
}
@@ -174,7 +172,6 @@ var (
_ driver.ConnPrepareContext = &conn{}
_ driver.ExecerContext = &conn{}
_ driver.ConnBeginTx = &conn{}
_ driver.Validator = &conn{}
_ sqlite3.DriverConn = &conn{}
)
@@ -182,10 +179,6 @@ func (c *conn) Raw() *sqlite3.Conn {
return c.Conn
}
func (c *conn) IsValid() bool {
return c.reusable
}
func (c *conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
@@ -199,10 +192,10 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
txBegin = `
BEGIN deferred;
PRAGMA query_only=on`
c.txCommit = `
c.txRollback = `
ROLLBACK;
PRAGMA query_only=` + string(c.readOnly)
c.txRollback = c.txCommit
c.txCommit = c.txRollback
}
switch opts.Isolation {
@@ -233,7 +226,13 @@ func (c *conn) Commit() error {
}
func (c *conn) Rollback() error {
return c.Conn.Exec(c.txRollback)
err := c.Conn.Exec(c.txRollback)
if errors.Is(err, sqlite3.INTERRUPT) {
old := c.Conn.SetInterrupt(context.Background())
defer c.Conn.SetInterrupt(old)
err = c.Conn.Exec(c.txRollback)
}
return err
}
func (c *conn) Prepare(query string) (driver.Stmt, error) {
@@ -270,6 +269,12 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
return nil, driver.ErrSkip
}
if savept, ok := ctx.(*saveptCtx); ok {
// Called from driver.Savepoint.
savept.Savepoint = c.Savepoint()
return resultRowsAffected(0), nil
}
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
@@ -281,6 +286,10 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
return newResult(c.Conn), nil
}
func (*conn) CheckNamedValue(arg *driver.NamedValue) error {
return nil
}
type stmt struct {
Stmt *sqlite3.Stmt
Conn *sqlite3.Conn
@@ -377,8 +386,12 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
err = s.Stmt.BindBlob(id, a)
case sqlite3.ZeroBlob:
err = s.Stmt.BindZeroBlob(id, int64(a))
case interface{ Value() any }:
err = s.Stmt.BindPointer(id, a.Value())
case time.Time:
err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
case json.Marshaler:
err = s.Stmt.BindJSON(id, a)
case nil:
err = s.Stmt.BindNull(id)
default:
@@ -395,7 +408,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) {
case bool, int, int64, float64, string, []byte,
sqlite3.ZeroBlob, time.Time, nil:
sqlite3.ZeroBlob, interface{ Value() any },
time.Time, json.Marshaler, nil:
return nil
default:
return driver.ErrSkip
@@ -474,11 +488,7 @@ func (r *rows) Next(dest []driver.Value) error {
case sqlite3.TEXT:
dest[i] = stringOrTime(r.Stmt.ColumnRawText(i))
case sqlite3.NULL:
if buf, ok := dest[i].([]byte); ok {
dest[i] = buf[0:0]
} else {
dest[i] = nil
}
dest[i] = nil
default:
panic(util.AssertErr())
}

65
driver/json_test.go Normal file
View File

@@ -0,0 +1,65 @@
package driver_test
import (
"fmt"
"log"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
func Example_json() {
db, err := driver.Open("file:/test.db?vfs=memdb", nil)
if err != nil {
log.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`
CREATE TABLE orders (
cart_id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
cart TEXT
);
`)
if err != nil {
log.Fatal(err)
}
type CartItem struct {
ItemID string `json:"id"`
Name string `json:"name"`
Quantity int `json:"quantity,omitempty"`
Price int `json:"price,omitempty"`
}
type Cart struct {
Items []CartItem `json:"items"`
}
_, err = db.Exec(`INSERT INTO orders (user_id, cart) VALUES (?, ?)`, 123, sqlite3.JSON(Cart{
[]CartItem{
{ItemID: "111", Name: "T-shirt", Quantity: 1, Price: 250},
{ItemID: "222", Name: "Trousers", Quantity: 1, Price: 600},
},
}))
if err != nil {
log.Fatal(err)
}
var total string
err = db.QueryRow(`
SELECT total(json_each.value -> 'price')
FROM orders, json_each(cart -> 'items')
WHERE cart_id = last_insert_rowid()
`).Scan(&total)
if err != nil {
log.Fatal(err)
}
fmt.Println("total:", total)
// Output:
// total: 850
}

27
driver/savepoint.go Normal file
View File

@@ -0,0 +1,27 @@
package driver
import (
"database/sql"
"time"
"github.com/ncruces/go-sqlite3"
)
// Savepoint establishes a new transaction savepoint.
//
// https://www.sqlite.org/lang_savepoint.html
func Savepoint(tx *sql.Tx) sqlite3.Savepoint {
var ctx saveptCtx
tx.ExecContext(&ctx, "")
return ctx.Savepoint
}
type saveptCtx struct{ sqlite3.Savepoint }
func (*saveptCtx) Deadline() (deadline time.Time, ok bool) { return }
func (*saveptCtx) Done() <-chan struct{} { return nil }
func (*saveptCtx) Err() error { return nil }
func (*saveptCtx) Value(key any) any { return nil }

87
driver/savepoint_test.go Normal file
View File

@@ -0,0 +1,87 @@
package driver_test
import (
"fmt"
"log"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
func ExampleSavepoint() {
db, err := driver.Open("file:/test.db?vfs=memdb", nil)
if err != nil {
log.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
err = func() error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
stmt, err := tx.Prepare(`INSERT INTO users (id, name) VALUES (?, ?)`)
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(0, "go")
if err != nil {
return err
}
_, err = stmt.Exec(1, "zig")
if err != nil {
return err
}
savept := driver.Savepoint(tx)
_, err = stmt.Exec(2, "whatever")
if err != nil {
return err
}
err = savept.Rollback()
if err != nil {
return err
}
_, err = stmt.Exec(3, "rust")
if err != nil {
return err
}
return tx.Commit()
}()
if err != nil {
log.Fatal(err)
}
rows, err := db.Query(`SELECT id, name FROM users`)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var id, name string
err = rows.Scan(&id, &name)
if err != nil {
log.Fatal(err)
}
fmt.Printf("%s %s\n", id, name)
}
// Output:
// 0 go
// 1 zig
// 3 rust
}

View File

@@ -1,75 +0,0 @@
package sqlite3_test
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
var db *sql.DB
func ExampleDriverConn() {
var err error
db, err = sql.Open("sqlite3", "demo.db")
if err != nil {
log.Fatal(err)
}
defer os.Remove("demo.db")
defer db.Close()
ctx := context.Background()
conn, err := db.Conn(ctx)
if err != nil {
log.Fatal(err)
}
defer conn.Close()
_, err = conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
log.Fatal(err)
}
res, err := conn.ExecContext(ctx, `INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(11))
if err != nil {
log.Fatal(err)
}
id, err := res.LastInsertId()
if err != nil {
log.Fatal(err)
}
err = conn.Raw(func(driverConn any) error {
conn := driverConn.(sqlite3.DriverConn).Raw()
savept := conn.Savepoint()
defer savept.Release(&err)
blob, err := conn.OpenBlob("main", "test", "col", id, true)
if err != nil {
return err
}
defer blob.Close()
_, err = fmt.Fprint(blob, "Hello BLOB!")
return err
})
if err != nil {
log.Fatal(err)
}
var msg string
err = conn.QueryRowContext(ctx, `SELECT col FROM test`).Scan(&msg)
if err != nil {
log.Fatal(err)
}
fmt.Println(msg)
// Output:
// Hello BLOB!
}

View File

@@ -1,6 +1,6 @@
# Embeddable WASM build of SQLite
This folder includes an embeddable WASM build of SQLite 3.43.1 for use with
This folder includes an embeddable WASM build of SQLite 3.44.0 for use with
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
The following optional features are compiled in:

View File

@@ -13,6 +13,8 @@ sqlite3_finalize
sqlite3_reset
sqlite3_step
sqlite3_exec
sqlite3_interrupt
sqlite3_progress_handler_go
sqlite3_clear_bindings
sqlite3_bind_parameter_count
sqlite3_bind_parameter_index
@@ -23,6 +25,7 @@ sqlite3_bind_double
sqlite3_bind_text64
sqlite3_bind_blob64
sqlite3_bind_zeroblob64
sqlite3_bind_pointer_go
sqlite3_column_count
sqlite3_column_name
sqlite3_column_type
@@ -62,12 +65,15 @@ sqlite3_value_double
sqlite3_value_text
sqlite3_value_blob
sqlite3_value_bytes
sqlite3_value_pointer_go
sqlite3_result_null
sqlite3_result_int64
sqlite3_result_double
sqlite3_result_text64
sqlite3_result_blob64
sqlite3_result_zeroblob64
sqlite3_result_pointer_go
sqlite3_result_value
sqlite3_result_error
sqlite3_result_error_code
sqlite3_result_error_nomem

Binary file not shown.

59
ext/blob/blob.go Normal file
View File

@@ -0,0 +1,59 @@
// Package blob provides an alternative interface to incremental BLOB I/O.
package blob
import (
"errors"
"github.com/ncruces/go-sqlite3"
)
// Register registers the blob_open SQL function.
func Register(db *sqlite3.Conn) {
db.CreateFunction("blob_open", -1,
sqlite3.DETERMINISTIC|sqlite3.DIRECTONLY, openBlob)
}
func openBlob(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) < 6 {
ctx.ResultError(errors.New("wrong number of arguments to function blob_open()"))
return
}
row := arg[3].Int64()
var err error
blob, ok := ctx.GetAuxData(0).(*sqlite3.Blob)
if ok {
err = blob.Reopen(row)
if errors.Is(err, sqlite3.MISUSE) {
// Blob was closed (db, table or column changed).
ok = false
}
}
if !ok {
db := arg[0].Text()
table := arg[1].Text()
column := arg[2].Text()
write := arg[4].Bool()
blob, err = ctx.Conn().OpenBlob(db, table, column, row, write)
}
if err != nil {
ctx.ResultError(err)
return
}
fn := arg[5].Pointer().(OpenCallback)
err = fn(blob, arg[6:]...)
if err != nil {
ctx.ResultError(err)
return
}
// This ensures the blob is closed if db, table or column change.
ctx.SetAuxData(0, blob)
ctx.SetAuxData(1, blob)
ctx.SetAuxData(2, blob)
}
type OpenCallback func(*sqlite3.Blob, ...sqlite3.Value) error

61
ext/blob/blob_test.go Normal file
View File

@@ -0,0 +1,61 @@
package blob_test
import (
"io"
"log"
"os"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/blob"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
func Example() {
// Open the database, registering the extension.
db, err := driver.Open("file:/test.db?vfs=memdb", func(conn *sqlite3.Conn) error {
blob.Register(conn)
return nil
})
if err != nil {
log.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
log.Fatal(err)
}
const message = "Hello BLOB!"
// Create the BLOB.
_, err = db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(len(message)))
if err != nil {
log.Fatal(err)
}
// Write the BLOB.
_, err = db.Exec(`SELECT blob_open('main', 'test', 'col', last_insert_rowid(), true, ?)`,
sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
_, err = io.WriteString(blob, message)
return err
}))
if err != nil {
log.Fatal(err)
}
// Read the BLOB.
_, err = db.Exec(`SELECT blob_open('main', 'test', 'col', rowid, false, ?) FROM test`,
sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
_, err = io.Copy(os.Stdout, blob)
return err
}))
if err != nil {
log.Fatal(err)
}
// Output:
// Hello BLOB!
}

71
func.go
View File

@@ -4,7 +4,6 @@ import (
"context"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
@@ -47,6 +46,7 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
// If fn returns a [WindowFunction], then an aggregate window function is created.
// If fn returns an [io.Closer], it will be called to free resources.
//
// https://www.sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
@@ -70,7 +70,7 @@ type AggregateFunction interface {
// The function arguments, if any, corresponding to the row being added are passed to Step.
Step(ctx Context, arg ...Value)
// Value is invoked to return the current value of the aggregate.
// Value is invoked to return the current (or final) value of the aggregate.
Value(ctx Context)
}
@@ -85,17 +85,6 @@ type WindowFunction interface {
Inverse(ctx Context, arg ...Value)
}
func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
util.ExportFuncVIII(env, "go_func", callbackFunc)
util.ExportFuncVIII(env, "go_step", callbackStep)
util.ExportFuncVI(env, "go_final", callbackFinal)
util.ExportFuncVI(env, "go_value", callbackValue)
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
return env
}
func callbackDestroy(ctx context.Context, mod api.Module, pApp uint32) {
util.DelHandle(ctx, pApp)
}
@@ -106,57 +95,57 @@ func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nK
}
func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackHandle(sqlite, pCtx).(func(ctx Context, arg ...Value))
fn(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
db := ctx.Value(connKey{}).(*Conn)
fn := callbackHandle(db, pCtx).(func(ctx Context, arg ...Value))
fn(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
}
func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
fn.Step(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
db := ctx.Value(connKey{}).(*Conn)
fn := callbackAggregate(db, pCtx, nil).(AggregateFunction)
fn.Step(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
}
func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) {
var handle uint32
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, &handle).(AggregateFunction)
fn.Value(Context{sqlite, pCtx})
db := ctx.Value(connKey{}).(*Conn)
fn := callbackAggregate(db, pCtx, &handle).(AggregateFunction)
fn.Value(Context{db, pCtx})
if err := util.DelHandle(ctx, handle); err != nil {
Context{sqlite, pCtx}.ResultError(err)
Context{db, pCtx}.ResultError(err)
}
}
func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
fn.Value(Context{sqlite, pCtx})
db := ctx.Value(connKey{}).(*Conn)
fn := callbackAggregate(db, pCtx, nil).(AggregateFunction)
fn.Value(Context{db, pCtx})
}
func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, nil).(WindowFunction)
fn.Inverse(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
db := ctx.Value(connKey{}).(*Conn)
fn := callbackAggregate(db, pCtx, nil).(WindowFunction)
fn.Inverse(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
}
func callbackHandle(sqlite *sqlite, pCtx uint32) any {
pApp := uint32(sqlite.call(sqlite.api.userData, uint64(pCtx)))
return util.GetHandle(sqlite.ctx, pApp)
func callbackHandle(db *Conn, pCtx uint32) any {
pApp := uint32(db.call(db.api.userData, uint64(pCtx)))
return util.GetHandle(db.ctx, pApp)
}
func callbackAggregate(sqlite *sqlite, pCtx uint32, close *uint32) any {
func callbackAggregate(db *Conn, pCtx uint32, close *uint32) any {
// On close, we're getting rid of the handle.
// Don't allocate space to store it.
var size uint64
if close == nil {
size = ptrlen
}
ptr := uint32(sqlite.call(sqlite.api.aggregateCtx, uint64(pCtx), size))
ptr := uint32(db.call(db.api.aggregateCtx, uint64(pCtx), size))
// Try loading the handle, if we already have one, or want a new one.
if ptr != 0 || size != 0 {
if handle := util.ReadUint32(sqlite.mod, ptr); handle != 0 {
fn := util.GetHandle(sqlite.ctx, handle)
if handle := util.ReadUint32(db.mod, ptr); handle != 0 {
fn := util.GetHandle(db.ctx, handle)
if close != nil {
*close = handle
}
@@ -167,19 +156,19 @@ func callbackAggregate(sqlite *sqlite, pCtx uint32, close *uint32) any {
}
// Create a new aggregate and store the handle.
fn := callbackHandle(sqlite, pCtx).(func() AggregateFunction)()
fn := callbackHandle(db, pCtx).(func() AggregateFunction)()
if ptr != 0 {
util.WriteUint32(sqlite.mod, ptr, util.AddHandle(sqlite.ctx, fn))
util.WriteUint32(db.mod, ptr, util.AddHandle(db.ctx, fn))
}
return fn
}
func callbackArgs(sqlite *sqlite, nArg, pArg uint32) []Value {
func callbackArgs(db *Conn, nArg, pArg uint32) []Value {
args := make([]Value, nArg)
for i := range args {
args[i] = Value{
sqlite: sqlite,
handle: util.ReadUint32(sqlite.mod, pArg+ptrlen*uint32(i)),
sqlite: db.sqlite,
handle: util.ReadUint32(db.mod, pArg+ptrlen*uint32(i)),
}
}
return args

8
go.mod
View File

@@ -3,12 +3,12 @@ module github.com/ncruces/go-sqlite3
go 1.21
require (
github.com/ncruces/julianday v0.1.5
github.com/ncruces/julianday v1.0.0
github.com/psanford/httpreadat v0.1.0
github.com/tetratelabs/wazero v1.5.0
golang.org/x/sync v0.3.0
golang.org/x/sys v0.12.0
golang.org/x/text v0.13.0
golang.org/x/sync v0.5.0
golang.org/x/sys v0.14.0
golang.org/x/text v0.14.0
)
retract v0.4.0 // tagged from the wrong branch

16
go.sum
View File

@@ -1,12 +1,12 @@
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=

View File

@@ -1,4 +1,4 @@
github.com/ncruces/go-sqlite3 v0.9.0/go.mod h1:IyRoNwT0Z+mNRXIVeP2DgWPNl78Kmc/B+pO9i6GNgRg=
github.com/ncruces/go-sqlite3 v0.9.1/go.mod h1:jFoUbaCDNUS1KN5ZgFxN7bgcWoWfO0EOKeik9QAHZ08=
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=

View File

@@ -125,7 +125,7 @@ func parseDDL(strs ...string) (*ddl, error) {
ColumnTypeValue: sql.NullString{String: matches[2], Valid: true},
PrimaryKeyValue: sql.NullBool{Valid: true},
UniqueValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Bool: true, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
}
@@ -175,6 +175,18 @@ func parseDDL(strs ...string) (*ddl, error) {
return &result, nil
}
func (d *ddl) clone() *ddl {
copied := new(ddl)
*copied = *d
copied.fields = make([]string, len(d.fields))
copy(copied.fields, d.fields)
copied.columns = make([]migrator.ColumnType, len(d.columns))
copy(copied.columns, d.columns)
return copied
}
func (d *ddl) compile() string {
if len(d.fields) == 0 {
return d.head
@@ -183,6 +195,21 @@ func (d *ddl) compile() string {
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
}
func (d *ddl) renameTable(dst, src string) error {
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + regexp.QuoteMeta(src) + "\\b('|`|\")?\\s*")
if err != nil {
return err
}
replaced := tableReg.ReplaceAllString(d.head, fmt.Sprintf(" `%s` ", dst))
if replaced == d.head {
return fmt.Errorf("failed to look up tablename `%s` from DDL head '%s'", src, d.head)
}
d.head = replaced
return nil
}
func (d *ddl) addConstraint(name string, sql string) {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
@@ -208,6 +235,18 @@ func (d *ddl) removeConstraint(name string) bool {
return false
}
//lint:ignore U1000 ignore unused code.
func (d *ddl) hasConstraint(name string) bool {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
for _, f := range d.fields {
if reg.MatchString(f) {
return true
}
}
return false
}
func (d *ddl) getColumns() []string {
res := []string{}
@@ -229,3 +268,30 @@ func (d *ddl) getColumns() []string {
}
return res
}
func (d *ddl) alterColumn(name, sql string) bool {
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields[i] = sql
return false
}
}
d.fields = append(d.fields, sql)
return true
}
func (d *ddl) removeColumn(name string) bool {
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields = append(d.fields[:i], d.fields[i+1:]...)
return true
}
}
return false
}

View File

@@ -20,16 +20,16 @@ func TestParseDDL(t *testing.T) {
"CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)",
}, 6, []migrator.ColumnType{
{NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}},
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
},
{"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
}},
{"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
@@ -39,7 +39,7 @@ func TestParseDDL(t *testing.T) {
{"with_special_characters", []string{
"CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")",
}, 1, []migrator.ColumnType{
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
},
{
@@ -122,7 +122,7 @@ func TestParseDDL_Whitespaces(t *testing.T) {
NameValue: sql.NullString{String: "id", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
NullableValue: sql.NullBool{Bool: true, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true},
@@ -131,7 +131,7 @@ func TestParseDDL_Whitespaces(t *testing.T) {
NameValue: sql.NullString{String: "dark_mode", Valid: true},
DataTypeValue: sql.NullString{String: "numeric", Valid: true},
ColumnTypeValue: sql.NullString{String: "numeric", Valid: true},
NullableValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Bool: true, Valid: true},
DefaultValueValue: sql.NullString{String: "true", Valid: true},
UniqueValue: sql.NullBool{Bool: false, Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},

View File

@@ -7,7 +7,7 @@ import (
"gorm.io/gorm"
)
func (dialector Dialector) Translate(err error) error {
func (_Dialector) Translate(err error) error {
switch {
case
errors.Is(err, sqlite3.CONSTRAINT_UNIQUE),

View File

@@ -3,8 +3,8 @@ module github.com/ncruces/go-sqlite3/gormlite
go 1.21
require (
github.com/ncruces/go-sqlite3 v0.9.0
gorm.io/gorm v1.25.4
github.com/ncruces/go-sqlite3 v0.9.1
gorm.io/gorm v1.25.5
)
require (
@@ -12,5 +12,5 @@ require (
github.com/jinzhu/now v1.1.5 // indirect
github.com/ncruces/julianday v0.1.5 // indirect
github.com/tetratelabs/wazero v1.5.0 // indirect
golang.org/x/sys v0.12.0 // indirect
golang.org/x/sys v0.13.0 // indirect
)

View File

@@ -2,15 +2,15 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/ncruces/go-sqlite3 v0.9.0 h1:tl5eEmGEyzZH2ur8sDgPJTdzV4CRnKpsFngoP1QRjD8=
github.com/ncruces/go-sqlite3 v0.9.0/go.mod h1:IyRoNwT0Z+mNRXIVeP2DgWPNl78Kmc/B+pO9i6GNgRg=
github.com/ncruces/go-sqlite3 v0.9.1 h1:kV7Zy+ZNyHMfMyZeWc1Yyq+wtgYZDZdp2qAA/wfeMWo=
github.com/ncruces/go-sqlite3 v0.9.1/go.mod h1:jFoUbaCDNUS1KN5ZgFxN7bgcWoWfO0EOKeik9QAHZ08=
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
gorm.io/gorm v1.25.4 h1:iyNd8fNAe8W9dvtlgeRI5zSVZPsq3OpcTu37cYcpCmw=
gorm.io/gorm v1.25.4/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=

View File

@@ -3,7 +3,6 @@ package gormlite
import (
"database/sql"
"fmt"
"regexp"
"strings"
"gorm.io/gorm"
@@ -12,11 +11,11 @@ import (
"gorm.io/gorm/schema"
)
type Migrator struct {
type _Migrator struct {
migrator.Migrator
}
func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
func (m *_Migrator) RunWithoutForeignKey(fc func() error) error {
var enabled int
m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
if enabled == 1 {
@@ -27,7 +26,7 @@ func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
return fc()
}
func (m Migrator) HasTable(value interface{}) bool {
func (m _Migrator) HasTable(value interface{}) bool {
var count int
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
@@ -35,7 +34,7 @@ func (m Migrator) HasTable(value interface{}) bool {
return count > 0
}
func (m Migrator) DropTable(values ...interface{}) error {
func (m _Migrator) DropTable(values ...interface{}) error {
return m.RunWithoutForeignKey(func() error {
values = m.ReorderModels(values, false)
tx := m.DB.Session(&gorm.Session{})
@@ -52,11 +51,11 @@ func (m Migrator) DropTable(values ...interface{}) error {
})
}
func (m Migrator) GetTables() (tableList []string, err error) {
func (m _Migrator) GetTables() (tableList []string, err error) {
return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error
}
func (m Migrator) HasColumn(value interface{}, name string) bool {
func (m _Migrator) HasColumn(value interface{}, name string) bool {
var count int
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
@@ -76,31 +75,24 @@ func (m Migrator) HasColumn(value interface{}, name string) bool {
return count > 0
}
func (m Migrator) AlterColumn(value interface{}, name string) error {
func (m _Migrator) AlterColumn(value interface{}, name string) error {
return m.RunWithoutForeignKey(func() error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
if field := stmt.Schema.LookUpField(name); field != nil {
// lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)`
reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)")
if err != nil {
return "", nil, err
if ddl.alterColumn(field.DBName, fmt.Sprintf("`%s` ?", field.DBName)) {
return nil, nil, fmt.Errorf("field `%s` not found in origin ddl, ddl= '%s'", name, ddl.compile())
}
createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName))
if createSQL == rawDDL {
return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL)
}
return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil
return ddl, []interface{}{m.FullDataTypeOf(field)}, nil
}
return "", nil, fmt.Errorf("failed to alter field with name %v", name)
return nil, nil, fmt.Errorf("failed to alter field with name `%s`", name)
})
})
}
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
func (m _Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
var (
@@ -148,29 +140,23 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
return columnTypes, execErr
}
func (m Migrator) DropColumn(value interface{}, name string) error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
func (m _Migrator) DropColumn(value interface{}, name string) error {
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,")
if err != nil {
return "", nil, err
}
createSQL := reg.ReplaceAllString(rawDDL, "")
return createSQL, nil, nil
ddl.removeColumn(name)
return ddl, nil, nil
})
}
func (m Migrator) CreateConstraint(value interface{}, name string) error {
func (m _Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
return m.recreateTable(value, &table,
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
var (
constraintName string
constraintSql string
@@ -185,22 +171,16 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
constraintSql = "CONSTRAINT ? CHECK (?)"
constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
} else {
return "", nil, nil
return nil, nil, nil
}
createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
}
createDDL.addConstraint(constraintName, constraintSql)
createSQL := createDDL.compile()
return createSQL, constraintValues, nil
ddl.addConstraint(constraintName, constraintSql)
return ddl, constraintValues, nil
})
})
}
func (m Migrator) DropConstraint(value interface{}, name string) error {
func (m _Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil {
@@ -210,20 +190,14 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
}
return m.recreateTable(value, &table,
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
}
createDDL.removeConstraint(name)
createSQL := createDDL.compile()
return createSQL, nil, nil
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
ddl.removeConstraint(name)
return ddl, nil, nil
})
})
}
func (m Migrator) HasConstraint(value interface{}, name string) bool {
func (m _Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
@@ -244,13 +218,13 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
return count > 0
}
func (m Migrator) CurrentDatabase() (name string) {
func (m _Migrator) CurrentDatabase() (name string) {
var null interface{}
m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
return
}
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
func (m _Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
if opt.Expression != "" {
@@ -269,7 +243,7 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem
return
}
func (m Migrator) CreateIndex(value interface{}, name string) error {
func (m _Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
@@ -298,7 +272,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
})
}
func (m Migrator) HasIndex(value interface{}, name string) bool {
func (m _Migrator) HasIndex(value interface{}, name string) bool {
var count int
m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
@@ -317,7 +291,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
return count > 0
}
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
func (m _Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
var sql string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
@@ -331,7 +305,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error
})
}
func (m Migrator) DropIndex(value interface{}, name string) error {
func (m _Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
@@ -365,7 +339,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter
return
}
func (m Migrator) getRawDDL(table string) (string, error) {
func (m _Migrator) getRawDDL(table string) (string, error) {
var createSQL string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
@@ -375,8 +349,10 @@ func (m Migrator) getRawDDL(table string) (string, error) {
return createSQL, nil
}
func (m Migrator) recreateTable(value interface{}, tablePtr *string,
getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error {
func (m _Migrator) recreateTable(
value interface{}, tablePtr *string,
getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error),
) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
table := stmt.Table
if tablePtr != nil {
@@ -388,27 +364,26 @@ func (m Migrator) recreateTable(value interface{}, tablePtr *string,
return err
}
newTableName := table + "__temp"
createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt)
originDDL, err := parseDDL(rawDDL)
if err != nil {
return err
}
if createSQL == "" {
createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt)
if err != nil {
return err
}
if createDDL == nil {
return nil
}
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
if err != nil {
newTableName := table + "__temp"
if err := createDDL.renameTable(newTableName, table); err != nil {
return err
}
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
createDDL, err := parseDDL(createSQL)
if err != nil {
return err
}
columns := createDDL.getColumns()
createSQL := createDDL.compile()
return m.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {

View File

@@ -3,6 +3,7 @@ package gormlite
import (
"context"
"database/sql"
"strconv"
"gorm.io/gorm"
@@ -15,20 +16,26 @@ import (
"github.com/ncruces/go-sqlite3/driver"
)
type Dialector struct {
// Open opens a GORM dialector from a data source name.
func Open(dsn string) gorm.Dialector {
return &_Dialector{DSN: dsn}
}
// Open opens a GORM dialector from a database handle.
func OpenDB(db *sql.DB) gorm.Dialector {
return &_Dialector{Conn: db}
}
type _Dialector struct {
DSN string
Conn gorm.ConnPool
}
func Open(dsn string) gorm.Dialector {
return &Dialector{DSN: dsn}
}
func (dialector Dialector) Name() string {
func (dialector _Dialector) Name() string {
return "sqlite"
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
func (dialector _Dialector) Initialize(db *gorm.DB) (err error) {
if dialector.Conn != nil {
db.ConnPool = dialector.Conn
} else {
@@ -47,7 +54,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
if compareVersion(version, "3.35.0") >= 0 {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE", "RETURNING"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
LastInsertIDReversed: true,
})
@@ -63,7 +70,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
return
}
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
func (dialector _Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
return map[string]clause.ClauseBuilder{
"INSERT": func(c clause.Clause, builder clause.Builder) {
if insert, ok := c.Expression.(clause.Insert); ok {
@@ -112,7 +119,7 @@ func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
}
}
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
func (dialector _Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
if field.AutoIncrement {
return clause.Expr{SQL: "NULL"}
}
@@ -121,19 +128,19 @@ func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression
return clause.Expr{SQL: "DEFAULT"}
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{
func (dialector _Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return _Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
CreateIndexAfterCreateTable: true,
}}}
}
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
func (dialector _Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteByte('?')
}
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
func (dialector _Dialector) QuoteTo(writer clause.Writer, str string) {
var (
underQuoted, selfQuoted bool
continuousBacktick int8
@@ -181,16 +188,17 @@ func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteString("`")
}
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
func (dialector _Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, nil, `"`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
func (dialector _Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType {
case schema.Bool:
return "numeric"
case schema.Int, schema.Uint:
if field.AutoIncrement && !field.PrimaryKey {
if field.AutoIncrement {
// doesn't check `PrimaryKey`, to keep backward compatibility
// https://www.sqlite.org/autoinc.html
return "integer PRIMARY KEY AUTOINCREMENT"
} else {
@@ -214,12 +222,12 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
return string(field.DataType)
}
func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error {
func (dialectopr _Dialector) SavePoint(tx *gorm.DB, name string) error {
tx.Exec("SAVEPOINT " + name)
return nil
}
func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
func (dialectopr _Dialector) RollbackTo(tx *gorm.DB, name string) error {
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
return nil
}

View File

@@ -1,7 +1,6 @@
package gormlite
import (
"context"
"fmt"
"testing"
@@ -17,7 +16,7 @@ func TestDialector(t *testing.T) {
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
// Custom connection with a custom function called "my_custom_function".
conn, err := driver.Open(InMemoryDSN, func(ctx context.Context, conn *sqlite3.Conn) error {
db, err := driver.Open(InMemoryDSN, func(conn *sqlite3.Conn) error {
return conn.CreateFunction("my_custom_function", 0, sqlite3.DETERMINISTIC,
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultText("my-result")
@@ -29,43 +28,35 @@ func TestDialector(t *testing.T) {
rows := []struct {
description string
dialector *Dialector
dialector gorm.Dialector
openSuccess bool
query string
querySuccess bool
}{
{
description: "Default driver",
dialector: &Dialector{
DSN: InMemoryDSN,
},
description: "Default driver",
dialector: Open(InMemoryDSN),
openSuccess: true,
query: "SELECT 1",
querySuccess: true,
},
{
description: "Custom function",
dialector: &Dialector{
DSN: InMemoryDSN,
},
description: "Custom function",
dialector: Open(InMemoryDSN),
openSuccess: true,
query: "SELECT my_custom_function()",
querySuccess: false,
},
{
description: "Custom connection",
dialector: &Dialector{
Conn: conn,
},
description: "Custom connection",
dialector: OpenDB(db),
openSuccess: true,
query: "SELECT 1",
querySuccess: true,
},
{
description: "Custom connection, custom function",
dialector: &Dialector{
Conn: conn,
},
description: "Custom connection, custom function",
dialector: OpenDB(db),
openSuccess: true,
query: "SELECT my_custom_function()",
querySuccess: true,

View File

@@ -72,6 +72,7 @@ const (
IOERR_ROLLBACK_ATOMIC = IOERR | (31 << 8)
IOERR_DATA = IOERR | (32 << 8)
IOERR_CORRUPTFS = IOERR | (33 << 8)
IOERR_IN_PAGE = IOERR | (34 << 8)
LOCKED_SHAREDCACHE = LOCKED | (1 << 8)
LOCKED_VTAB = LOCKED | (2 << 8)
BUSY_RECOVERY = BUSY | (1 << 8)

View File

@@ -23,6 +23,7 @@ const (
OffsetErr = ErrorString("sqlite3: invalid offset")
TailErr = ErrorString("sqlite3: multiple statements")
IsolationErr = ErrorString("sqlite3: unsupported isolation level")
ValueErr = ErrorString("sqlite3: unsupported value")
NoVFSErr = ErrorString("sqlite3: no such vfs: ")
)

View File

@@ -14,6 +14,9 @@ func View(mod api.Module, ptr uint32, size uint64) []byte {
if size > math.MaxUint32 {
panic(RangeErr)
}
if size == 0 {
return nil
}
buf, ok := mod.Memory().Read(ptr, uint32(size))
if !ok {
panic(RangeErr)

56
json.go Normal file
View File

@@ -0,0 +1,56 @@
package sqlite3
import (
"encoding/json"
"strconv"
"time"
"unsafe"
"github.com/ncruces/go-sqlite3/internal/util"
)
// JSON returns:
// a [json.Marshaler] that can be used as an argument to
// [database/sql.DB.Exec] and similar methods to
// store value as JSON; and
// a [database/sql.Scanner] that can be used as an argument to
// [database/sql.Row.Scan] and similar methods to
// decode JSON into value.
func JSON(value any) any {
return jsonValue{value}
}
type jsonValue struct{ any }
func (j jsonValue) MarshalJSON() ([]byte, error) {
return json.Marshal(j.any)
}
func (j jsonValue) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, j.any)
}
func (j jsonValue) Scan(value any) error {
var buf []byte
switch v := value.(type) {
case []byte:
buf = v
case string:
buf = unsafe.Slice(unsafe.StringData(v), len(v))
case int64:
buf = strconv.AppendInt(nil, v, 10)
case float64:
buf = strconv.AppendFloat(nil, v, 'g', -1, 64)
case time.Time:
buf = append(buf, '"')
buf = v.AppendFormat(buf, time.RFC3339Nano)
buf = append(buf, '"')
case nil:
buf = append(buf, "null"...)
default:
panic(util.AssertErr())
}
return j.UnmarshalJSON(buf)
}

14
pointer.go Normal file
View File

@@ -0,0 +1,14 @@
package sqlite3
// Pointer returns a pointer to a value
// that can be used as an argument to
// [database/sql.DB.Exec] and similar methods.
//
// https://www.sqlite.org/bindptr.html
func Pointer[T any](val T) any {
return pointer[T]{val}
}
type pointer[T any] struct{ val T }
func (p pointer[T]) Value() any { return p.val }

112
quote.go Normal file
View File

@@ -0,0 +1,112 @@
package sqlite3
import (
"bytes"
"math"
"strconv"
"strings"
"time"
"unsafe"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Quote escapes and quotes a value
// making it safe to embed in SQL text.
func Quote(value any) string {
switch v := value.(type) {
case nil:
return "NULL"
case bool:
if v {
return "1"
} else {
return "0"
}
case int:
return strconv.Itoa(v)
case int64:
return strconv.FormatInt(v, 10)
case float64:
switch {
case math.IsNaN(v):
return "NULL"
case math.IsInf(v, 1):
return "9.0e999"
case math.IsInf(v, -1):
return "-9.0e999"
}
return strconv.FormatFloat(v, 'g', -1, 64)
case time.Time:
return "'" + v.Format(time.RFC3339Nano) + "'"
case string:
if strings.IndexByte(v, 0) >= 0 {
break
}
buf := make([]byte, 2+len(v)+strings.Count(v, "'"))
buf[0] = '\''
i := 1
for _, b := range []byte(v) {
if b == '\'' {
buf[i] = b
i += 1
}
buf[i] = b
i += 1
}
buf[i] = '\''
return unsafe.String(&buf[0], len(buf))
case []byte:
buf := make([]byte, 3+2*len(v))
buf[0] = 'x'
buf[1] = '\''
i := 2
for _, b := range v {
const hex = "0123456789ABCDEF"
buf[i+0] = hex[b/16]
buf[i+1] = hex[b%16]
i += 2
}
buf[i] = '\''
return unsafe.String(&buf[0], len(buf))
case ZeroBlob:
if v > ZeroBlob(1e9-3)/2 {
break
}
buf := bytes.Repeat([]byte("0"), int(3+2*int64(v)))
buf[0] = 'x'
buf[1] = '\''
buf[len(buf)-1] = '\''
return unsafe.String(&buf[0], len(buf))
}
panic(util.ValueErr)
}
// QuoteIdentifier escapes and quotes an identifier
// making it safe to embed in SQL text.
func QuoteIdentifier(id string) string {
if strings.IndexByte(id, 0) >= 0 {
panic(util.ValueErr)
}
buf := make([]byte, 2+len(id)+strings.Count(id, `"`))
buf[0] = '"'
i := 1
for _, b := range []byte(id) {
if b == '"' {
buf[i] = b
i += 1
}
buf[i] = b
i += 1
}
buf[i] = '"'
return unsafe.String(&buf[0], len(buf))
}

View File

@@ -22,6 +22,8 @@ import (
var (
Binary []byte // WASM binary to load.
Path string // Path to load the binary from.
RuntimeConfig wazero.RuntimeConfig
)
var instance struct {
@@ -32,12 +34,16 @@ var instance struct {
}
func compileSQLite() {
if RuntimeConfig == nil {
RuntimeConfig = wazero.NewRuntimeConfig()
}
ctx := context.Background()
instance.runtime = wazero.NewRuntime(ctx)
instance.runtime = wazero.NewRuntimeWithConfig(ctx, RuntimeConfig)
env := instance.runtime.NewHostModuleBuilder("env")
env = vfs.ExportHostFunctions(env)
env = exportHostFunctions(env)
env = exportCallbacks(env)
_, instance.err = env.Instantiate(ctx)
if instance.err != nil {
return
@@ -65,8 +71,6 @@ type sqlite struct {
stack [8]uint64
}
type sqliteKey struct{}
func instantiateSQLite() (sqlt *sqlite, err error) {
instance.once.Do(compileSQLite)
if instance.err != nil {
@@ -75,7 +79,6 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
sqlt = new(sqlite)
sqlt.ctx = util.NewContext(context.Background())
sqlt.ctx = context.WithValue(sqlt.ctx, sqliteKey{}, sqlt)
sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
instance.compiled, wazero.NewModuleConfig())
@@ -117,6 +120,8 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
reset: getFun("sqlite3_reset"),
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
interrupt: getFun("sqlite3_interrupt"),
progressHandler: getFun("sqlite3_progress_handler_go"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
@@ -127,6 +132,7 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
bindPointer: getFun("sqlite3_bind_pointer_go"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
@@ -164,12 +170,15 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
valueText: getFun("sqlite3_value_text"),
valueBlob: getFun("sqlite3_value_blob"),
valueBytes: getFun("sqlite3_value_bytes"),
valuePointer: getFun("sqlite3_value_pointer_go"),
resultNull: getFun("sqlite3_result_null"),
resultInteger: getFun("sqlite3_result_int64"),
resultFloat: getFun("sqlite3_result_double"),
resultText: getFun("sqlite3_result_text64"),
resultBlob: getFun("sqlite3_result_blob64"),
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
resultPointer: getFun("sqlite3_result_pointer_go"),
resultValue: getFun("sqlite3_result_value"),
resultError: getFun("sqlite3_result_error"),
resultErrorCode: getFun("sqlite3_result_error_code"),
resultErrorMem: getFun("sqlite3_result_error_nomem"),
@@ -200,13 +209,15 @@ func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if handle != 0 {
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if sql != nil {
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:]
if sql != nil {
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:]
}
}
}
@@ -245,7 +256,7 @@ func (sqlt *sqlite) new(size uint64) uint32 {
}
func (sqlt *sqlite) newBytes(b []byte) uint32 {
if b == nil {
if (*[0]byte)(b) == nil {
return 0
}
ptr := sqlt.new(uint64(len(b)))
@@ -333,6 +344,8 @@ type sqliteAPI struct {
reset api.Function
step api.Function
exec api.Function
interrupt api.Function
progressHandler api.Function
clearBindings api.Function
bindCount api.Function
bindIndex api.Function
@@ -343,6 +356,7 @@ type sqliteAPI struct {
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
bindPointer api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
@@ -380,15 +394,30 @@ type sqliteAPI struct {
valueText api.Function
valueBlob api.Function
valueBytes api.Function
valuePointer api.Function
resultNull api.Function
resultInteger api.Function
resultFloat api.Function
resultText api.Function
resultBlob api.Function
resultZeroBlob api.Function
resultPointer api.Function
resultValue api.Function
resultError api.Function
resultErrorCode api.Function
resultErrorMem api.Function
resultErrorBig api.Function
destructor uint32
}
func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncII(env, "go_progress", callbackProgress)
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
util.ExportFuncVIII(env, "go_func", callbackFunc)
util.ExportFuncVIII(env, "go_step", callbackStep)
util.ExportFuncVI(env, "go_final", callbackFinal)
util.ExportFuncVI(env, "go_value", callbackValue)
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
return env
}

View File

@@ -3,7 +3,7 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3430100.zip"
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3440000.zip"
unzip -d . sqlite-amalgamation-*.zip
mv sqlite-amalgamation-*/sqlite3* .
rm -rf sqlite-amalgamation-*
@@ -12,24 +12,24 @@ cat *.patch | patch --posix
mkdir -p ext/
cd ext/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/uuid.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/series.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/anycollseq.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/uuid.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/series.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/ext/misc/anycollseq.c"
cd ~-
cd ../vfs/tests/mptest/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/mptest.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/config02.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/crash01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/crash02.subtest"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/multiwrite01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/mptest.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/config02.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/crash01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/crash02.subtest"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/mptest/multiwrite01.test"
cd ~-
cd ../vfs/tests/speedtest1/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/test/speedtest1.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.44.0/test/speedtest1.c"
cd ~-

View File

@@ -1,4 +1,4 @@
#include <string.h>
#include <stddef.h>
#include "sqlite3.h"
@@ -18,14 +18,15 @@ int sqlite3_create_collation_go(sqlite3 *db, const char *zName, void *pApp) {
int sqlite3_create_function_go(sqlite3 *db, const char *zName, int nArg,
int flags, void *pApp) {
return sqlite3_create_function_v2(db, zName, nArg, SQLITE_UTF8 | flags, pApp,
go_func, NULL, NULL, go_destroy);
go_func, /*step=*/NULL, /*final=*/NULL,
go_destroy);
}
int sqlite3_create_aggregate_function_go(sqlite3 *db, const char *zName,
int nArg, int flags, void *pApp) {
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
pApp, go_step, go_final, NULL, NULL,
go_destroy);
pApp, go_step, go_final, /*value=*/NULL,
/*inverse=*/NULL, go_destroy);
}
int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg,
@@ -38,3 +39,17 @@ int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg,
void sqlite3_set_auxdata_go(sqlite3_context *ctx, int iArg, void *pAux) {
sqlite3_set_auxdata(ctx, iArg, pAux, go_destroy);
}
#define GO_POINTER_TYPE "github.com/ncruces/go-sqlite3.Pointer"
int sqlite3_bind_pointer_go(sqlite3_stmt *stmt, int i, void *pApp) {
return sqlite3_bind_pointer(stmt, i, pApp, GO_POINTER_TYPE, go_destroy);
}
void sqlite3_result_pointer_go(sqlite3_context *ctx, void *pApp) {
sqlite3_result_pointer(ctx, pApp, GO_POINTER_TYPE, go_destroy);
}
void *sqlite3_value_pointer_go(sqlite3_value *val) {
return sqlite3_value_pointer(val, GO_POINTER_TYPE);
}

34
sqlite3/isoweek.patch Normal file
View File

@@ -0,0 +1,34 @@
# ISO week date specifiers.
# https://sqlite.org/forum/forumpost/73d99e4497e8e6a7
--- sqlite3.c.orig
+++ sqlite3.c
@@ -1373,6 +1373,29 @@ static void strftimeFunc(
sqlite3_str_appendchar(&sRes, 1, c);
break;
}
+ case 'V': /* Fall thru */
+ case 'G': {
+ DateTime y = x;
+ computeJD(&y);
+ y.validYMD = 0;
+ /* Adjust date to Thursday this week:
+ The number in parentheses is 0 for Monday, 3 for Thursday */
+ y.iJD += (3 - (((y.iJD+43200000)/86400000) % 7))*86400000;
+ computeYMD(&y);
+ if( cf=='G' ){
+ sqlite3_str_appendf(&sRes,"%04d",y.Y);
+ }else{
+ int nDay; /* Number of days since 1st day of year */
+ i64 tJD = y.iJD;
+ y.validJD = 0;
+ y.M = 1;
+ y.D = 1;
+ computeJD(&y);
+ nDay = (int)((tJD-y.iJD+43200000)/86400000);
+ sqlite3_str_appendf(&sRes,"%02d",nDay/7+1);
+ }
+ break;
+ }
case 'Y': {
sqlite3_str_appendf(&sRes,"%04d",x.Y);
break;

View File

@@ -1,6 +1,3 @@
#include <stdbool.h>
#include <stddef.h>
// Amalgamation
#include "sqlite3.c"
// VFS
@@ -14,6 +11,7 @@
#include "ext/uint.c"
#include "ext/uuid.c"
#include "func.c"
#include "progress.c"
#include "time.c"
__attribute__((constructor)) void init() {
@@ -25,4 +23,4 @@ __attribute__((constructor)) void init() {
sqlite3_auto_extension((void (*)(void))sqlite3_uint_init);
sqlite3_auto_extension((void (*)(void))sqlite3_uuid_init);
sqlite3_auto_extension((void (*)(void))sqlite3_time_init);
}
}

9
sqlite3/progress.c Normal file
View File

@@ -0,0 +1,9 @@
#include <stddef.h>
#include "sqlite3.h"
int go_progress(void *);
void sqlite3_progress_handler_go(sqlite3 *db, int n) {
sqlite3_progress_handler(db, n, go_progress, /*arg=*/NULL);
}

View File

@@ -5,12 +5,28 @@
#define SQLITE_OS_OTHER 1
#define SQLITE_BYTEORDER 1234
#define HAVE_INT8_T 1
#define HAVE_INT16_T 1
#define HAVE_INT32_T 1
#define HAVE_INT64_T 1
#define HAVE_UINT8_T 1
#define HAVE_UINT16_T 1
#define HAVE_UINT32_T 1
#define HAVE_UINT64_T 1
#define HAVE_STDINT_H 1
#define HAVE_INTTYPES_H 1
#define HAVE_LOG2 1
#define HAVE_LOG10 1
#define HAVE_ISNAN 1
#define HAVE_USLEEP 1
#define HAVE_NANOSLEEP 1
#define HAVE_GMTIME_R 1
#define HAVE_LOCALTIME_S 1
#define HAVE_MALLOC_H 1
#define HAVE_MALLOC_USABLE_SIZE 1
// Recommended Options
@@ -23,7 +39,6 @@
#define SQLITE_MAX_EXPR_DEPTH 0
#define SQLITE_OMIT_DECLTYPE
#define SQLITE_OMIT_DEPRECATED
#define SQLITE_OMIT_PROGRESS_CALLBACK
#define SQLITE_OMIT_SHARED_CACHE
#define SQLITE_OMIT_AUTOINIT
#define SQLITE_USE_ALLOCA
@@ -52,11 +67,11 @@
#define SQLITE_ENABLE_RTREE 1
#define SQLITE_ENABLE_GEOPOLY 1
#define SQLITE_SOUNDEX
// Session Extension
// #define SQLITE_ENABLE_SESSION
// #define SQLITE_ENABLE_PREUPDATE_HOOK
#define SQLITE_SOUNDEX
// Implemented in vfs.c.
int localtime_s(struct tm *const pTm, time_t const *const pTime);

View File

@@ -1,3 +1,4 @@
#include <stddef.h>
#include <string.h>
#include "sqlite3.h"
@@ -26,7 +27,63 @@ static int time_collation(void *pArg, int nKey1, const void *pKey1, int nKey2,
return rc;
}
static void json_time_func(sqlite3_context *context, int argc,
sqlite3_value **argv) {
DateTime x;
if (isDate(context, argc, argv, &x)) return;
if (x.tzSet && x.tz) {
x.iJD += x.tz * 60000;
if (!validJulianDay(x.iJD)) return;
x.validYMD = 0;
x.validHMS = 0;
}
computeYMD_HMS(&x);
sqlite3 *db = sqlite3_context_db_handle(context);
sqlite3_str *res = sqlite3_str_new(db);
sqlite3_str_appendf(res, "%04d-%02d-%02dT%02d:%02d:%02d", //
x.Y, x.M, x.D, //
x.h, x.m, (int)(x.iJD / 1000 % 60));
if (x.useSubsec) {
int rem = x.iJD % 1000;
if (rem) {
sqlite3_str_appendchar(res, 1, '.');
sqlite3_str_appendchar(res, 1, '0' + rem / 100);
if ((rem %= 100)) {
sqlite3_str_appendchar(res, 1, '0' + rem / 10);
if ((rem %= 10)) {
sqlite3_str_appendchar(res, 1, '0' + rem);
}
}
}
}
if (x.tz) {
sqlite3_str_appendf(res, "%+03d:%02d", x.tz / 60, abs(x.tz) % 60);
} else {
sqlite3_str_appendchar(res, 1, 'Z');
}
int rc = sqlite3_str_errcode(res);
if (rc) {
sqlite3_result_error_code(context, rc);
return;
}
int n = sqlite3_str_length(res);
sqlite3_result_text(context, sqlite3_str_finish(res), n, sqlite3_free);
}
int sqlite3_time_init(sqlite3 *db, char **pzErrMsg,
const sqlite3_api_routines *pApi) {
return sqlite3_create_collation(db, "time", SQLITE_UTF8, 0, time_collation);
sqlite3_create_collation_v2(db, "time", SQLITE_UTF8, /*arg=*/NULL,
time_collation,
/*destroy=*/NULL);
sqlite3_create_function_v2(
db, "json_time", -1,
SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS, /*arg=*/NULL,
json_time_func, /*step=*/NULL, /*final=*/NULL, /*destroy=*/NULL);
return SQLITE_OK;
}

45
sqlite3/timezone.patch Normal file
View File

@@ -0,0 +1,45 @@
# Set UTC timezone, compute local offset.
--- sqlite3.c.orig
+++ sqlite3.c
@@ -340,6 +340,7 @@ static int setDateTimeToCurrent(sqlite3_context *context, DateTime *p){
p->iJD = sqlite3StmtCurrentTime(context);
if( p->iJD>0 ){
p->validJD = 1;
+ p->tzSet = 1;
return 0;
}else{
return 1;
@@ -355,6 +356,7 @@ static int setDateTimeToCurrent(sqlite3_context *context, DateTime *p){
static void setRawDateNumber(DateTime *p, double r){
p->s = r;
p->rawS = 1;
+ p->tzSet = 1;
if( r>=0.0 && r<5373484.5 ){
p->iJD = (sqlite3_int64)(r*86400000.0 + 0.5);
p->validJD = 1;
@@ -731,7 +733,16 @@ static int parseModifier(
** show local time.
*/
if( sqlite3_stricmp(z, "localtime")==0 && sqlite3NotPureFunc(pCtx) ){
- rc = toLocaltime(p, pCtx);
+ if( p->tzSet!=0 || p->tz==0 ) {
+ rc = toLocaltime(p, pCtx);
+ i64 iOrigJD = p->iJD;
+ p->tzSet = 0;
+ computeJD(p);
+ p->tz = (p->iJD-iOrigJD)/60000;
+ if( abs(p->tz)>= 900 ) p->tz = 0;
+ } else {
+ rc = 0;
+ }
}
break;
}
@@ -781,6 +792,7 @@ static int parseModifier(
p->validJD = 1;
p->tzSet = 1;
}
+ p->tz = 0;
rc = SQLITE_OK;
}
#endif

View File

@@ -1,3 +1,5 @@
#include <stdbool.h>
#include <stddef.h>
#include <time.h>
#include "sqlite3.h"
@@ -90,22 +92,25 @@ int localtime_s(struct tm *const pTm, time_t const *const pTime) {
sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
if (zVfsName) {
static sqlite3_vfs *go_vfs_list;
sqlite3_vfs *found = NULL;
for (sqlite3_vfs **next = &go_vfs_list; *next;) {
sqlite3_vfs *it = *next;
for (sqlite3_vfs *it = go_vfs_list; it; it = it->pNext) {
if (!strcmp(zVfsName, it->zName) && go_vfs_find(it->zName)) {
return it;
}
}
for (sqlite3_vfs **ptr = &go_vfs_list; *ptr;) {
sqlite3_vfs *it = *ptr;
if (go_vfs_find(it->zName)) {
if (!strcmp(zVfsName, it->zName)) found = it;
next = &it->pNext;
ptr = &it->pNext;
} else {
*next = it->pNext;
*ptr = it->pNext;
free(it);
}
}
if (found) {
return found;
}
if (go_vfs_find(zVfsName)) {
sqlite3_vfs *prev = go_vfs_list;
sqlite3_vfs *head = go_vfs_list;
go_vfs_list = malloc(sizeof(sqlite3_vfs) + strlen(zVfsName) + 1);
char *name = (char *)(go_vfs_list + 1);
strcpy(name, zVfsName);
@@ -114,7 +119,7 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
.szOsFile = sizeof(struct go_file),
.mxPathname = 512,
.zName = name,
.pNext = prev,
.pNext = head,
.xOpen = go_open_wrapper,
.xDelete = go_delete,
@@ -132,6 +137,5 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
return sqlite3_vfs_find_orig(zVfsName);
}
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");
static_assert(offsetof(sqlite3_vfs, zName) == 16, "Unexpected offset");
static_assert(offsetof(sqlite3, u1.isInterrupted) == 288, "Unexpected offset");
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");

View File

@@ -3,6 +3,7 @@ package sqlite3
import (
"bytes"
"math"
"os"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -56,6 +57,12 @@ func Test_sqlite_new(t *testing.T) {
})
t.Run("_MAX_ALLOCATION_SIZE", func(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
if os.Getenv("CI") != "" {
t.Skip("skipping in CI")
}
defer func() { _ = recover() }()
sqlite.new(_MAX_ALLOCATION_SIZE)
sqlite.new(_MAX_ALLOCATION_SIZE)
@@ -132,6 +139,15 @@ func Test_sqlite_newBytes(t *testing.T) {
if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
ptr = sqlite.newBytes(buf[:0])
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
if got := util.View(sqlite.mod, ptr, 0); got != nil {
t.Errorf("got %q, want nil", got)
}
}
func Test_sqlite_newString(t *testing.T) {

57
stmt.go
View File

@@ -1,7 +1,9 @@
package sqlite3
import (
"encoding/json"
"math"
"strconv"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -235,7 +237,7 @@ func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
}
func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
const maxlen = uint64(len(time.RFC3339Nano))
const maxlen = uint64(len(time.RFC3339Nano)) + 5
ptr := s.c.new(maxlen)
buf := util.View(s.c.mod, ptr, maxlen)
@@ -248,6 +250,35 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
return s.c.error(r)
}
// BindPointer binds a NULL to the prepared statement, just like [Stmt.BindNull],
// but it also associates ptr with that NULL value such that it can be retrieved
// within an application-defined SQL function using [Value.Pointer].
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindPointer(param int, ptr any) error {
valPtr := util.AddHandle(s.c.ctx, ptr)
r := s.c.call(s.c.api.bindPointer,
uint64(s.handle), uint64(param), uint64(valPtr))
return s.c.error(r)
}
// BindJSON binds the JSON encoding of value to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindJSON(param int, value any) error {
data, err := json.Marshal(value)
if err != nil {
return err
}
ptr := s.c.newBytes(data)
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(data)),
uint64(s.c.api.destructor), _UTF8)
return s.c.error(r)
}
// ColumnCount returns the number of columns in a result set.
//
// https://www.sqlite.org/c3ref/column_count.html
@@ -402,6 +433,30 @@ func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte {
return util.View(s.c.mod, ptr, r)
}
// ColumnJSON parses the JSON-encoded value of the result column
// and stores it in the value pointed to by ptr.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnJSON(col int, ptr any) error {
var data []byte
switch s.ColumnType(col) {
case NULL:
data = append(data, "null"...)
case TEXT:
data = s.ColumnRawText(col)
case BLOB:
data = s.ColumnRawBlob(col)
case INTEGER:
data = strconv.AppendInt(nil, s.ColumnInt64(col), 10)
case FLOAT:
data = strconv.AppendFloat(nil, s.ColumnFloat(col), 'g', -1, 64)
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}
// Return true if stmt is an empty SQL statement.
// This is used as an optimization.
// It's OK to always return false here.

View File

@@ -182,7 +182,7 @@ func TestConn_SetInterrupt(t *testing.T) {
defer stmt.Close()
db.SetInterrupt(ctx)
cancel()
go cancel()
// Interrupting works.
err = stmt.Exec()

View File

@@ -25,6 +25,13 @@ func TestDB_file(t *testing.T) {
testDB(t, filepath.Join(t.TempDir(), "test.db"))
}
func TestDB_nolock(t *testing.T) {
t.Parallel()
testDB(t, "file:"+
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
"?nolock=1")
}
func TestDB_wal(t *testing.T) {
t.Parallel()
wal := filepath.Join(t.TempDir(), "test.db")

View File

@@ -36,8 +36,17 @@ func TestCreateFunction(t *testing.T) {
case 7:
ctx.ResultTime(arg.Time(sqlite3.TimeFormatUnix), sqlite3.TimeFormatDefault)
case 8:
ctx.ResultNull()
var v any
if err := arg.JSON(&v); err != nil {
ctx.ResultError(err)
} else {
ctx.ResultJSON(v)
}
case 9:
ctx.ResultValue(arg)
case 10:
ctx.ResultNull()
case 11:
ctx.ResultError(sqlite3.FULL)
}
})
@@ -45,7 +54,7 @@ func TestCreateFunction(t *testing.T) {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0, 9)`)
stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0)`)
if err != nil {
t.Error(err)
}
@@ -123,6 +132,27 @@ func TestCreateFunction(t *testing.T) {
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
var got int
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != 8 {
t.Errorf("got %v, want 8", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnInt64(0); got != 9 {
t.Errorf("got %v, want 9", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)

68
tests/json_test.go Normal file
View File

@@ -0,0 +1,68 @@
package tests
import (
"encoding/json"
"math"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
"github.com/ncruces/julianday"
)
func TestJSON(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", nil)
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
_, err = db.Exec(
`INSERT INTO test (col) VALUES (?), (?), (?), (?)`,
nil, 1, math.Pi, reference,
)
if err != nil {
t.Fatal(err)
}
_, err = db.Exec(
`INSERT INTO test (col) VALUES (?), (?), (?), (?)`,
sqlite3.JSON(math.Pi), sqlite3.JSON(false),
julianday.Format(reference), sqlite3.JSON([]string{}))
if err != nil {
t.Fatal(err)
}
rows, err := db.Query("SELECT * FROM test")
if err != nil {
t.Fatal(err)
}
want := []string{
"null", "1", "3.141592653589793",
`"2013-10-07T04:23:19.12-04:00"`,
"3.141592653589793", "false",
"2456572.849526851851852", "[]",
}
for rows.Next() {
var got json.RawMessage
err = rows.Scan(sqlite3.JSON(&got))
if err != nil {
t.Fatal(err)
}
if string(got) != want[0] {
t.Errorf("got %q, want %q", got, want[0])
}
want = want[1:]
}
}

82
tests/quote_test.go Normal file
View File

@@ -0,0 +1,82 @@
package tests
import (
"math"
"reflect"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
)
func TestQuote(t *testing.T) {
tests := []struct {
val any
want string
}{
{`abc`, "'abc'"},
{`a"bc`, "'a\"bc'"},
{`a'bc`, "'a''bc'"},
{"\x07bc", "'\abc'"},
{"\x1c\n", "'\x1c\n'"},
{[]byte("\xB0\x00\x0B"), "x'B0000B'"},
{"\xB0\x00\x0B", ""},
{0, "0"},
{true, "1"},
{false, "0"},
{nil, "NULL"},
{math.NaN(), "NULL"},
{math.Inf(1), "9.0e999"},
{math.Inf(-1), "-9.0e999"},
{math.Pi, "3.141592653589793"},
{int64(math.MaxInt64), "9223372036854775807"},
{time.Unix(0, 0).UTC(), "'1970-01-01T00:00:00Z'"},
{sqlite3.ZeroBlob(4), "x'00000000'"},
{sqlite3.ZeroBlob(1e9), ""},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
defer func() {
if r := recover(); r != nil && tt.want != "" {
t.Errorf("Quote(%q) = %v", tt.val, r)
}
}()
got := sqlite3.Quote(tt.val)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Quote(%v) = %q, want %q", tt.val, got, tt.want)
}
})
}
}
func TestQuoteIdentifier(t *testing.T) {
tests := []struct {
id string
want string
}{
{`abc`, `"abc"`},
{`a"bc`, `"a""bc"`},
{`a'bc`, `"a'bc"`},
{"\x07bc", "\"\abc\""},
{"\x1c\n", "\"\x1c\n\""},
{"\xB0\x00\x0B", ""},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
defer func() {
if r := recover(); r != nil && tt.want != "" {
t.Errorf("QuoteIdentifier(%q) = %v", tt.id, r)
}
}()
got := sqlite3.QuoteIdentifier(tt.id)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("QuoteIdentifier(%v) = %q, want %q", tt.id, got, tt.want)
}
})
}
}

View File

@@ -1,6 +1,7 @@
package tests
import (
"encoding/json"
"math"
"testing"
"time"
@@ -81,6 +82,13 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
if err := stmt.BindBlob(1, []byte("")); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.BindBlob(1, []byte("blob")); err != nil {
t.Fatal(err)
}
@@ -102,6 +110,13 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
if err := stmt.BindJSON(1, true); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.ClearBindings(); err != nil {
t.Fatal(err)
}
@@ -114,7 +129,7 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
// The table should have: 0, 1, 2, π, NULL, "", "text", "blob", NULL, "\0\0\0\0", NULL
// The table should have: 0, 1, 2, π, NULL, "", "text", "", "blob", NULL, "\0\0\0\0", "true", NULL
stmt, _, err = db.Prepare(`SELECT col FROM test`)
if err != nil {
t.Fatal(err)
@@ -140,6 +155,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "0" {
t.Errorf("got %q, want zero", got)
}
var got int
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != 0 {
t.Errorf("got %v, want zero", got)
}
}
if stmt.Step() {
@@ -161,6 +182,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "1" {
t.Errorf("got %q, want one", got)
}
var got float32
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != 1 {
t.Errorf("got %v, want one", got)
}
}
if stmt.Step() {
@@ -182,6 +209,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "2" {
t.Errorf("got %q, want two", got)
}
var got json.Number
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != "2" {
t.Errorf("got %v, want two", got)
}
}
if stmt.Step() {
@@ -203,6 +236,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "3.14159265358979" {
t.Errorf("got %q, want π", got)
}
var got float64
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != math.Pi {
t.Errorf("got %v, want π", got)
}
}
if stmt.Step() {
@@ -224,6 +263,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != nil {
t.Errorf("got %v, want NULL", got)
}
}
if stmt.Step() {
@@ -245,6 +290,10 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
@@ -266,6 +315,35 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "text" {
t.Errorf(`got %q, want "text"`, got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
@@ -287,6 +365,10 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "blob" {
t.Errorf(`got %q, want "blob"`, got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
@@ -308,6 +390,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != nil {
t.Errorf("got %v, want NULL", got)
}
}
if stmt.Step() {
@@ -329,6 +417,37 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "\x00\x00\x00\x00" {
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "true" {
t.Errorf("got %q, want true", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "true" {
t.Errorf("got %q, want true", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != true {
t.Errorf("got %v, want true", got)
}
}
if stmt.Step() {
@@ -350,6 +469,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != nil {
t.Errorf("got %v, want NULL", got)
}
}
if err := stmt.Close(); err != nil {

View File

@@ -1,11 +1,13 @@
package tests
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
)
func TestTimeFormat_Encode(t *testing.T) {
@@ -39,70 +41,72 @@ func TestTimeFormat_Encode(t *testing.T) {
func TestTimeFormat_Decode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
zone := time.FixedZone("", -4*3600)
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, zone)
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, zone)
tests := []struct {
fmt sqlite3.TimeFormat
val any
want time.Time
wantDelta time.Duration
wantLoc *time.Location
wantErr bool
}{
{sqlite3.TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
{sqlite3.TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
{sqlite3.TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
{sqlite3.TimeFormatJulianDay, false, time.Time{}, 0, true},
{sqlite3.TimeFormatJulianDay, "2456572.849526851851852", reference, 0, time.UTC, false},
{sqlite3.TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, time.UTC, false},
{sqlite3.TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, time.UTC, false},
{sqlite3.TimeFormatJulianDay, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
{sqlite3.TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
{sqlite3.TimeFormatUnix, int64(1381134199), reference, time.Second, false},
{sqlite3.TimeFormatUnix, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnix, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnix, "1381134199.120", reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatUnix, 1381134199.120, reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatUnix, int64(1381134199), reference, time.Second, time.UTC, false},
{sqlite3.TimeFormatUnix, "abc", time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnix, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixMilli, "1381134199120", reference, 0, false},
{sqlite3.TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
{sqlite3.TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
{sqlite3.TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMilli, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMilli, "1381134199120", reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMilli, 1381134199.120e3, reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMilli, int64(1381134199_120), reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMilli, "abc", time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixMilli, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
{sqlite3.TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
{sqlite3.TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
{sqlite3.TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMicro, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMicro, "1381134199120000", reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMicro, 1381134199.120e6, reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixMicro, "abc", time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixMicro, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
{sqlite3.TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
{sqlite3.TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
{sqlite3.TimeFormatUnixNano, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixNano, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixNano, "1381134199120000000", reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixNano, 1381134199.120e9, reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, time.UTC, false},
{sqlite3.TimeFormatUnixNano, "abc", time.Time{}, 0, nil, true},
{sqlite3.TimeFormatUnixNano, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
{sqlite3.TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
{sqlite3.TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199", reference, time.Second, false},
{sqlite3.TimeFormatAuto, "1381134199120", reference, 0, false},
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, false},
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, false},
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, true},
{sqlite3.TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, time.UTC, false},
{sqlite3.TimeFormatAuto, "2456572", reference, 24 * time.Hour, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199.120", reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199", reference, time.Second, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199120", reference, 0, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, time.UTC, false},
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, time.UTC, false},
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, zone, false},
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, zone, false},
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, nil, true},
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, false},
{sqlite3.TimeFormat3, false, time.Time{}, 0, true},
{sqlite3.TimeFormat9, false, time.Time{}, 0, true},
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, zone, false},
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, time.UTC, false},
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, zone, false},
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, time.UTC, false},
{sqlite3.TimeFormat3, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormat9, false, time.Time{}, 0, nil, true},
{sqlite3.TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
{sqlite3.TimeFormatDefault, false, time.Time{}, 0, true},
{sqlite3.TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, zone, false},
{sqlite3.TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, time.UTC, false},
{sqlite3.TimeFormatDefault, false, time.Time{}, 0, nil, true},
}
for _, tt := range tests {
@@ -112,13 +116,48 @@ func TestTimeFormat_Decode(t *testing.T) {
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
return
}
if tt.want.Sub(got).Abs() > tt.wantDelta {
if got.Sub(tt.want).Abs() > tt.wantDelta {
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
}
if got.Location().String() != tt.wantLoc.String() {
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got.Location(), tt.wantLoc)
}
})
}
}
func TestTimeFormat_Scanner(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", nil)
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Exec(
`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
_, err = db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.TimeFormat7TZ.Encode(reference))
if err != nil {
t.Fatal(err)
}
var got time.Time
err = db.QueryRow("SELECT * FROM test").Scan(sqlite3.TimeFormatAuto.Scanner(&got))
if err != nil {
t.Fatal(err)
}
if !got.Equal(reference) {
t.Errorf("got %v, want %v", got, reference)
}
}
func TestDB_timeCollation(t *testing.T) {
t.Parallel()
@@ -167,3 +206,57 @@ func TestDB_timeCollation(t *testing.T) {
t.Fatal(err)
}
}
func TestDB_isoWeek(t *testing.T) {
t.Parallel()
tests := []time.Time{
time.Date(1977, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1977, 1, 2, 0, 0, 0, 0, time.UTC),
time.Date(1977, 12, 31, 0, 0, 0, 0, time.UTC),
time.Date(1978, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1978, 1, 2, 0, 0, 0, 0, time.UTC),
time.Date(1978, 12, 31, 0, 0, 0, 0, time.UTC),
time.Date(1979, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1979, 1, 2, 0, 0, 0, 0, time.UTC),
time.Date(1979, 12, 31, 0, 0, 0, 0, time.UTC),
time.Date(1980, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1980, 12, 28, 0, 0, 0, 0, time.UTC),
time.Date(1980, 12, 29, 0, 0, 0, 0, time.UTC),
time.Date(1980, 12, 30, 0, 0, 0, 0, time.UTC),
time.Date(1980, 12, 31, 0, 0, 0, 0, time.UTC),
time.Date(1981, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1981, 12, 31, 0, 0, 0, 0, time.UTC),
time.Date(1982, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(1982, 1, 2, 0, 0, 0, 0, time.UTC),
time.Date(1982, 1, 3, 0, 0, 0, 0, time.UTC),
}
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT strftime('%G-W%V-%u', ?)`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
for _, tm := range tests {
stmt.BindTime(1, tm, sqlite3.TimeFormatDefault)
if stmt.Step() {
y, w := tm.ISOWeek()
d := tm.Weekday()
if d == 0 {
d = 7
}
want := fmt.Sprintf("%04d-W%02d-%d", y, w, d)
if got := stmt.ColumnText(0); got != want {
t.Errorf("got %q, want %q (%v)", got, want, tm)
}
}
stmt.Reset()
}
}

56
time.go
View File

@@ -164,9 +164,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case float64:
sec, frac := math.Modf(v)
nsec := math.Floor(frac * 1e9)
return time.Unix(int64(sec), int64(nsec)), nil
return time.Unix(int64(sec), int64(nsec)).UTC(), nil
case int64:
return time.Unix(v, 0), nil
return time.Unix(v, 0).UTC(), nil
default:
return time.Time{}, util.TimeErr
}
@@ -181,9 +181,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
switch v := v.(type) {
case float64:
return time.UnixMilli(int64(math.Floor(v))), nil
return time.UnixMilli(int64(math.Floor(v))).UTC(), nil
case int64:
return time.UnixMilli(int64(v)), nil
return time.UnixMilli(int64(v)).UTC(), nil
default:
return time.Time{}, util.TimeErr
}
@@ -198,9 +198,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
switch v := v.(type) {
case float64:
return time.UnixMicro(int64(math.Floor(v))), nil
return time.UnixMicro(int64(math.Floor(v))).UTC(), nil
case int64:
return time.UnixMicro(int64(v)), nil
return time.UnixMicro(int64(v)).UTC(), nil
default:
return time.Time{}, util.TimeErr
}
@@ -215,9 +215,9 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
switch v := v.(type) {
case float64:
return time.Unix(0, int64(math.Floor(v))), nil
return time.Unix(0, int64(math.Floor(v))).UTC(), nil
case int64:
return time.Unix(0, int64(v)), nil
return time.Unix(0, int64(v)).UTC(), nil
default:
return time.Time{}, util.TimeErr
}
@@ -238,26 +238,16 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
dates := []TimeFormat{
TimeFormat6TZ, TimeFormat6, TimeFormat3TZ, TimeFormat3,
TimeFormat5TZ, TimeFormat5, TimeFormat2TZ, TimeFormat2,
TimeFormat1,
TimeFormat9, TimeFormat8,
TimeFormat6, TimeFormat5,
TimeFormat3, TimeFormat2, TimeFormat1,
}
for _, f := range dates {
t, err := time.Parse(string(f), s)
t, err := f.Decode(s)
if err == nil {
return t, nil
}
}
times := []TimeFormat{
TimeFormat9TZ, TimeFormat9, TimeFormat8TZ, TimeFormat8,
}
for _, f := range times {
t, err := time.Parse(string(f), s)
if err == nil {
return t.AddDate(2000, 0, 0), nil
}
}
}
switch v := v.(type) {
case float64:
@@ -314,7 +304,10 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
return time.Time{}, util.TimeErr
}
t, err := f.parseRelaxed(s)
return t.AddDate(2000, 0, 0), err
if err != nil {
return time.Time{}, err
}
return t.AddDate(2000, 0, 0), nil
default:
s, ok := v.(string)
@@ -338,3 +331,20 @@ func (f TimeFormat) parseRelaxed(s string) (time.Time, error) {
}
return t, nil
}
// Scanner returns a [database/sql.Scanner] that can be used as an argument to
// [database/sql.Row.Scan] and similar methods to
// decode a time value into dest using this format.
func (f TimeFormat) Scanner(dest *time.Time) interface{ Scan(any) error } {
return timeScanner{dest, f}
}
type timeScanner struct {
*time.Time
TimeFormat
}
func (s timeScanner) Scan(src any) (err error) {
*s.Time, err = s.Decode(src)
return
}

33
tx.go
View File

@@ -7,6 +7,7 @@ import (
"math/rand"
"runtime"
"strconv"
"strings"
)
// Tx is an in-progress database transaction.
@@ -119,17 +120,8 @@ type Savepoint struct {
//
// https://www.sqlite.org/lang_savepoint.html
func (c *Conn) Savepoint() Savepoint {
name := "sqlite3.Savepoint"
var pc [1]uintptr
if n := runtime.Callers(2, pc[:]); n > 0 {
frames := runtime.CallersFrames(pc[:n])
frame, _ := frames.Next()
if frame.Function != "" {
name = frame.Function
}
}
// Names can be reused; this makes catching bugs more likely.
name += "#" + strconv.Itoa(int(rand.Int31()))
name := saveptName() + "_" + strconv.Itoa(int(rand.Int31()))
err := c.txExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name))
if err != nil {
@@ -138,6 +130,27 @@ func (c *Conn) Savepoint() Savepoint {
return Savepoint{c: c, name: name}
}
func saveptName() (name string) {
defer func() {
if name == "" {
name = "sqlite3.Savepoint"
}
}()
var pc [8]uintptr
n := runtime.Callers(3, pc[:])
if n <= 0 {
return ""
}
frames := runtime.CallersFrames(pc[:n])
frame, more := frames.Next()
for more && (strings.HasPrefix(frame.Function, "database/sql.") ||
strings.HasPrefix(frame.Function, "github.com/ncruces/go-sqlite3/driver.")) {
frame, more = frames.Next()
}
return frame.Function
}
// Release releases the savepoint rolling back any changes
// if *error points to a non-nil error.
//

View File

@@ -1,7 +1,9 @@
package sqlite3
import (
"encoding/json"
"math"
"strconv"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -123,3 +125,31 @@ func (v Value) rawBytes(ptr uint32) []byte {
r := v.call(v.api.valueBytes, uint64(v.handle))
return util.View(v.mod, ptr, r)
}
// Pointer gets the pointer associated with this value,
// or nil if it has no associated pointer.
func (v Value) Pointer() any {
r := v.call(v.api.valuePointer, uint64(v.handle))
return util.GetHandle(v.ctx, uint32(r))
}
// JSON parses a JSON-encoded value
// and stores the result in the value pointed to by ptr.
func (v Value) JSON(ptr any) error {
var data []byte
switch v.Type() {
case NULL:
data = append(data, "null"...)
case TEXT:
data = v.RawText()
case BLOB:
data = v.RawBlob()
case INTEGER:
data = strconv.AppendInt(nil, v.Int64(), 10)
case FLOAT:
data = strconv.AppendFloat(nil, v.Float(), 'g', -1, 64)
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}

View File

@@ -4,28 +4,4 @@ This package implements the SQLite [OS Interface](https://www.sqlite.org/vfs.htm
It replaces the default SQLite VFS with a pure Go implementation.
It also exposes interfaces that should allow you to implement your own custom VFSes.
## Portability
This package is tested on Linux, macOS and Windows,
but it should also work on FreeBSD and illumos
(code paths for those plaforms are tested on macOS and Linux, respectively).
In all platforms for which this package builds,
it should be safe to use it to access databases concurrently,
from multiple goroutines, processes, and
with _other_ implementations of SQLite.
If the package does not build for your platform,
you may try to use the `sqlite3_flock` and `sqlite3_nolock` build tags.
These are only minimally tested and concurrency test failures should be expected.
The `sqlite3_flock` tag uses
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2).
It should be safe to access databases concurrently from multiple goroutines and processes,
but **not** with _other_ implementations of SQLite
(_unless_ these are _also_ configured to use `flock`).
The `sqlite3_nolock` tag uses no locking at all.
Database corruption is the likely result from concurrent write access.
It also exposes interfaces that should allow you to implement your own custom VFSes.

View File

@@ -1,12 +1,6 @@
//go:build !sqlite3_nolock
package vfs
import (
"os"
"github.com/ncruces/go-sqlite3/internal/util"
)
import "github.com/ncruces/go-sqlite3/internal/util"
const (
_PENDING_BYTE = 0x40000000
@@ -134,18 +128,3 @@ func (f *vfsFile) CheckReservedLock() (bool, error) {
}
return osCheckReservedLock(f.File)
}
func osGetReservedLock(file *os.File) _ErrorCode {
// Acquire the RESERVED lock.
return osWriteLock(file, _RESERVED_BYTE, 1, 0)
}
func osGetPendingLock(file *os.File) _ErrorCode {
// Acquire the PENDING lock.
return osWriteLock(file, _PENDING_BYTE, 1, 0)
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
// Test the RESERVED lock.
return osCheckLock(file, _RESERVED_BYTE, 1)
}

View File

@@ -1,22 +0,0 @@
//go:build sqlite3_nolock
package vfs
const (
_PENDING_BYTE = 0x40000000
_RESERVED_BYTE = (_PENDING_BYTE + 1)
_SHARED_FIRST = (_PENDING_BYTE + 2)
_SHARED_SIZE = 510
)
func (f *vfsFile) Lock(lock LockLevel) error {
return nil
}
func (f *vfsFile) Unlock(lock LockLevel) error {
return nil
}
func (f *vfsFile) CheckReservedLock() (bool, error) {
return false, nil
}

View File

@@ -1,4 +1,4 @@
//go:build sqlite3_flock || freebsd
//go:build (freebsd || openbsd || netbsd || dragonfly || sqlite3_flock) && !sqlite3_nosys
package vfs

View File

@@ -1,4 +1,4 @@
//go:build !sqlite3_flock
//go:build !sqlite3_flock && !sqlite3_nosys
package vfs

View File

@@ -1,3 +1,5 @@
//go:build !sqlite3_nosys
package vfs
import (

View File

@@ -1,28 +1,33 @@
//go:build sqlite3_nolock && unix && !(linux || darwin || freebsd || illumos)
//go:build !(linux || darwin || windows || freebsd || openbsd || netbsd || dragonfly || illumos) || sqlite3_nosys
package vfs
import (
"os"
"time"
)
import "os"
func osUnlock(file *os.File, start, len int64) _ErrorCode {
return _OK
func osGetSharedLock(file *os.File) _ErrorCode {
return _IOERR_RDLOCK
}
func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, def _ErrorCode) _ErrorCode {
return _OK
func osGetReservedLock(file *os.File) _ErrorCode {
return _IOERR_LOCK
}
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return _OK
func osGetPendingLock(file *os.File) _ErrorCode {
return _IOERR_LOCK
}
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return _OK
func osGetExclusiveLock(file *os.File) _ErrorCode {
return _IOERR_LOCK
}
func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) {
return false, _OK
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
return _IOERR_RDLOCK
}
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
return _IOERR_UNLOCK
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
return false, _IOERR_CHECKRESERVEDLOCK
}

View File

@@ -1,4 +1,4 @@
//go:build (linux || illumos) && !sqlite3_flock
//go:build (linux || illumos) && !sqlite3_nosys
package vfs

View File

@@ -1,4 +1,4 @@
//go:build !unix
//go:build !unix || sqlite3_nosys
package vfs

View File

@@ -1,4 +1,4 @@
//go:build !linux && (!darwin || sqlite3_flock)
//go:build !(linux || darwin) || sqlite3_flock || sqlite3_nosys
package vfs

View File

@@ -1,4 +1,4 @@
//go:build !unix
//go:build !unix || sqlite3_nosys
package vfs

View File

@@ -1,4 +1,4 @@
//go:build !windows
//go:build !windows || sqlite3_nosys
package vfs

View File

@@ -1,4 +1,4 @@
//go:build !linux && (!darwin || sqlite3_flock)
//go:build !(linux || darwin) || sqlite3_flock || sqlite3_nosys
package vfs

View File

@@ -1,11 +1,10 @@
//go:build unix
//go:build unix && !sqlite3_nosys
package vfs
import (
"os"
"syscall"
"time"
"golang.org/x/sys/unix"
)
@@ -32,60 +31,3 @@ func osSetMode(file *os.File, modeof string) error {
}
return nil
}
func osGetSharedLock(file *os.File) _ErrorCode {
// Test the PENDING lock before acquiring a new SHARED lock.
if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending {
return _BUSY
}
// Acquire the SHARED lock.
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
}
func osGetExclusiveLock(file *os.File) _ErrorCode {
// Acquire the EXCLUSIVE lock.
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, time.Millisecond)
}
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
if state >= LOCK_EXCLUSIVE {
// Downgrade to a SHARED lock.
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return _IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return _IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return osUnlock(file, _PENDING_BYTE, 2)
}
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
// Release all locks.
return osUnlock(file, 0, 0)
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(unix.Errno); ok {
switch errno {
case
unix.EACCES,
unix.EAGAIN,
unix.EBUSY,
unix.EINTR,
unix.ENOLCK,
unix.EDEADLK,
unix.ETIMEDOUT:
return _BUSY
case unix.EPERM:
return _PERM
}
}
return def
}

82
vfs/os_unix2.go Normal file
View File

@@ -0,0 +1,82 @@
//go:build (linux || darwin || freebsd || openbsd || netbsd || dragonfly || illumos) && !sqlite3_nosys
package vfs
import (
"os"
"time"
"golang.org/x/sys/unix"
)
func osGetSharedLock(file *os.File) _ErrorCode {
// Test the PENDING lock before acquiring a new SHARED lock.
if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending {
return _BUSY
}
// Acquire the SHARED lock.
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
}
func osGetReservedLock(file *os.File) _ErrorCode {
// Acquire the RESERVED lock.
return osWriteLock(file, _RESERVED_BYTE, 1, 0)
}
func osGetPendingLock(file *os.File) _ErrorCode {
// Acquire the PENDING lock.
return osWriteLock(file, _PENDING_BYTE, 1, 0)
}
func osGetExclusiveLock(file *os.File) _ErrorCode {
// Acquire the EXCLUSIVE lock.
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, time.Millisecond)
}
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
if state >= LOCK_EXCLUSIVE {
// Downgrade to a SHARED lock.
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return _IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return _IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return osUnlock(file, _PENDING_BYTE, 2)
}
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
// Release all locks.
return osUnlock(file, 0, 0)
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
// Test the RESERVED lock.
return osCheckLock(file, _RESERVED_BYTE, 1)
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(unix.Errno); ok {
switch errno {
case
unix.EACCES,
unix.EAGAIN,
unix.EBUSY,
unix.EINTR,
unix.ENOLCK,
unix.EDEADLK,
unix.ETIMEDOUT:
return _BUSY
case unix.EPERM:
return _PERM
}
}
return def
}

View File

@@ -1,3 +1,5 @@
//go:build !sqlite3_nosys
package vfs
import (
@@ -39,6 +41,16 @@ func osGetSharedLock(file *os.File) _ErrorCode {
return rc
}
func osGetReservedLock(file *os.File) _ErrorCode {
// Acquire the RESERVED lock.
return osWriteLock(file, _RESERVED_BYTE, 1, 0)
}
func osGetPendingLock(file *os.File) _ErrorCode {
// Acquire the PENDING lock.
return osWriteLock(file, _PENDING_BYTE, 1, 0)
}
func osGetExclusiveLock(file *os.File) _ErrorCode {
// Release the SHARED lock.
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
@@ -90,6 +102,11 @@ func osReleaseLock(file *os.File, state LockLevel) _ErrorCode {
return _OK
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
// Test the RESERVED lock.
return osCheckLock(file, _RESERVED_BYTE, 1)
}
func osUnlock(file *os.File, start, len uint32) _ErrorCode {
err := windows.UnlockFileEx(windows.Handle(file.Fd()),
0, len, 0, &windows.Overlapped{Offset: start})

View File

@@ -2,6 +2,7 @@ package mptest
import (
"bytes"
"compress/bzip2"
"context"
"crypto/rand"
"embed"
@@ -24,8 +25,8 @@ import (
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
)
//go:embed testdata/mptest.wasm
var binary []byte
//go:embed testdata/mptest.wasm.bz2
var compressed string
//go:embed testdata/*.*test
var scripts embed.FS
@@ -48,6 +49,11 @@ func TestMain(m *testing.M) {
panic(err)
}
binary, err := io.ReadAll(bzip2.NewReader(strings.NewReader(compressed)))
if err != nil {
panic(err)
}
module, err = rt.CompileModule(ctx, binary)
if err != nil {
panic(err)

View File

@@ -1,2 +1,2 @@
mptest.wasm filter=lfs diff=lfs merge=lfs -text
mptest.wasm.bz2 filter=lfs diff=lfs merge=lfs -text
*.*test -crlf

View File

@@ -28,4 +28,5 @@ WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
--enable-simd --enable-mutable-globals --enable-multivalue \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext
mv mptest.tmp mptest.wasm
mv mptest.tmp mptest.wasm
bzip2 -9f mptest.wasm

View File

@@ -1,5 +1,4 @@
#include <stdbool.h>
#include <stddef.h>
#include <unistd.h>
// Amalgamation
#include "sqlite3.c"
@@ -8,9 +7,8 @@
__attribute__((constructor)) void init() { sqlite3_initialize(); }
static int dont_unlink(const char *pathname) { return 0; }
#define sqlite3_enable_load_extension(...)
#define sqlite3_trace(...)
#define unlink dont_unlink
#define unlink(...) (0)
#undef UNUSED_PARAMETER
#include "mptest.c"

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5b77e9e13a487e976a6e71bc698542098433d1cc586ad8f24784f1f325ffb8dd
size 1459145

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:74294bf19d213056ef5ffb7a980c3a7de5d029d0621ded53394d3055dfc4f604
size 513604

View File

@@ -2,6 +2,7 @@ package speedtest1
import (
"bytes"
"compress/bzip2"
"context"
"crypto/rand"
"flag"
@@ -23,8 +24,8 @@ import (
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
//go:embed testdata/speedtest1.wasm
var binary []byte
//go:embed testdata/speedtest1.wasm.bz2
var compressed string
var (
rt wazero.Runtime
@@ -45,6 +46,11 @@ func TestMain(m *testing.M) {
panic(err)
}
binary, err := io.ReadAll(bzip2.NewReader(strings.NewReader(compressed)))
if err != nil {
panic(err)
}
module, err = rt.CompileModule(ctx, binary)
if err != nil {
panic(err)

View File

@@ -1 +1 @@
speedtest1.wasm filter=lfs diff=lfs merge=lfs -text
speedtest1.wasm.bz2 filter=lfs diff=lfs merge=lfs -text

View File

@@ -23,4 +23,5 @@ WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
--enable-simd --enable-mutable-globals --enable-multivalue \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext
mv speedtest1.tmp speedtest1.wasm
mv speedtest1.tmp speedtest1.wasm
bzip2 -9f speedtest1.wasm

View File

@@ -6,5 +6,5 @@
// VFS
#include "vfs.c"
#define randomFunc(args...) randomFunc2(args)
#define randomFunc randomFunc2
#include "speedtest1.c"

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3b52de3306965ac3f812592be29697d75232802a13bb16a34344f8d81dbf0637
size 1499410

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:966754393264cc43eb931ece22941d0d607e7e776e26c26b548209d2264d01a1
size 527530

View File

@@ -7,6 +7,7 @@ import (
"io/fs"
"math"
"os"
"os/user"
"path/filepath"
"syscall"
"testing"
@@ -208,6 +209,10 @@ func Test_vfsAccess(t *testing.T) {
t.Error("can't access file")
}
if usr, err := user.Current(); err == nil && usr.Uid == "0" {
t.Skip("skipping as root")
}
util.WriteString(mod, 8, file)
rc = vfsAccess(ctx, mod, 0, 8, ACCESS_READWRITE, 4)
if rc != _OK {