mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 22:19:14 +00:00
Compare commits
75 Commits
gormlite/v
...
gormlite/v
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f6d77f3cf4 | ||
|
|
d5d7cd1f2d | ||
|
|
a33a187d48 | ||
|
|
70c6ee15c6 | ||
|
|
994d9b1812 | ||
|
|
b19bd28ed3 | ||
|
|
e66bd51845 | ||
|
|
f5614bc2ed | ||
|
|
d9fcf60b7d | ||
|
|
ac6dd1aa5f | ||
|
|
b1495bd6cb | ||
|
|
2d91760295 | ||
|
|
38d4254bc4 | ||
|
|
c0aa734786 | ||
|
|
fa845dbd3d | ||
|
|
fed315ab79 | ||
|
|
726d7316f7 | ||
|
|
ddb387b021 | ||
|
|
d0f19507f5 | ||
|
|
9d997552ad | ||
|
|
9d75c39dcc | ||
|
|
746a84965e | ||
|
|
312d3b58f2 | ||
|
|
b71cd295c2 | ||
|
|
5b3b61a304 | ||
|
|
d661d15723 | ||
|
|
1e38165ad0 | ||
|
|
58a32d7c9d | ||
|
|
6765e883c1 | ||
|
|
18fc608433 | ||
|
|
77f37893b9 | ||
|
|
f1e36e2581 | ||
|
|
772b9153c7 | ||
|
|
4b280a3a7e | ||
|
|
19b6098bf6 | ||
|
|
2aa685320f | ||
|
|
9941be05c2 | ||
|
|
a0a9ab7737 | ||
|
|
a77727a1ce | ||
|
|
47fe032078 | ||
|
|
bdfe279444 | ||
|
|
a86937a54e | ||
|
|
6ef422fbde | ||
|
|
ff0cb6fb88 | ||
|
|
72db90efdf | ||
|
|
5a3fdef3c5 | ||
|
|
ff34b0cae1 | ||
|
|
f064492bb1 | ||
|
|
1427d30541 | ||
|
|
d3730341f0 | ||
|
|
78ac2386f6 | ||
|
|
632ea933b3 | ||
|
|
0f7fa6ebc9 | ||
|
|
6f7f776488 | ||
|
|
f6d7c5e9c5 | ||
|
|
1cc7ecfe8d | ||
|
|
3844e81404 | ||
|
|
fec1f8d32a | ||
|
|
31572e6095 | ||
|
|
4aee38b957 | ||
|
|
232a7705b5 | ||
|
|
a6c2fccd74 | ||
|
|
6a982559cd | ||
|
|
c7904d30de | ||
|
|
ce4386604d | ||
|
|
26b62c520d | ||
|
|
738714bf32 | ||
|
|
41b020bafc | ||
|
|
d0e720272b | ||
|
|
76171da12b | ||
|
|
dcc845d684 | ||
|
|
f1b42c26d5 | ||
|
|
1e94407ae7 | ||
|
|
eb8d9b95fd | ||
|
|
04037a75ed |
11
.github/workflows/go.yml
vendored
11
.github/workflows/go.yml
vendored
@@ -47,8 +47,12 @@ jobs:
|
||||
- name: Test
|
||||
run: go test -v ./...
|
||||
|
||||
- name: Test no locks
|
||||
run: go test -v -tags sqlite3_nolock .
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
|
||||
- name: Test BSD locks
|
||||
run: go test -v -tags sqlite3_bsd ./...
|
||||
run: go test -v -tags sqlite3_flock ./...
|
||||
if: matrix.os == 'macos-latest'
|
||||
|
||||
- name: Coverage report
|
||||
@@ -56,7 +60,8 @@ jobs:
|
||||
with:
|
||||
chart: 'true'
|
||||
amend: 'true'
|
||||
reuse-go: 'true'
|
||||
if: |
|
||||
matrix.os == 'ubuntu-latest' &&
|
||||
github.event_name == 'push'
|
||||
github.event_name == 'push' &&
|
||||
matrix.os == 'ubuntu-latest'
|
||||
continue-on-error: true
|
||||
|
||||
64
README.md
64
README.md
@@ -7,18 +7,26 @@
|
||||
Go module `github.com/ncruces/go-sqlite3` wraps a [WASM](https://webassembly.org/) build of [SQLite](https://sqlite.org/),
|
||||
and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings.
|
||||
|
||||
- Package [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
|
||||
wraps the [C SQLite API](https://www.sqlite.org/cintro.html)
|
||||
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
|
||||
- Package [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver)
|
||||
provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
|
||||
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
|
||||
- Package [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
|
||||
embeds a build of SQLite into your application.
|
||||
- Package [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
|
||||
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
|
||||
- Package [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
|
||||
provides a [GORM](https://gorm.io) driver.
|
||||
- [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
|
||||
wraps the [C SQLite API](https://www.sqlite.org/cintro.html)
|
||||
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
|
||||
- [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver)
|
||||
provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
|
||||
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
|
||||
- [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
|
||||
embeds a build of SQLite into your application.
|
||||
- [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
|
||||
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
|
||||
- [`github.com/ncruces/go-sqlite3/vfs/memdb`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/memdb)
|
||||
implements an in-memory VFS.
|
||||
- [`github.com/ncruces/go-sqlite3/vfs/readervfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/readervfs)
|
||||
implements a VFS for immutable databases.
|
||||
- [`github.com/ncruces/go-sqlite3/ext/unicode`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/unicode)
|
||||
registers Unicode aware functions.
|
||||
- [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
|
||||
registers [statistics functions](https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html).
|
||||
- [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
|
||||
provides a [GORM](https://gorm.io) driver.
|
||||
|
||||
### Caveats
|
||||
|
||||
@@ -31,35 +39,41 @@ This has benefits, but also comes with some drawbacks.
|
||||
Because WASM does not support shared memory,
|
||||
[WAL](https://www.sqlite.org/wal.html) support is [limited](https://www.sqlite.org/wal.html#noshm).
|
||||
|
||||
To work around this limitation, SQLite is compiled with
|
||||
[`SQLITE_DEFAULT_LOCKING_MODE=1`](https://www.sqlite.org/compile.html#default_locking_mode),
|
||||
making `EXCLUSIVE` the default locking mode.
|
||||
For non-WAL databases, `NORMAL` locking mode can be activated with
|
||||
[`PRAGMA locking_mode=NORMAL`](https://www.sqlite.org/pragma.html#pragma_locking_mode).
|
||||
To work around this limitation, SQLite is [patched](sqlite3/locking_mode.patch)
|
||||
to always use `EXCLUSIVE` locking mode for WAL databases.
|
||||
|
||||
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
|
||||
the `database/sql` driver defaults to `NORMAL` locking mode.
|
||||
To open WAL databases, or use `EXCLUSIVE` locking mode,
|
||||
disable connection pooling by calling
|
||||
to use the [`database/sql`](https://pkg.go.dev/database/sql)
|
||||
driver with WAL mode databases you should disable connection pooling by calling
|
||||
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
|
||||
|
||||
#### POSIX Advisory Locks
|
||||
|
||||
POSIX advisory locks, which SQLite uses, are
|
||||
[broken by design](https://www.sqlite.org/src/artifact/90c4fa?ln=1073-1161).
|
||||
[broken by design](https://www.sqlite.org/src/artifact/2e8b12?ln=1073-1161).
|
||||
|
||||
On Linux, macOS and illumos, this module uses
|
||||
[OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
|
||||
to synchronize access to database files.
|
||||
OFD locks are fully compatible with process-associated POSIX advisory locks.
|
||||
|
||||
On BSD Unixes, this module uses
|
||||
On BSD Unixes, this module may use
|
||||
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2).
|
||||
BSD locks may _not_ be compatible with process-associated POSIX advisory locks.
|
||||
|
||||
##### TL;DR
|
||||
|
||||
In all platforms for which this package builds,
|
||||
it should be safe to use it to access databases concurrently,
|
||||
from multiple goroutines, processes, and
|
||||
with _other_ implementations of SQLite.
|
||||
|
||||
If the package does not build for your platform,
|
||||
see [this](vfs/README.md#portability).
|
||||
|
||||
#### Testing
|
||||
|
||||
The pure Go VFS is tested by running an unmodified build of SQLite's
|
||||
The pure Go VFS is tested by running SQLite's
|
||||
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
|
||||
on Linux, macOS and Windows.
|
||||
Performance is tested by running
|
||||
@@ -68,6 +82,7 @@ Performance is tested by running
|
||||
### Roadmap
|
||||
|
||||
- [ ] advanced SQLite features
|
||||
- [x] custom functions
|
||||
- [x] nested transactions
|
||||
- [x] incremental BLOB I/O
|
||||
- [x] online backup
|
||||
@@ -77,11 +92,10 @@ Performance is tested by running
|
||||
- [x] in-memory VFS
|
||||
- [x] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
|
||||
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
|
||||
- [ ] custom SQL functions
|
||||
|
||||
### Alternatives
|
||||
|
||||
- [`modernc.org/sqlite`](https://pkg.go.dev/modernc.org/sqlite)
|
||||
- [`crawshaw.io/sqlite`](https://pkg.go.dev/crawshaw.io/sqlite)
|
||||
- [`github.com/mattn/go-sqlite3`](https://pkg.go.dev/github.com/mattn/go-sqlite3)
|
||||
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)
|
||||
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)
|
||||
|
||||
@@ -77,7 +77,7 @@ func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string
|
||||
if r == 0 {
|
||||
defer c.closeDB(other)
|
||||
r = c.call(c.api.errcode, uint64(dst))
|
||||
return nil, c.module.error(r, dst)
|
||||
return nil, c.sqlite.error(r, dst)
|
||||
}
|
||||
|
||||
return &Backup{
|
||||
|
||||
42
conn.go
42
conn.go
@@ -2,7 +2,6 @@ package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
@@ -19,7 +18,7 @@ import (
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/sqlite3.html
|
||||
type Conn struct {
|
||||
*module
|
||||
*sqlite
|
||||
|
||||
interrupt context.Context
|
||||
waiter chan struct{}
|
||||
@@ -39,7 +38,7 @@ func Open(filename string) (*Conn, error) {
|
||||
// If none of the required flags is used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used.
|
||||
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
|
||||
//
|
||||
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
|
||||
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)")
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/open.html
|
||||
func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
|
||||
@@ -50,19 +49,19 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
|
||||
}
|
||||
|
||||
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
|
||||
mod, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if conn == nil {
|
||||
mod.close()
|
||||
sqlite.close()
|
||||
} else {
|
||||
runtime.SetFinalizer(conn, util.Finalizer[Conn](3))
|
||||
}
|
||||
}()
|
||||
|
||||
c := &Conn{module: mod}
|
||||
c := &Conn{sqlite: sqlite}
|
||||
c.arena = c.newArena(1024)
|
||||
c.handle, err = c.openDB(filename, flags)
|
||||
if err != nil {
|
||||
@@ -80,7 +79,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
|
||||
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
|
||||
|
||||
handle := util.ReadUint32(c.mod, connPtr)
|
||||
if err := c.module.error(r, handle); err != nil {
|
||||
if err := c.sqlite.error(r, handle); err != nil {
|
||||
c.closeDB(handle)
|
||||
return 0, err
|
||||
}
|
||||
@@ -99,7 +98,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
|
||||
c.arena.reset()
|
||||
pragmaPtr := c.arena.string(pragmas.String())
|
||||
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
|
||||
if err := c.module.error(r, handle, pragmas.String()); err != nil {
|
||||
if err := c.sqlite.error(r, handle, pragmas.String()); err != nil {
|
||||
if errors.Is(err, ERROR) {
|
||||
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
|
||||
}
|
||||
@@ -113,7 +112,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
|
||||
|
||||
func (c *Conn) closeDB(handle uint32) {
|
||||
r := c.call(c.api.closeZombie, uint64(handle))
|
||||
if err := c.module.error(r, handle); err != nil {
|
||||
if err := c.sqlite.error(r, handle); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -143,7 +142,7 @@ func (c *Conn) Close() error {
|
||||
|
||||
c.handle = 0
|
||||
runtime.SetFinalizer(c, nil)
|
||||
return c.module.close()
|
||||
return c.close()
|
||||
}
|
||||
|
||||
// Exec is a convenience function that allows an application to run
|
||||
@@ -240,6 +239,11 @@ func (c *Conn) Changes() int64 {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/interrupt.html
|
||||
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
|
||||
// Is it the same context?
|
||||
if ctx == c.interrupt {
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Is a waiter running?
|
||||
if c.waiter != nil {
|
||||
c.waiter <- struct{}{} // Cancel the waiter.
|
||||
@@ -278,7 +282,7 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
|
||||
break
|
||||
|
||||
case <-ctx.Done(): // Done was closed.
|
||||
const isInterruptedOffset = 280
|
||||
const isInterruptedOffset = 288
|
||||
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
|
||||
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
|
||||
// Wait for the next call to SetInterrupt.
|
||||
@@ -295,7 +299,7 @@ func (c *Conn) checkInterrupt() bool {
|
||||
if c.interrupt == nil || c.interrupt.Err() == nil {
|
||||
return false
|
||||
}
|
||||
const isInterruptedOffset = 280
|
||||
const isInterruptedOffset = 288
|
||||
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
|
||||
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
|
||||
return true
|
||||
@@ -319,7 +323,7 @@ func (c *Conn) Pragma(str string) ([]string, error) {
|
||||
}
|
||||
|
||||
func (c *Conn) error(rc uint64, sql ...string) error {
|
||||
return c.module.error(rc, c.handle, sql...)
|
||||
return c.sqlite.error(rc, c.handle, sql...)
|
||||
}
|
||||
|
||||
// DriverConn is implemented by the SQLite [database/sql] driver connection.
|
||||
@@ -331,15 +335,5 @@ func (c *Conn) error(rc uint64, sql ...string) error {
|
||||
// [online backup]: https://www.sqlite.org/backup.html
|
||||
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
|
||||
type DriverConn interface {
|
||||
driver.Conn
|
||||
driver.ConnBeginTx
|
||||
driver.ExecerContext
|
||||
driver.ConnPrepareContext
|
||||
|
||||
SetInterrupt(ctx context.Context) (old context.Context)
|
||||
|
||||
Savepoint() Savepoint
|
||||
Backup(srcDB, dstURI string) error
|
||||
Restore(dstDB, srcURI string) error
|
||||
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
|
||||
Raw() *Conn
|
||||
}
|
||||
|
||||
20
const.go
20
const.go
@@ -167,6 +167,18 @@ const (
|
||||
PREPARE_NO_VTAB PrepareFlag = 0x04
|
||||
)
|
||||
|
||||
// FunctionFlag is a flag that can be passed to [Conn.PrepareFlags].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_deterministic.html
|
||||
type FunctionFlag uint32
|
||||
|
||||
const (
|
||||
DETERMINISTIC FunctionFlag = 0x000000800
|
||||
DIRECTONLY FunctionFlag = 0x000080000
|
||||
SUBTYPE FunctionFlag = 0x000100000
|
||||
INNOCUOUS FunctionFlag = 0x000200000
|
||||
)
|
||||
|
||||
// Datatype is a fundamental datatype of SQLite.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_blob.html
|
||||
@@ -182,18 +194,18 @@ const (
|
||||
|
||||
// String implements the [fmt.Stringer] interface.
|
||||
func (t Datatype) String() string {
|
||||
const name = "INTEGERFLOATTEXTBLOBNULL"
|
||||
const name = "INTEGERFLOATEXTBLOBNULL"
|
||||
switch t {
|
||||
case INTEGER:
|
||||
return name[0:7]
|
||||
case FLOAT:
|
||||
return name[7:12]
|
||||
case TEXT:
|
||||
return name[12:16]
|
||||
return name[11:15]
|
||||
case BLOB:
|
||||
return name[16:20]
|
||||
return name[15:19]
|
||||
case NULL:
|
||||
return name[20:24]
|
||||
return name[19:23]
|
||||
}
|
||||
return strconv.FormatUint(uint64(t), 10)
|
||||
}
|
||||
|
||||
174
context.go
Normal file
174
context.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Context is the context in which an SQL function executes.
|
||||
// An SQLite [Context] is in no way related to a Go [context.Context].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/context.html
|
||||
type Context struct {
|
||||
*sqlite
|
||||
handle uint32
|
||||
}
|
||||
|
||||
// SetAuxData saves metadata for argument n of the function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/get_auxdata.html
|
||||
func (c Context) SetAuxData(n int, data any) {
|
||||
ptr := util.AddHandle(c.ctx, data)
|
||||
c.call(c.api.setAuxData, uint64(c.handle), uint64(n), uint64(ptr))
|
||||
}
|
||||
|
||||
// GetAuxData returns metadata for argument n of the function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/get_auxdata.html
|
||||
func (c Context) GetAuxData(n int) any {
|
||||
ptr := uint32(c.call(c.api.getAuxData, uint64(c.handle), uint64(n)))
|
||||
return util.GetHandle(c.ctx, ptr)
|
||||
}
|
||||
|
||||
// ResultBool sets the result of the function to a bool.
|
||||
// SQLite does not have a separate boolean storage class.
|
||||
// Instead, boolean values are stored as integers 0 (false) and 1 (true).
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultBool(value bool) {
|
||||
var i int64
|
||||
if value {
|
||||
i = 1
|
||||
}
|
||||
c.ResultInt64(i)
|
||||
}
|
||||
|
||||
// ResultInt sets the result of the function to an int.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultInt(value int) {
|
||||
c.ResultInt64(int64(value))
|
||||
}
|
||||
|
||||
// ResultInt64 sets the result of the function to an int64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultInt64(value int64) {
|
||||
c.call(c.api.resultInteger,
|
||||
uint64(c.handle), uint64(value))
|
||||
}
|
||||
|
||||
// ResultFloat sets the result of the function to a float64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultFloat(value float64) {
|
||||
c.call(c.api.resultFloat,
|
||||
uint64(c.handle), math.Float64bits(value))
|
||||
}
|
||||
|
||||
// ResultText sets the result of the function to a string.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultText(value string) {
|
||||
ptr := c.newString(value)
|
||||
c.call(c.api.resultText,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(value)),
|
||||
uint64(c.api.destructor), _UTF8)
|
||||
}
|
||||
|
||||
// ResultBlob sets the result of the function to a []byte.
|
||||
// Returning a nil slice is the same as calling [Context.ResultNull].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultBlob(value []byte) {
|
||||
ptr := c.newBytes(value)
|
||||
c.call(c.api.resultBlob,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(value)),
|
||||
uint64(c.api.destructor))
|
||||
}
|
||||
|
||||
// BindZeroBlob sets the result of the function to a zero-filled, length n BLOB.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultZeroBlob(n int64) {
|
||||
c.call(c.api.resultZeroBlob,
|
||||
uint64(c.handle), uint64(n))
|
||||
}
|
||||
|
||||
// ResultNull sets the result of the function to NULL.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultNull() {
|
||||
c.call(c.api.resultNull,
|
||||
uint64(c.handle))
|
||||
}
|
||||
|
||||
// ResultTime sets the result of the function to a [time.Time].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultTime(value time.Time, format TimeFormat) {
|
||||
if format == TimeFormatDefault {
|
||||
c.resultRFC3339Nano(value)
|
||||
return
|
||||
}
|
||||
switch v := format.Encode(value).(type) {
|
||||
case string:
|
||||
c.ResultText(v)
|
||||
case int64:
|
||||
c.ResultInt64(v)
|
||||
case float64:
|
||||
c.ResultFloat(v)
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
}
|
||||
|
||||
func (c Context) resultRFC3339Nano(value time.Time) {
|
||||
const maxlen = uint64(len(time.RFC3339Nano))
|
||||
|
||||
ptr := c.new(maxlen)
|
||||
buf := util.View(c.mod, ptr, maxlen)
|
||||
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
|
||||
|
||||
c.call(c.api.resultText,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(buf)),
|
||||
uint64(c.api.destructor), _UTF8)
|
||||
}
|
||||
|
||||
// ResultError sets the result of the function an error.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultError(err error) {
|
||||
if errors.Is(err, NOMEM) {
|
||||
c.call(c.api.resultErrorMem, uint64(c.handle))
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, TOOBIG) {
|
||||
c.call(c.api.resultErrorBig, uint64(c.handle))
|
||||
return
|
||||
}
|
||||
|
||||
str := err.Error()
|
||||
ptr := c.newString(str)
|
||||
c.call(c.api.resultError,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(str)))
|
||||
c.free(ptr)
|
||||
|
||||
var code uint64
|
||||
var ecode ErrorCode
|
||||
var xcode xErrorCode
|
||||
switch {
|
||||
case errors.As(err, &xcode):
|
||||
code = uint64(xcode)
|
||||
case errors.As(err, &ecode):
|
||||
code = uint64(ecode)
|
||||
}
|
||||
if code != 0 {
|
||||
c.call(c.api.resultErrorCode,
|
||||
uint64(c.handle), code)
|
||||
}
|
||||
}
|
||||
147
driver/driver.go
147
driver/driver.go
@@ -14,10 +14,9 @@
|
||||
//
|
||||
// [PRAGMA] statements can be specified using "_pragma":
|
||||
//
|
||||
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
|
||||
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)")
|
||||
//
|
||||
// If no PRAGMAs are specified, a busy timeout of 1 minute
|
||||
// and normal locking mode are used.
|
||||
// If no PRAGMAs are specified, a busy timeout of 1 minute is set.
|
||||
//
|
||||
// Order matters:
|
||||
// busy timeout and locking mode should be the first PRAGMAs set, in that order.
|
||||
@@ -41,45 +40,96 @@ import (
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// This variable can be replaced with -ldflags:
|
||||
//
|
||||
// go build -ldflags="-X github.com/ncruces/go-sqlite3.driverName=sqlite"
|
||||
var driverName = "sqlite3"
|
||||
|
||||
func init() {
|
||||
sql.Register("sqlite3", sqlite{})
|
||||
if driverName != "" {
|
||||
sql.Register(driverName, sqlite{})
|
||||
}
|
||||
}
|
||||
|
||||
// Open opens the SQLite database specified by dataSourceName as a [database/sql.DB].
|
||||
//
|
||||
// The init function is called by the driver on new connections.
|
||||
// The conn can be used to execute queries, register functions, etc.
|
||||
// Any error return closes the conn and passes the error to database/sql.
|
||||
func Open(dataSourceName string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*sql.DB, error) {
|
||||
c, err := newConnector(dataSourceName, init)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sql.OpenDB(c), nil
|
||||
}
|
||||
|
||||
type sqlite struct{}
|
||||
|
||||
func (sqlite) Open(name string) (_ driver.Conn, err error) {
|
||||
var c conn
|
||||
c.Conn, err = sqlite3.Open(name)
|
||||
func (sqlite) Open(name string) (driver.Conn, error) {
|
||||
c, err := newConnector(name, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.Connect(context.Background())
|
||||
}
|
||||
|
||||
c.txBegin = "BEGIN"
|
||||
var pragmas []string
|
||||
func (sqlite) OpenConnector(name string) (driver.Connector, error) {
|
||||
return newConnector(name, nil)
|
||||
}
|
||||
|
||||
func newConnector(name string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*connector, error) {
|
||||
c := connector{name: name, init: init}
|
||||
if strings.HasPrefix(name, "file:") {
|
||||
if _, after, ok := strings.Cut(name, "?"); ok {
|
||||
query, _ := url.ParseQuery(after)
|
||||
|
||||
switch s := query.Get("_txlock"); s {
|
||||
case "":
|
||||
c.txBegin = "BEGIN"
|
||||
case "deferred", "immediate", "exclusive":
|
||||
c.txBegin = "BEGIN " + s
|
||||
default:
|
||||
c.Close()
|
||||
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
|
||||
query, err := url.ParseQuery(after)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pragmas = query["_pragma"]
|
||||
c.txlock = query.Get("_txlock")
|
||||
c.pragmas = len(query["_pragma"]) > 0
|
||||
}
|
||||
}
|
||||
if len(pragmas) == 0 {
|
||||
err := c.Conn.Exec(`
|
||||
PRAGMA busy_timeout=60000;
|
||||
PRAGMA locking_mode=normal;
|
||||
`)
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
type connector struct {
|
||||
init func(ctx context.Context, conn *sqlite3.Conn) error
|
||||
name string
|
||||
txlock string
|
||||
pragmas bool
|
||||
}
|
||||
|
||||
func (n *connector) Driver() driver.Driver {
|
||||
return sqlite{}
|
||||
}
|
||||
|
||||
func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
|
||||
var c conn
|
||||
c.Conn, err = sqlite3.Open(n.name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
c.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
|
||||
switch n.txlock {
|
||||
case "":
|
||||
c.txBegin = "BEGIN"
|
||||
case "deferred", "immediate", "exclusive":
|
||||
c.txBegin = "BEGIN " + n.txlock
|
||||
default:
|
||||
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", n.txlock)
|
||||
}
|
||||
if !n.pragmas {
|
||||
err = c.Conn.Exec(`PRAGMA busy_timeout=60000`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.reusable = true
|
||||
@@ -90,7 +140,6 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
|
||||
PRAGMA_query_only;
|
||||
`)
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
if s.Step() {
|
||||
@@ -99,7 +148,12 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
|
||||
}
|
||||
err = s.Close()
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if n.init != nil {
|
||||
err = n.init(ctx, c.Conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -117,12 +171,17 @@ type conn struct {
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ driver.ExecerContext = &conn{}
|
||||
_ driver.ConnBeginTx = &conn{}
|
||||
_ driver.Validator = &conn{}
|
||||
_ sqlite3.DriverConn = &conn{}
|
||||
_ driver.ConnPrepareContext = &conn{}
|
||||
_ driver.ExecerContext = &conn{}
|
||||
_ driver.ConnBeginTx = &conn{}
|
||||
_ driver.Validator = &conn{}
|
||||
_ sqlite3.DriverConn = &conn{}
|
||||
)
|
||||
|
||||
func (c *conn) Raw() *sqlite3.Conn {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *conn) IsValid() bool {
|
||||
return c.reusable
|
||||
}
|
||||
@@ -167,7 +226,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
|
||||
|
||||
func (c *conn) Commit() error {
|
||||
err := c.Conn.Exec(c.txCommit)
|
||||
if err != nil && !c.GetAutocommit() {
|
||||
if err != nil && !c.Conn.GetAutocommit() {
|
||||
c.Rollback()
|
||||
}
|
||||
return err
|
||||
@@ -259,13 +318,14 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
}
|
||||
|
||||
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
// Use QueryContext to setup bindings.
|
||||
// No need to close rows: that simply resets the statement, exec does the same.
|
||||
_, err := s.QueryContext(ctx, args)
|
||||
err := s.setupBindings(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
old := s.Conn.SetInterrupt(ctx)
|
||||
defer s.Conn.SetInterrupt(old)
|
||||
|
||||
err = s.Stmt.Exec()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -275,10 +335,18 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
}
|
||||
|
||||
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
err := s.Stmt.ClearBindings()
|
||||
err := s.setupBindings(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rows{ctx, s.Stmt, s.Conn}, nil
|
||||
}
|
||||
|
||||
func (s *stmt) setupBindings(args []driver.NamedValue) error {
|
||||
err := s.Stmt.ClearBindings()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ids [3]int
|
||||
for _, arg := range args {
|
||||
@@ -318,11 +386,10 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return &rows{ctx, s.Stmt, s.Conn}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
|
||||
|
||||
@@ -47,7 +47,7 @@ func ExampleDriverConn() {
|
||||
}
|
||||
|
||||
err = conn.Raw(func(driverConn any) error {
|
||||
conn := driverConn.(sqlite3.DriverConn)
|
||||
conn := driverConn.(sqlite3.DriverConn).Raw()
|
||||
savept := conn.Savepoint()
|
||||
defer savept.Release(&err)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Embeddable WASM build of SQLite
|
||||
|
||||
This folder includes an embeddable WASM build of SQLite 3.42.0 for use with
|
||||
This folder includes an embeddable WASM build of SQLite 3.43.2 for use with
|
||||
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
|
||||
|
||||
The following optional features are compiled in:
|
||||
@@ -9,6 +9,7 @@ The following optional features are compiled in:
|
||||
- [JSON](https://www.sqlite.org/json1.html)
|
||||
- [R*Tree](https://www.sqlite.org/rtree.html)
|
||||
- [GeoPoly](https://www.sqlite.org/geopoly.html)
|
||||
- [soundex](https://www.sqlite.org/lang_corefunc.html#soundex)
|
||||
- [base64](https://github.com/sqlite/sqlite/blob/master/ext/misc/base64.c)
|
||||
- [decimal](https://github.com/sqlite/sqlite/blob/master/ext/misc/decimal.c)
|
||||
- [regexp](https://github.com/sqlite/sqlite/blob/master/ext/misc/regexp.c)
|
||||
|
||||
@@ -4,24 +4,27 @@ set -euo pipefail
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
ROOT=../
|
||||
BINARYEN="$ROOT/tools/binaryen-version_113/bin"
|
||||
BINARYEN="$ROOT/tools/binaryen-version_116/bin"
|
||||
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
|
||||
|
||||
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
|
||||
-o sqlite3.wasm "$ROOT/sqlite3/main.c" \
|
||||
-I"$ROOT/sqlite3" \
|
||||
-mexec-model=reactor \
|
||||
-mmutable-globals \
|
||||
-msimd128 -mmutable-globals \
|
||||
-mbulk-memory -mreference-types \
|
||||
-mnontrapping-fptoint -msign-ext \
|
||||
-fno-stack-protector -fno-stack-clash-protection \
|
||||
-Wl,--initial-memory=327680 \
|
||||
-Wl,--stack-first \
|
||||
-Wl,--import-undefined \
|
||||
-D_HAVE_SQLITE_CONFIG_H \
|
||||
$(awk '{print "-Wl,--export="$0}' exports.txt)
|
||||
|
||||
trap 'rm -f sqlite3.tmp' EXIT
|
||||
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
|
||||
"$BINARYEN/wasm-opt" -g -O2 sqlite3.tmp -o sqlite3.wasm \
|
||||
--enable-multivalue --enable-mutable-globals \
|
||||
"$BINARYEN/wasm-opt" -g --strip -c -O3 \
|
||||
sqlite3.tmp -o sqlite3.wasm \
|
||||
--enable-simd --enable-mutable-globals --enable-multivalue \
|
||||
--enable-bulk-memory --enable-reference-types \
|
||||
--enable-nontrapping-float-to-int --enable-sign-ext
|
||||
@@ -33,10 +33,10 @@ sqlite3_column_blob
|
||||
sqlite3_column_bytes
|
||||
sqlite3_blob_open
|
||||
sqlite3_blob_close
|
||||
sqlite3_blob_reopen
|
||||
sqlite3_blob_bytes
|
||||
sqlite3_blob_read
|
||||
sqlite3_blob_write
|
||||
sqlite3_blob_reopen
|
||||
sqlite3_backup_init
|
||||
sqlite3_backup_step
|
||||
sqlite3_backup_finish
|
||||
@@ -46,4 +46,29 @@ sqlite3_uri_parameter
|
||||
sqlite3_uri_key
|
||||
sqlite3_changes64
|
||||
sqlite3_last_insert_rowid
|
||||
sqlite3_get_autocommit
|
||||
sqlite3_get_autocommit
|
||||
sqlite3_anycollseq_init
|
||||
sqlite3_create_collation_go
|
||||
sqlite3_create_function_go
|
||||
sqlite3_create_aggregate_function_go
|
||||
sqlite3_create_window_function_go
|
||||
sqlite3_aggregate_context
|
||||
sqlite3_user_data
|
||||
sqlite3_set_auxdata_go
|
||||
sqlite3_get_auxdata
|
||||
sqlite3_value_type
|
||||
sqlite3_value_int64
|
||||
sqlite3_value_double
|
||||
sqlite3_value_text
|
||||
sqlite3_value_blob
|
||||
sqlite3_value_bytes
|
||||
sqlite3_result_null
|
||||
sqlite3_result_int64
|
||||
sqlite3_result_double
|
||||
sqlite3_result_text64
|
||||
sqlite3_result_blob64
|
||||
sqlite3_result_zeroblob64
|
||||
sqlite3_result_error
|
||||
sqlite3_result_error_code
|
||||
sqlite3_result_error_nomem
|
||||
sqlite3_result_error_toobig
|
||||
Binary file not shown.
22
error.go
22
error.go
@@ -68,6 +68,19 @@ func (e *Error) Is(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// As converts this error to an [ErrorCode] or [ExtendedErrorCode].
|
||||
func (e *Error) As(err any) bool {
|
||||
switch c := err.(type) {
|
||||
case *ErrorCode:
|
||||
*c = e.Code()
|
||||
return true
|
||||
case *ExtendedErrorCode:
|
||||
*c = e.ExtendedCode()
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Temporary returns true for [BUSY] errors.
|
||||
func (e *Error) Temporary() bool {
|
||||
return e.Code() == BUSY
|
||||
@@ -104,6 +117,15 @@ func (e ExtendedErrorCode) Is(err error) bool {
|
||||
return ok && c == ErrorCode(e)
|
||||
}
|
||||
|
||||
// As converts this error to an [ErrorCode].
|
||||
func (e ExtendedErrorCode) As(err any) bool {
|
||||
c, ok := err.(*ErrorCode)
|
||||
if ok {
|
||||
*c = ErrorCode(e)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// Temporary returns true for [BUSY] errors.
|
||||
func (e ExtendedErrorCode) Temporary() bool {
|
||||
return ErrorCode(e) == BUSY
|
||||
|
||||
@@ -18,22 +18,36 @@ func Test_assertErr(t *testing.T) {
|
||||
func TestError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := Error{code: 0x8080}
|
||||
if rc := err.Code(); rc != 0x80 {
|
||||
t.Errorf("got %#x, want 0x80", rc)
|
||||
var ecode ErrorCode
|
||||
var xcode xErrorCode
|
||||
err := &Error{code: 0x8080}
|
||||
if !errors.As(err, &err) {
|
||||
t.Fatal("want true")
|
||||
}
|
||||
if !errors.Is(&err, ErrorCode(0x80)) {
|
||||
if ecode := err.Code(); ecode != 0x80 {
|
||||
t.Errorf("got %#x, want 0x80", uint8(ecode))
|
||||
}
|
||||
if ok := errors.As(err, &ecode); !ok || ecode != ErrorCode(0x80) {
|
||||
t.Errorf("got %#x, want 0x80", uint8(ecode))
|
||||
}
|
||||
if !errors.Is(err, ErrorCode(0x80)) {
|
||||
t.Errorf("want true")
|
||||
}
|
||||
if rc := err.ExtendedCode(); rc != 0x8080 {
|
||||
t.Errorf("got %#x, want 0x8080", rc)
|
||||
if xcode := err.ExtendedCode(); xcode != 0x8080 {
|
||||
t.Errorf("got %#x, want 0x8080", uint16(xcode))
|
||||
}
|
||||
if !errors.Is(&err, ExtendedErrorCode(0x8080)) {
|
||||
if ok := errors.As(err, &xcode); !ok || xcode != xErrorCode(0x8080) {
|
||||
t.Errorf("got %#x, want 0x8080", uint16(xcode))
|
||||
}
|
||||
if !errors.Is(err, xErrorCode(0x8080)) {
|
||||
t.Errorf("want true")
|
||||
}
|
||||
if s := err.Error(); s != "sqlite3: 32896" {
|
||||
t.Errorf("got %q", s)
|
||||
}
|
||||
if ok := errors.As(err.ExtendedCode(), &ecode); !ok || ecode != ErrorCode(0x80) {
|
||||
t.Errorf("got %#x, want 0x80", uint8(ecode))
|
||||
}
|
||||
if !errors.Is(err.ExtendedCode(), ErrorCode(0x80)) {
|
||||
t.Errorf("want true")
|
||||
}
|
||||
|
||||
109
ext/stats/stats.go
Normal file
109
ext/stats/stats.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Package stats provides aggregate functions for statistics.
|
||||
//
|
||||
// Functions:
|
||||
// - stddev_pop: population standard deviation
|
||||
// - stddev_samp: sample standard deviation
|
||||
// - var_pop: population variance
|
||||
// - var_samp: sample variance
|
||||
// - covar_pop: population covariance
|
||||
// - covar_samp: sample covariance
|
||||
// - corr: correlation coefficient
|
||||
//
|
||||
// See: [ANSI SQL Aggregate Functions]
|
||||
//
|
||||
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
|
||||
package stats
|
||||
|
||||
import "github.com/ncruces/go-sqlite3"
|
||||
|
||||
// Register registers statistics functions.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop))
|
||||
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp))
|
||||
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop))
|
||||
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp))
|
||||
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
|
||||
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
|
||||
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr))
|
||||
}
|
||||
|
||||
const (
|
||||
var_pop = iota
|
||||
var_samp
|
||||
stddev_pop
|
||||
stddev_samp
|
||||
corr
|
||||
)
|
||||
|
||||
func newVariance(kind int) func() sqlite3.AggregateFunction {
|
||||
return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
|
||||
}
|
||||
|
||||
type variance struct {
|
||||
kind int
|
||||
welford
|
||||
}
|
||||
|
||||
func (fn *variance) Value(ctx sqlite3.Context) {
|
||||
var r float64
|
||||
switch fn.kind {
|
||||
case var_pop:
|
||||
r = fn.var_pop()
|
||||
case var_samp:
|
||||
r = fn.var_samp()
|
||||
case stddev_pop:
|
||||
r = fn.stddev_pop()
|
||||
case stddev_samp:
|
||||
r = fn.stddev_samp()
|
||||
}
|
||||
ctx.ResultFloat(r)
|
||||
}
|
||||
|
||||
func (fn *variance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if a := arg[0]; a.Type() != sqlite3.NULL {
|
||||
fn.enqueue(a.Float())
|
||||
}
|
||||
}
|
||||
|
||||
func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if a := arg[0]; a.Type() != sqlite3.NULL {
|
||||
fn.dequeue(a.Float())
|
||||
}
|
||||
}
|
||||
|
||||
func newCovariance(kind int) func() sqlite3.AggregateFunction {
|
||||
return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
|
||||
}
|
||||
|
||||
type covariance struct {
|
||||
kind int
|
||||
welford2
|
||||
}
|
||||
|
||||
func (fn *covariance) Value(ctx sqlite3.Context) {
|
||||
var r float64
|
||||
switch fn.kind {
|
||||
case var_pop:
|
||||
r = fn.covar_pop()
|
||||
case var_samp:
|
||||
r = fn.covar_samp()
|
||||
case corr:
|
||||
r = fn.correlation()
|
||||
}
|
||||
ctx.ResultFloat(r)
|
||||
}
|
||||
|
||||
func (fn *covariance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
a, b := arg[0], arg[1]
|
||||
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
|
||||
fn.enqueue(a.Float(), b.Float())
|
||||
}
|
||||
}
|
||||
|
||||
func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
a, b := arg[0], arg[1]
|
||||
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
|
||||
fn.dequeue(a.Float(), b.Float())
|
||||
}
|
||||
}
|
||||
140
ext/stats/stats_test.go
Normal file
140
ext/stats/stats_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestRegister_variance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`
|
||||
SELECT
|
||||
sum(x), avg(x),
|
||||
var_samp(x), var_pop(x),
|
||||
stddev_samp(x), stddev_pop(x)
|
||||
FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 40 {
|
||||
t.Errorf("got %v, want 40", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(1); got != 10 {
|
||||
t.Errorf("got %v, want 10", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(2); got != 30 {
|
||||
t.Errorf("got %v, want 30", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(3); got != 22.5 {
|
||||
t.Errorf("got %v, want 22.5", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(4); got != math.Sqrt(30) {
|
||||
t.Errorf("got %v, want √30", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(5); got != math.Sqrt(22.5) {
|
||||
t.Errorf("got %v, want √22.5", got)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
stmt, _, err := db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
want := [...]float64{0, 4.5, 18, 0, 0}
|
||||
for i := 0; stmt.Step(); i++ {
|
||||
if got := stmt.ColumnFloat(0); got != want[i] {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_covariance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x, y)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO data (x, y) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT
|
||||
corr(x, y), covar_samp(x, y), covar_pop(x, y) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
|
||||
t.Errorf("got %v, want 0.9881049293224639", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(1); got != 21.25 {
|
||||
t.Errorf("got %v, want 21.25", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(2); got != 17 {
|
||||
t.Errorf("got %v, want 17", got)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
stmt, _, err := db.Prepare(`SELECT covar_samp(x, y) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
want := [...]float64{0, 10, 30, 75, 22.5}
|
||||
for i := 0; stmt.Step(); i++ {
|
||||
if got := stmt.ColumnFloat(0); got != want[i] {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
109
ext/stats/welford.go
Normal file
109
ext/stats/welford.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package stats
|
||||
|
||||
import "math"
|
||||
|
||||
// Welford's algorithm with Kahan summation:
|
||||
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
|
||||
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
|
||||
|
||||
type welford struct {
|
||||
m1, m2 kahan
|
||||
n uint64
|
||||
}
|
||||
|
||||
func (w welford) average() float64 {
|
||||
return w.m1.hi
|
||||
}
|
||||
|
||||
func (w welford) var_pop() float64 {
|
||||
return w.m2.hi / float64(w.n)
|
||||
}
|
||||
|
||||
func (w welford) var_samp() float64 {
|
||||
return w.m2.hi / float64(w.n-1) // Bessel's correction
|
||||
}
|
||||
|
||||
func (w welford) stddev_pop() float64 {
|
||||
return math.Sqrt(w.var_pop())
|
||||
}
|
||||
|
||||
func (w welford) stddev_samp() float64 {
|
||||
return math.Sqrt(w.var_samp())
|
||||
}
|
||||
|
||||
func (w *welford) enqueue(x float64) {
|
||||
w.n++
|
||||
d1 := x - w.m1.hi - w.m1.lo
|
||||
w.m1.add(d1 / float64(w.n))
|
||||
d2 := x - w.m1.hi - w.m1.lo
|
||||
w.m2.add(d1 * d2)
|
||||
}
|
||||
|
||||
func (w *welford) dequeue(x float64) {
|
||||
w.n--
|
||||
d1 := x - w.m1.hi - w.m1.lo
|
||||
w.m1.sub(d1 / float64(w.n))
|
||||
d2 := x - w.m1.hi - w.m1.lo
|
||||
w.m2.sub(d1 * d2)
|
||||
}
|
||||
|
||||
type welford2 struct {
|
||||
m1x, m2x kahan
|
||||
m1y, m2y kahan
|
||||
cov kahan
|
||||
n uint64
|
||||
}
|
||||
|
||||
func (w welford2) covar_pop() float64 {
|
||||
return w.cov.hi / float64(w.n)
|
||||
}
|
||||
|
||||
func (w welford2) covar_samp() float64 {
|
||||
return w.cov.hi / float64(w.n-1) // Bessel's correction
|
||||
}
|
||||
|
||||
func (w welford2) correlation() float64 {
|
||||
return w.cov.hi / math.Sqrt(w.m2x.hi*w.m2y.hi)
|
||||
}
|
||||
|
||||
func (w *welford2) enqueue(x, y float64) {
|
||||
w.n++
|
||||
d1x := x - w.m1x.hi - w.m1x.lo
|
||||
d1y := y - w.m1y.hi - w.m1y.lo
|
||||
w.m1x.add(d1x / float64(w.n))
|
||||
w.m1y.add(d1y / float64(w.n))
|
||||
d2x := x - w.m1x.hi - w.m1x.lo
|
||||
d2y := y - w.m1y.hi - w.m1y.lo
|
||||
w.m2x.add(d1x * d2x)
|
||||
w.m2y.add(d1y * d2y)
|
||||
w.cov.add(d1x * d2y)
|
||||
}
|
||||
|
||||
func (w *welford2) dequeue(x, y float64) {
|
||||
w.n--
|
||||
d1x := x - w.m1x.hi - w.m1x.lo
|
||||
d1y := y - w.m1y.hi - w.m1y.lo
|
||||
w.m1x.sub(d1x / float64(w.n))
|
||||
w.m1y.sub(d1y / float64(w.n))
|
||||
d2x := x - w.m1x.hi - w.m1x.lo
|
||||
d2y := y - w.m1y.hi - w.m1y.lo
|
||||
w.m2x.sub(d1x * d2x)
|
||||
w.m2y.sub(d1y * d2y)
|
||||
w.cov.sub(d1x * d2y)
|
||||
}
|
||||
|
||||
type kahan struct{ hi, lo float64 }
|
||||
|
||||
func (k *kahan) add(x float64) {
|
||||
y := k.lo + x
|
||||
t := k.hi + y
|
||||
k.lo = y - (t - k.hi)
|
||||
k.hi = t
|
||||
}
|
||||
|
||||
func (k *kahan) sub(x float64) {
|
||||
y := k.lo - x
|
||||
t := k.hi + y
|
||||
k.lo = y - (t - k.hi)
|
||||
k.hi = t
|
||||
}
|
||||
75
ext/stats/welford_test.go
Normal file
75
ext/stats/welford_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_welford(t *testing.T) {
|
||||
var s1, s2 welford
|
||||
|
||||
s1.enqueue(4)
|
||||
s1.enqueue(7)
|
||||
s1.enqueue(13)
|
||||
s1.enqueue(16)
|
||||
if got := s1.average(); got != 10 {
|
||||
t.Errorf("got %v, want 10", got)
|
||||
}
|
||||
if got := s1.var_samp(); got != 30 {
|
||||
t.Errorf("got %v, want 30", got)
|
||||
}
|
||||
if got := s1.var_pop(); got != 22.5 {
|
||||
t.Errorf("got %v, want 22.5", got)
|
||||
}
|
||||
if got := s1.stddev_samp(); got != math.Sqrt(30) {
|
||||
t.Errorf("got %v, want √30", got)
|
||||
}
|
||||
if got := s1.stddev_pop(); got != math.Sqrt(22.5) {
|
||||
t.Errorf("got %v, want √22.5", got)
|
||||
}
|
||||
|
||||
s1.dequeue(4)
|
||||
s2.enqueue(7)
|
||||
s2.enqueue(13)
|
||||
s2.enqueue(16)
|
||||
if s1.var_pop() != s2.var_pop() {
|
||||
t.Errorf("got %v, want %v", s1, s2)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_covar(t *testing.T) {
|
||||
var c1, c2 welford2
|
||||
|
||||
c1.enqueue(3, 70)
|
||||
c1.enqueue(5, 80)
|
||||
c1.enqueue(2, 60)
|
||||
c1.enqueue(7, 90)
|
||||
c1.enqueue(4, 75)
|
||||
|
||||
if got := c1.covar_samp(); got != 21.25 {
|
||||
t.Errorf("got %v, want 21.25", got)
|
||||
}
|
||||
if got := c1.covar_pop(); got != 17 {
|
||||
t.Errorf("got %v, want 17", got)
|
||||
}
|
||||
|
||||
c1.dequeue(3, 70)
|
||||
c2.enqueue(5, 80)
|
||||
c2.enqueue(2, 60)
|
||||
c2.enqueue(7, 90)
|
||||
c2.enqueue(4, 75)
|
||||
if c1.covar_pop() != c2.covar_pop() {
|
||||
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_correlation(t *testing.T) {
|
||||
var c welford2
|
||||
c.enqueue(1, 3)
|
||||
c.enqueue(2, 2)
|
||||
c.enqueue(3, 1)
|
||||
|
||||
if got := c.correlation(); got != -1 {
|
||||
t.Errorf("got %v, want -1", got)
|
||||
}
|
||||
}
|
||||
181
ext/unicode/unicode.go
Normal file
181
ext/unicode/unicode.go
Normal file
@@ -0,0 +1,181 @@
|
||||
// Package unicode provides an alternative to the SQLite ICU extension.
|
||||
//
|
||||
// Like the [ICU extension], it provides Unicode aware:
|
||||
// - upper() and lower() functions,
|
||||
// - LIKE and REGEXP operators,
|
||||
// - collation sequences.
|
||||
//
|
||||
// The implementation is not 100% compatible with the [ICU extension]:
|
||||
// - upper() and lower() use [strings.ToUpper], [strings.ToLower] and [cases];
|
||||
// - the LIKE operator follows [strings.EqualFold] rules;
|
||||
// - the REGEXP operator uses Go [regex/syntax];
|
||||
// - collation sequences use [collate].
|
||||
//
|
||||
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
|
||||
//
|
||||
// [ICU extension]: https://sqlite.org/src/dir/ext/icu
|
||||
package unicode
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/collate"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// Register registers Unicode aware functions for a database connection.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
|
||||
db.CreateFunction("like", 2, flags, like)
|
||||
db.CreateFunction("like", 3, flags, like)
|
||||
db.CreateFunction("upper", 1, flags, upper)
|
||||
db.CreateFunction("upper", 2, flags, upper)
|
||||
db.CreateFunction("lower", 1, flags, lower)
|
||||
db.CreateFunction("lower", 2, flags, lower)
|
||||
db.CreateFunction("regexp", 2, flags, regex)
|
||||
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
|
||||
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
name := arg[1].Text()
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err := RegisterCollation(db, arg[0].Text(), name)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterCollation registers a Unicode collation sequence for a database connection.
|
||||
func RegisterCollation(db *sqlite3.Conn, locale, name string) error {
|
||||
tag, err := language.Parse(locale)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.CreateCollation(name, collate.New(tag).Compare)
|
||||
}
|
||||
|
||||
func upper(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if len(arg) == 1 {
|
||||
ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob()))
|
||||
return
|
||||
}
|
||||
cs, ok := ctx.GetAuxData(1).(cases.Caser)
|
||||
if !ok {
|
||||
t, err := language.Parse(arg[1].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
c := cases.Upper(t)
|
||||
ctx.SetAuxData(1, c)
|
||||
cs = c
|
||||
}
|
||||
ctx.ResultBlob(cs.Bytes(arg[0].RawBlob()))
|
||||
}
|
||||
|
||||
func lower(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if len(arg) == 1 {
|
||||
ctx.ResultBlob(bytes.ToLower(arg[0].RawBlob()))
|
||||
return
|
||||
}
|
||||
cs, ok := ctx.GetAuxData(1).(cases.Caser)
|
||||
if !ok {
|
||||
t, err := language.Parse(arg[1].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
c := cases.Lower(t)
|
||||
ctx.SetAuxData(1, c)
|
||||
cs = c
|
||||
}
|
||||
ctx.ResultBlob(cs.Bytes(arg[0].RawBlob()))
|
||||
}
|
||||
|
||||
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
|
||||
if !ok {
|
||||
r, err := regexp.Compile(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
re = r
|
||||
ctx.SetAuxData(0, re)
|
||||
}
|
||||
ctx.ResultBool(re.Match(arg[1].RawBlob()))
|
||||
}
|
||||
|
||||
func like(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
escape := rune(-1)
|
||||
if len(arg) == 3 {
|
||||
var size int
|
||||
b := arg[2].RawBlob()
|
||||
escape, size = utf8.DecodeRune(b)
|
||||
if size != len(b) {
|
||||
ctx.ResultError(util.ErrorString("ESCAPE expression must be a single character"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type likeData struct {
|
||||
*regexp.Regexp
|
||||
escape rune
|
||||
}
|
||||
|
||||
re, ok := ctx.GetAuxData(0).(likeData)
|
||||
if !ok || re.escape != escape {
|
||||
re = likeData{
|
||||
regexp.MustCompile(like2regex(arg[0].Text(), escape)),
|
||||
escape,
|
||||
}
|
||||
ctx.SetAuxData(0, re)
|
||||
}
|
||||
ctx.ResultBool(re.Match(arg[1].RawBlob()))
|
||||
}
|
||||
|
||||
func like2regex(pattern string, escape rune) string {
|
||||
var re strings.Builder
|
||||
start := 0
|
||||
literal := false
|
||||
re.Grow(len(pattern) + 10)
|
||||
re.WriteString(`(?is)\A`) // case insensitive, . matches any character
|
||||
for i, r := range pattern {
|
||||
if start < 0 {
|
||||
start = i
|
||||
}
|
||||
if literal {
|
||||
literal = false
|
||||
continue
|
||||
}
|
||||
var symbol string
|
||||
switch r {
|
||||
case '_':
|
||||
symbol = `.`
|
||||
case '%':
|
||||
symbol = `.*`
|
||||
case escape:
|
||||
literal = true
|
||||
default:
|
||||
continue
|
||||
}
|
||||
re.WriteString(regexp.QuoteMeta(pattern[start:i]))
|
||||
re.WriteString(symbol)
|
||||
start = -1
|
||||
}
|
||||
if start >= 0 {
|
||||
re.WriteString(regexp.QuoteMeta(pattern[start:]))
|
||||
}
|
||||
re.WriteString(`\z`)
|
||||
return re.String()
|
||||
}
|
||||
215
ext/unicode/unicode_test.go
Normal file
215
ext/unicode/unicode_test.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package unicode
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
exec := func(fn string) string {
|
||||
stmt, _, err := db.Prepare(`SELECT ` + fn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
return stmt.ColumnText(0)
|
||||
}
|
||||
t.Fatal(stmt.Err())
|
||||
return ""
|
||||
}
|
||||
|
||||
Register(db)
|
||||
|
||||
tests := []struct {
|
||||
test string
|
||||
want string
|
||||
}{
|
||||
{`upper('hello')`, "HELLO"},
|
||||
{`lower('HELLO')`, "hello"},
|
||||
{`upper('привет')`, "ПРИВЕТ"},
|
||||
{`lower('ПРИВЕТ')`, "привет"},
|
||||
{`upper('istanbul')`, "ISTANBUL"},
|
||||
{`upper('istanbul', 'tr-TR')`, "İSTANBUL"},
|
||||
{`lower('Dünyanın İlk Borsası', 'tr-TR')`, "dünyanın ilk borsası"},
|
||||
{`upper('Dünyanın İlk Borsası', 'tr-TR')`, "DÜNYANIN İLK BORSASI"},
|
||||
{`'Hello' REGEXP 'ell'`, "1"},
|
||||
{`'Hello' REGEXP 'el.'`, "1"},
|
||||
{`'Hello' LIKE 'hel_'`, "0"},
|
||||
{`'Hello' LIKE 'hel%'`, "1"},
|
||||
{`'Hello' LIKE 'h_llo'`, "1"},
|
||||
{`'Hello' LIKE 'hello'`, "1"},
|
||||
{`'Привет' LIKE 'ПРИВЕТ'`, "1"},
|
||||
{`'100%' LIKE '100|%' ESCAPE '|'`, "1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.test, func(t *testing.T) {
|
||||
if got := exec(tt.test); got != tt.want {
|
||||
t.Errorf("exec(%q) = %q, want %q", tt.test, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_collation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT icu_load_collation('fr_FR', 'french')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT word FROM words ORDER BY word COLLATE french`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
got, want := []string{}, []string{"cote", "coté", "côte", "côté", "cotée", "coter"}
|
||||
|
||||
for stmt.Step() {
|
||||
got = append(got, stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Error("not equal")
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
Register(db)
|
||||
|
||||
err = db.Exec(`SELECT upper('hello', 'enUS')`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT lower('hello', 'enUS')`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT 'hello' REGEXP '\'`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT 'hello' LIKE 'HELLO' ESCAPE '\\'`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT icu_load_collation('enUS', 'error')`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.ERROR) {
|
||||
t.Errorf("got %v, want sqlite3.ERROR", err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT icu_load_collation('enUS', '')`)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_like2regex(t *testing.T) {
|
||||
const prefix = `(?is)\A`
|
||||
const sufix = `\z`
|
||||
tests := []struct {
|
||||
pattern string
|
||||
escape rune
|
||||
want string
|
||||
}{
|
||||
{`a`, -1, `a`},
|
||||
{`a.`, -1, `a\.`},
|
||||
{`a%`, -1, `a.*`},
|
||||
{`a\`, -1, `a\\`},
|
||||
{`a_b`, -1, `a.b`},
|
||||
{`a|b`, '|', `ab`},
|
||||
{`a|_`, '|', `a_`},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.pattern, func(t *testing.T) {
|
||||
want := prefix + tt.want + sufix
|
||||
if got := like2regex(tt.pattern, tt.escape); got != want {
|
||||
t.Errorf("like2regex() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
186
func.go
Normal file
186
func.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
// AnyCollationNeeded registers a fake collating function
|
||||
// for any unknown collating sequence.
|
||||
// The fake collating function works like BINARY.
|
||||
//
|
||||
// This can be used to load schemas that contain
|
||||
// one or more unknown collating sequences.
|
||||
func (c *Conn) AnyCollationNeeded() {
|
||||
c.call(c.api.anyCollation, uint64(c.handle), 0, 0)
|
||||
}
|
||||
|
||||
// CreateCollation defines a new collating sequence.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/create_collation.html
|
||||
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
|
||||
namePtr := c.arena.string(name)
|
||||
funcPtr := util.AddHandle(c.ctx, fn)
|
||||
r := c.call(c.api.createCollation,
|
||||
uint64(c.handle), uint64(namePtr), uint64(funcPtr))
|
||||
if err := c.error(r); err != nil {
|
||||
util.DelHandle(c.ctx, funcPtr)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateFunction defines a new scalar SQL function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/create_function.html
|
||||
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error {
|
||||
namePtr := c.arena.string(name)
|
||||
funcPtr := util.AddHandle(c.ctx, fn)
|
||||
r := c.call(c.api.createFunction,
|
||||
uint64(c.handle), uint64(namePtr), uint64(nArg),
|
||||
uint64(flag), uint64(funcPtr))
|
||||
return c.error(r)
|
||||
}
|
||||
|
||||
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
|
||||
// If fn returns a [WindowFunction], then an aggregate window function is created.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/create_function.html
|
||||
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
|
||||
call := c.api.createAggregate
|
||||
namePtr := c.arena.string(name)
|
||||
funcPtr := util.AddHandle(c.ctx, fn)
|
||||
if _, ok := fn().(WindowFunction); ok {
|
||||
call = c.api.createWindow
|
||||
}
|
||||
r := c.call(call,
|
||||
uint64(c.handle), uint64(namePtr), uint64(nArg),
|
||||
uint64(flag), uint64(funcPtr))
|
||||
return c.error(r)
|
||||
}
|
||||
|
||||
// AggregateFunction is the interface an aggregate function should implement.
|
||||
//
|
||||
// https://www.sqlite.org/appfunc.html
|
||||
type AggregateFunction interface {
|
||||
// Step is invoked to add a row to the current window.
|
||||
// The function arguments, if any, corresponding to the row being added are passed to Step.
|
||||
Step(ctx Context, arg ...Value)
|
||||
|
||||
// Value is invoked to return the current value of the aggregate.
|
||||
Value(ctx Context)
|
||||
}
|
||||
|
||||
// WindowFunction is the interface an aggregate window function should implement.
|
||||
//
|
||||
// https://www.sqlite.org/windowfunctions.html
|
||||
type WindowFunction interface {
|
||||
AggregateFunction
|
||||
|
||||
// Inverse is invoked to remove the oldest presently aggregated result of Step from the current window.
|
||||
// The function arguments, if any, are those passed to Step for the row being removed.
|
||||
Inverse(ctx Context, arg ...Value)
|
||||
}
|
||||
|
||||
func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
|
||||
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
|
||||
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
|
||||
util.ExportFuncVIII(env, "go_func", callbackFunc)
|
||||
util.ExportFuncVIII(env, "go_step", callbackStep)
|
||||
util.ExportFuncVI(env, "go_final", callbackFinal)
|
||||
util.ExportFuncVI(env, "go_value", callbackValue)
|
||||
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
|
||||
return env
|
||||
}
|
||||
|
||||
func callbackDestroy(ctx context.Context, mod api.Module, pApp uint32) {
|
||||
util.DelHandle(ctx, pApp)
|
||||
}
|
||||
|
||||
func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
|
||||
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
|
||||
return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))
|
||||
}
|
||||
|
||||
func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackHandle(sqlite, pCtx).(func(ctx Context, arg ...Value))
|
||||
fn(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
|
||||
fn.Step(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) {
|
||||
var handle uint32
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackAggregate(sqlite, pCtx, &handle).(AggregateFunction)
|
||||
fn.Value(Context{sqlite, pCtx})
|
||||
if err := util.DelHandle(ctx, handle); err != nil {
|
||||
Context{sqlite, pCtx}.ResultError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
|
||||
fn.Value(Context{sqlite, pCtx})
|
||||
}
|
||||
|
||||
func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackAggregate(sqlite, pCtx, nil).(WindowFunction)
|
||||
fn.Inverse(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackHandle(sqlite *sqlite, pCtx uint32) any {
|
||||
pApp := uint32(sqlite.call(sqlite.api.userData, uint64(pCtx)))
|
||||
return util.GetHandle(sqlite.ctx, pApp)
|
||||
}
|
||||
|
||||
func callbackAggregate(sqlite *sqlite, pCtx uint32, close *uint32) any {
|
||||
// On close, we're getting rid of the handle.
|
||||
// Don't allocate space to store it.
|
||||
var size uint64
|
||||
if close == nil {
|
||||
size = ptrlen
|
||||
}
|
||||
ptr := uint32(sqlite.call(sqlite.api.aggregateCtx, uint64(pCtx), size))
|
||||
|
||||
// Try loading the handle, if we already have one, or want a new one.
|
||||
if ptr != 0 || size != 0 {
|
||||
if handle := util.ReadUint32(sqlite.mod, ptr); handle != 0 {
|
||||
fn := util.GetHandle(sqlite.ctx, handle)
|
||||
if close != nil {
|
||||
*close = handle
|
||||
}
|
||||
if fn != nil {
|
||||
return fn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new aggregate and store the handle.
|
||||
fn := callbackHandle(sqlite, pCtx).(func() AggregateFunction)()
|
||||
if ptr != 0 {
|
||||
util.WriteUint32(sqlite.mod, ptr, util.AddHandle(sqlite.ctx, fn))
|
||||
}
|
||||
return fn
|
||||
}
|
||||
|
||||
func callbackArgs(sqlite *sqlite, nArg, pArg uint32) []Value {
|
||||
args := make([]Value, nArg)
|
||||
for i := range args {
|
||||
args[i] = Value{
|
||||
sqlite: sqlite,
|
||||
handle: util.ReadUint32(sqlite.mod, pArg+ptrlen*uint32(i)),
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
154
func_test.go
Normal file
154
func_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"regexp"
|
||||
|
||||
"golang.org/x/text/collate"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func ExampleConn_CreateCollation() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateCollation("french", collate.New(language.French).Compare)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT word FROM words ORDER BY word COLLATE french`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Output:
|
||||
// cote
|
||||
// coté
|
||||
// côte
|
||||
// côté
|
||||
// cotée
|
||||
// coter
|
||||
}
|
||||
|
||||
func ExampleConn_CreateFunction() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateFunction("upper", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob()))
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT upper(word) FROM words`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Unordered output:
|
||||
// COTE
|
||||
// COTÉ
|
||||
// CÔTE
|
||||
// CÔTÉ
|
||||
// COTÉE
|
||||
// COTER
|
||||
}
|
||||
|
||||
func ExampleContext_SetAuxData() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateFunction("regexp", 2, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
|
||||
if !ok {
|
||||
r, err := regexp.Compile(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
ctx.SetAuxData(0, r)
|
||||
re = r
|
||||
}
|
||||
ctx.ResultBool(re.Match(arg[1].RawBlob()))
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT word FROM words WHERE word REGEXP '^\p{L}+e$'`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Unordered output:
|
||||
// cote
|
||||
// côte
|
||||
// cotée
|
||||
}
|
||||
87
func_win_test.go
Normal file
87
func_win_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"unicode"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func ExampleConn_CreateWindowFunction() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateWindowFunction("count_ascii", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, newASCIICounter)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT count_ascii(word) OVER (ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM words`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnInt(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Output:
|
||||
// 1
|
||||
// 2
|
||||
// 2
|
||||
// 1
|
||||
// 0
|
||||
// 0
|
||||
}
|
||||
|
||||
type countASCII struct{ result int }
|
||||
|
||||
func newASCIICounter() sqlite3.AggregateFunction {
|
||||
return &countASCII{}
|
||||
}
|
||||
|
||||
func (f *countASCII) Value(ctx sqlite3.Context) {
|
||||
ctx.ResultInt(f.result)
|
||||
}
|
||||
|
||||
func (f *countASCII) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if f.isASCII(arg[0]) {
|
||||
f.result++
|
||||
}
|
||||
}
|
||||
|
||||
func (f *countASCII) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if f.isASCII(arg[0]) {
|
||||
f.result--
|
||||
}
|
||||
}
|
||||
|
||||
func (f *countASCII) isASCII(arg sqlite3.Value) bool {
|
||||
if arg.Type() != sqlite3.TEXT {
|
||||
return false
|
||||
}
|
||||
for _, c := range arg.RawBlob() {
|
||||
if c > unicode.MaxASCII {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
9
go.mod
9
go.mod
@@ -1,13 +1,14 @@
|
||||
module github.com/ncruces/go-sqlite3
|
||||
|
||||
go 1.19
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/ncruces/julianday v0.1.5
|
||||
github.com/psanford/httpreadat v0.1.0
|
||||
github.com/tetratelabs/wazero v1.2.0
|
||||
golang.org/x/sync v0.2.0
|
||||
golang.org/x/sys v0.8.0
|
||||
github.com/tetratelabs/wazero v1.5.0
|
||||
golang.org/x/sync v0.4.0
|
||||
golang.org/x/sys v0.13.0
|
||||
golang.org/x/text v0.13.0
|
||||
)
|
||||
|
||||
retract v0.4.0 // tagged from the wrong branch
|
||||
|
||||
14
go.sum
14
go.sum
@@ -2,9 +2,11 @@ github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FB
|
||||
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
|
||||
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
|
||||
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
|
||||
github.com/tetratelabs/wazero v1.2.0 h1:I/8LMf4YkCZ3r2XaL9whhA0VMyAvF6QE+O7rco0DCeQ=
|
||||
github.com/tetratelabs/wazero v1.2.0/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
|
||||
golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
|
||||
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
|
||||
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
|
||||
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
|
||||
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
|
||||
5
go.work.sum
Normal file
5
go.work.sum
Normal file
@@ -0,0 +1,5 @@
|
||||
github.com/ncruces/go-sqlite3 v0.9.1/go.mod h1:jFoUbaCDNUS1KN5ZgFxN7bgcWoWfO0EOKeik9QAHZ08=
|
||||
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
@@ -1,5 +1,7 @@
|
||||
# GORM SQLite Driver
|
||||
|
||||
[](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
@@ -19,6 +21,6 @@ Checkout [https://gorm.io](https://gorm.io) for details.
|
||||
Foreign-key constraint is disabled by default in SQLite. To activate it, use connection URL parameter:
|
||||
```go
|
||||
db, err := gorm.Open(gormlite.Open(
|
||||
"file:gorm.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)&_pragma=foreign_keys(1)"),
|
||||
"file:gorm.db?_pragma=busy_timeout(10000)&_pragma=foreign_keys(1)"),
|
||||
&gorm.Config{})
|
||||
```
|
||||
@@ -125,7 +125,7 @@ func parseDDL(strs ...string) (*ddl, error) {
|
||||
ColumnTypeValue: sql.NullString{String: matches[2], Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Valid: true},
|
||||
UniqueValue: sql.NullBool{Valid: true},
|
||||
NullableValue: sql.NullBool{Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: true, Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ func parseDDL(strs ...string) (*ddl, error) {
|
||||
for _, column := range getAllColumns(matches[1]) {
|
||||
for idx, c := range result.columns {
|
||||
if c.NameValue.String == column {
|
||||
c.UniqueValue = sql.NullBool{Bool: true, Valid: true}
|
||||
c.UniqueValue = sql.NullBool{Bool: strings.ToUpper(strings.Fields(str)[1]) == "UNIQUE", Valid: true}
|
||||
result.columns[idx] = c
|
||||
}
|
||||
}
|
||||
@@ -175,6 +175,18 @@ func parseDDL(strs ...string) (*ddl, error) {
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (d *ddl) clone() *ddl {
|
||||
copied := new(ddl)
|
||||
*copied = *d
|
||||
|
||||
copied.fields = make([]string, len(d.fields))
|
||||
copy(copied.fields, d.fields)
|
||||
copied.columns = make([]migrator.ColumnType, len(d.columns))
|
||||
copy(copied.columns, d.columns)
|
||||
|
||||
return copied
|
||||
}
|
||||
|
||||
func (d *ddl) compile() string {
|
||||
if len(d.fields) == 0 {
|
||||
return d.head
|
||||
@@ -183,6 +195,21 @@ func (d *ddl) compile() string {
|
||||
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
|
||||
}
|
||||
|
||||
func (d *ddl) renameTable(dst, src string) error {
|
||||
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + regexp.QuoteMeta(src) + "\\b('|`|\")?\\s*")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
replaced := tableReg.ReplaceAllString(d.head, fmt.Sprintf(" `%s` ", dst))
|
||||
if replaced == d.head {
|
||||
return fmt.Errorf("failed to look up tablename `%s` from DDL head '%s'", src, d.head)
|
||||
}
|
||||
|
||||
d.head = replaced
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *ddl) addConstraint(name string, sql string) {
|
||||
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
|
||||
|
||||
@@ -208,6 +235,17 @@ func (d *ddl) removeConstraint(name string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *ddl) hasConstraint(name string) bool {
|
||||
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
|
||||
|
||||
for _, f := range d.fields {
|
||||
if reg.MatchString(f) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *ddl) getColumns() []string {
|
||||
res := []string{}
|
||||
|
||||
@@ -229,3 +267,30 @@ func (d *ddl) getColumns() []string {
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (d *ddl) alterColumn(name, sql string) bool {
|
||||
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields[i] = sql
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
d.fields = append(d.fields, sql)
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *ddl) removeColumn(name string) bool {
|
||||
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields = append(d.fields[:i], d.fields[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -20,16 +20,16 @@ func TestParseDDL(t *testing.T) {
|
||||
"CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)",
|
||||
}, 6, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}},
|
||||
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
},
|
||||
},
|
||||
{"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
}},
|
||||
{"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
@@ -39,7 +39,7 @@ func TestParseDDL(t *testing.T) {
|
||||
{"with_special_characters", []string{
|
||||
"CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")",
|
||||
}, 1, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -79,6 +79,24 @@ func TestParseDDL(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"non-unique index",
|
||||
[]string{
|
||||
"CREATE TABLE `test-c` (`field` integer NOT NULL)",
|
||||
"CREATE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
|
||||
},
|
||||
1,
|
||||
[]migrator.ColumnType{
|
||||
{
|
||||
NameValue: sql.NullString{String: "field", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
|
||||
UniqueValue: sql.NullBool{Bool: false, Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
@@ -104,7 +122,7 @@ func TestParseDDL_Whitespaces(t *testing.T) {
|
||||
NameValue: sql.NullString{String: "id", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: true, Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
UniqueValue: sql.NullBool{Bool: true, Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true},
|
||||
@@ -113,7 +131,7 @@ func TestParseDDL_Whitespaces(t *testing.T) {
|
||||
NameValue: sql.NullString{String: "dark_mode", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "numeric", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "numeric", Valid: true},
|
||||
NullableValue: sql.NullBool{Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: true, Valid: true},
|
||||
DefaultValueValue: sql.NullString{String: "true", Valid: true},
|
||||
UniqueValue: sql.NullBool{Bool: false, Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
|
||||
|
||||
11
gormlite/download.sh
Executable file
11
gormlite/download.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/ddlmod.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/ddlmod_test.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/error_translator.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/migrator.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/sqlite.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/sqlite_test.go"
|
||||
@@ -7,9 +7,15 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (dialector Dialector) Translate(err error) error {
|
||||
if errors.Is(err, sqlite3.CONSTRAINT_UNIQUE) {
|
||||
func (_Dialector) Translate(err error) error {
|
||||
switch {
|
||||
case
|
||||
errors.Is(err, sqlite3.CONSTRAINT_UNIQUE),
|
||||
errors.Is(err, sqlite3.CONSTRAINT_PRIMARYKEY):
|
||||
return gorm.ErrDuplicatedKey
|
||||
case
|
||||
errors.Is(err, sqlite3.CONSTRAINT_FOREIGNKEY):
|
||||
return gorm.ErrForeignKeyViolated
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
package gormlite
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrConstraintsNotImplemented = errors.New("constraints not implemented on sqlite, consider using DisableForeignKeyConstraintWhenMigrating, more details https://github.com/go-gorm/gorm/wiki/GORM-V2-Release-Note-Draft#all-new-migrator")
|
||||
)
|
||||
@@ -1,16 +1,16 @@
|
||||
module github.com/ncruces/go-sqlite3/gormlite
|
||||
|
||||
go 1.19
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/ncruces/go-sqlite3 v0.7.2
|
||||
gorm.io/gorm v1.25.1
|
||||
github.com/ncruces/go-sqlite3 v0.9.1
|
||||
gorm.io/gorm v1.25.5
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/ncruces/julianday v0.1.5 // indirect
|
||||
github.com/tetratelabs/wazero v1.2.0 // indirect
|
||||
golang.org/x/sys v0.8.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.5.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
)
|
||||
|
||||
@@ -2,13 +2,15 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/ncruces/go-sqlite3 v0.7.2 h1:K7jU4rnUxFdUsbEL+B0Xc+VexLTEwGSO6Qh91Qh4hYc=
|
||||
github.com/ncruces/go-sqlite3 v0.7.2/go.mod h1:t3dP4AP9rJddU+ffFv0h6fWyeOCEhjxrYc1nsYG7aQI=
|
||||
github.com/ncruces/go-sqlite3 v0.9.1 h1:kV7Zy+ZNyHMfMyZeWc1Yyq+wtgYZDZdp2qAA/wfeMWo=
|
||||
github.com/ncruces/go-sqlite3 v0.9.1/go.mod h1:jFoUbaCDNUS1KN5ZgFxN7bgcWoWfO0EOKeik9QAHZ08=
|
||||
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
|
||||
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
|
||||
github.com/tetratelabs/wazero v1.2.0 h1:I/8LMf4YkCZ3r2XaL9whhA0VMyAvF6QE+O7rco0DCeQ=
|
||||
github.com/tetratelabs/wazero v1.2.0/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
|
||||
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
gorm.io/gorm v1.25.1 h1:nsSALe5Pr+cM3V1qwwQ7rOkw+6UeLrX5O4v3llhHa64=
|
||||
gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
|
||||
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
|
||||
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
|
||||
@@ -3,7 +3,6 @@ package gormlite
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -12,11 +11,11 @@ import (
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type Migrator struct {
|
||||
type _Migrator struct {
|
||||
migrator.Migrator
|
||||
}
|
||||
|
||||
func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
|
||||
func (m *_Migrator) RunWithoutForeignKey(fc func() error) error {
|
||||
var enabled int
|
||||
m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
|
||||
if enabled == 1 {
|
||||
@@ -27,7 +26,7 @@ func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
|
||||
return fc()
|
||||
}
|
||||
|
||||
func (m Migrator) HasTable(value interface{}) bool {
|
||||
func (m _Migrator) HasTable(value interface{}) bool {
|
||||
var count int
|
||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
|
||||
@@ -35,7 +34,7 @@ func (m Migrator) HasTable(value interface{}) bool {
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) DropTable(values ...interface{}) error {
|
||||
func (m _Migrator) DropTable(values ...interface{}) error {
|
||||
return m.RunWithoutForeignKey(func() error {
|
||||
values = m.ReorderModels(values, false)
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
@@ -52,11 +51,11 @@ func (m Migrator) DropTable(values ...interface{}) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) GetTables() (tableList []string, err error) {
|
||||
func (m _Migrator) GetTables() (tableList []string, err error) {
|
||||
return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error
|
||||
}
|
||||
|
||||
func (m Migrator) HasColumn(value interface{}, name string) bool {
|
||||
func (m _Migrator) HasColumn(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
@@ -76,31 +75,24 @@ func (m Migrator) HasColumn(value interface{}, name string) bool {
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) AlterColumn(value interface{}, name string) error {
|
||||
func (m _Migrator) AlterColumn(value interface{}, name string) error {
|
||||
return m.RunWithoutForeignKey(func() error {
|
||||
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
// lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)`
|
||||
reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
if ddl.alterColumn(field.DBName, fmt.Sprintf("`%s` ?", field.DBName)) {
|
||||
return nil, nil, fmt.Errorf("field `%s` not found in origin ddl, ddl= '%s'", name, ddl.compile())
|
||||
}
|
||||
|
||||
createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName))
|
||||
|
||||
if createSQL == rawDDL {
|
||||
return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL)
|
||||
}
|
||||
|
||||
return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil
|
||||
return ddl, []interface{}{m.FullDataTypeOf(field)}, nil
|
||||
}
|
||||
return "", nil, fmt.Errorf("failed to alter field with name %v", name)
|
||||
|
||||
return nil, nil, fmt.Errorf("failed to alter field with name `%s`", name)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
|
||||
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
func (m _Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
columnTypes := make([]gorm.ColumnType, 0)
|
||||
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
|
||||
var (
|
||||
@@ -148,29 +140,23 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
return columnTypes, execErr
|
||||
}
|
||||
|
||||
func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
func (m _Migrator) DropColumn(value interface{}, name string) error {
|
||||
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
createSQL := reg.ReplaceAllString(rawDDL, "")
|
||||
|
||||
return createSQL, nil, nil
|
||||
ddl.removeColumn(name)
|
||||
return ddl, nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
func (m _Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
|
||||
return m.recreateTable(value, &table,
|
||||
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
|
||||
var (
|
||||
constraintName string
|
||||
constraintSql string
|
||||
@@ -185,22 +171,16 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
constraintSql = "CONSTRAINT ? CHECK (?)"
|
||||
constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
|
||||
} else {
|
||||
return "", nil, nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
createDDL, err := parseDDL(rawDDL)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
createDDL.addConstraint(constraintName, constraintSql)
|
||||
createSQL := createDDL.compile()
|
||||
|
||||
return createSQL, constraintValues, nil
|
||||
ddl.addConstraint(constraintName, constraintSql)
|
||||
return ddl, constraintValues, nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||
func (m _Migrator) DropConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
@@ -210,20 +190,14 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||
}
|
||||
|
||||
return m.recreateTable(value, &table,
|
||||
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
createDDL, err := parseDDL(rawDDL)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
createDDL.removeConstraint(name)
|
||||
createSQL := createDDL.compile()
|
||||
|
||||
return createSQL, nil, nil
|
||||
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
|
||||
ddl.removeConstraint(name)
|
||||
return ddl, nil, nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
func (m _Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
@@ -244,13 +218,13 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) CurrentDatabase() (name string) {
|
||||
func (m _Migrator) CurrentDatabase() (name string) {
|
||||
var null interface{}
|
||||
m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
|
||||
func (m _Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
|
||||
for _, opt := range opts {
|
||||
str := stmt.Quote(opt.DBName)
|
||||
if opt.Expression != "" {
|
||||
@@ -269,7 +243,7 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
func (m _Migrator) CreateIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
@@ -298,7 +272,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
func (m _Migrator) HasIndex(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
@@ -317,18 +291,21 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
func (m _Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
var sql string
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
|
||||
if sql != "" {
|
||||
if err := m.DropIndex(value, oldName); err != nil {
|
||||
return err
|
||||
}
|
||||
return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
|
||||
}
|
||||
return fmt.Errorf("failed to find index with name %v", oldName)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropIndex(value interface{}, name string) error {
|
||||
func (m _Migrator) DropIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
@@ -362,7 +339,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) getRawDDL(table string) (string, error) {
|
||||
func (m _Migrator) getRawDDL(table string) (string, error) {
|
||||
var createSQL string
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
|
||||
|
||||
@@ -372,8 +349,10 @@ func (m Migrator) getRawDDL(table string) (string, error) {
|
||||
return createSQL, nil
|
||||
}
|
||||
|
||||
func (m Migrator) recreateTable(value interface{}, tablePtr *string,
|
||||
getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error {
|
||||
func (m _Migrator) recreateTable(
|
||||
value interface{}, tablePtr *string,
|
||||
getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error),
|
||||
) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
table := stmt.Table
|
||||
if tablePtr != nil {
|
||||
@@ -385,27 +364,26 @@ func (m Migrator) recreateTable(value interface{}, tablePtr *string,
|
||||
return err
|
||||
}
|
||||
|
||||
newTableName := table + "__temp"
|
||||
|
||||
createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt)
|
||||
originDDL, err := parseDDL(rawDDL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if createSQL == "" {
|
||||
|
||||
createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if createDDL == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
|
||||
if err != nil {
|
||||
newTableName := table + "__temp"
|
||||
if err := createDDL.renameTable(newTableName, table); err != nil {
|
||||
return err
|
||||
}
|
||||
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
|
||||
|
||||
createDDL, err := parseDDL(createSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
columns := createDDL.getColumns()
|
||||
createSQL := createDDL.compile()
|
||||
|
||||
return m.DB.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
@@ -14,27 +13,33 @@ import (
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
)
|
||||
|
||||
type Dialector struct {
|
||||
// Open opens a GORM dialector from a data source name.
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &_Dialector{DSN: dsn}
|
||||
}
|
||||
|
||||
// Open opens a GORM dialector from a database handle.
|
||||
func OpenDB(db *sql.DB) gorm.Dialector {
|
||||
return &_Dialector{Conn: db}
|
||||
}
|
||||
|
||||
type _Dialector struct {
|
||||
DSN string
|
||||
Conn gorm.ConnPool
|
||||
}
|
||||
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &Dialector{DSN: dsn}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Name() string {
|
||||
func (dialector _Dialector) Name() string {
|
||||
return "sqlite"
|
||||
}
|
||||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
func (dialector _Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
if dialector.Conn != nil {
|
||||
db.ConnPool = dialector.Conn
|
||||
} else {
|
||||
conn, err := sql.Open("sqlite3", dialector.DSN)
|
||||
conn, err := driver.Open(dialector.DSN, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -49,7 +54,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
if compareVersion(version, "3.35.0") >= 0 {
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
|
||||
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
|
||||
UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE", "RETURNING"},
|
||||
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
|
||||
LastInsertIDReversed: true,
|
||||
})
|
||||
@@ -65,7 +70,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
||||
func (dialector _Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
||||
return map[string]clause.ClauseBuilder{
|
||||
"INSERT": func(c clause.Clause, builder clause.Builder) {
|
||||
if insert, ok := c.Expression.(clause.Insert); ok {
|
||||
@@ -114,7 +119,7 @@ func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
|
||||
func (dialector _Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
|
||||
if field.AutoIncrement {
|
||||
return clause.Expr{SQL: "NULL"}
|
||||
}
|
||||
@@ -123,44 +128,77 @@ func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression
|
||||
return clause.Expr{SQL: "DEFAULT"}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
func (dialector _Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return _Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
CreateIndexAfterCreateTable: true,
|
||||
}}}
|
||||
}
|
||||
|
||||
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||||
func (dialector _Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||||
writer.WriteByte('?')
|
||||
}
|
||||
|
||||
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
writer.WriteByte('`')
|
||||
if strings.Contains(str, ".") {
|
||||
for idx, str := range strings.Split(str, ".") {
|
||||
if idx > 0 {
|
||||
writer.WriteString(".`")
|
||||
func (dialector _Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
var (
|
||||
underQuoted, selfQuoted bool
|
||||
continuousBacktick int8
|
||||
shiftDelimiter int8
|
||||
)
|
||||
|
||||
for _, v := range []byte(str) {
|
||||
switch v {
|
||||
case '`':
|
||||
continuousBacktick++
|
||||
if continuousBacktick == 2 {
|
||||
writer.WriteString("``")
|
||||
continuousBacktick = 0
|
||||
}
|
||||
writer.WriteString(str)
|
||||
writer.WriteByte('`')
|
||||
case '.':
|
||||
if continuousBacktick > 0 || !selfQuoted {
|
||||
shiftDelimiter = 0
|
||||
underQuoted = false
|
||||
continuousBacktick = 0
|
||||
writer.WriteString("`")
|
||||
}
|
||||
writer.WriteByte(v)
|
||||
continue
|
||||
default:
|
||||
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
|
||||
writer.WriteString("`")
|
||||
underQuoted = true
|
||||
if selfQuoted = continuousBacktick > 0; selfQuoted {
|
||||
continuousBacktick -= 1
|
||||
}
|
||||
}
|
||||
|
||||
for ; continuousBacktick > 0; continuousBacktick -= 1 {
|
||||
writer.WriteString("``")
|
||||
}
|
||||
|
||||
writer.WriteByte(v)
|
||||
}
|
||||
} else {
|
||||
writer.WriteString(str)
|
||||
writer.WriteByte('`')
|
||||
shiftDelimiter++
|
||||
}
|
||||
|
||||
if continuousBacktick > 0 && !selfQuoted {
|
||||
writer.WriteString("``")
|
||||
}
|
||||
writer.WriteString("`")
|
||||
}
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
func (dialector _Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, nil, `"`, vars...)
|
||||
}
|
||||
|
||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
func (dialector _Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
return "numeric"
|
||||
case schema.Int, schema.Uint:
|
||||
if field.AutoIncrement && !field.PrimaryKey {
|
||||
if field.AutoIncrement {
|
||||
// doesn't check `PrimaryKey`, to keep backward compatibility
|
||||
// https://www.sqlite.org/autoinc.html
|
||||
return "integer PRIMARY KEY AUTOINCREMENT"
|
||||
} else {
|
||||
@@ -184,12 +222,12 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
return string(field.DataType)
|
||||
}
|
||||
|
||||
func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error {
|
||||
func (dialectopr _Dialector) SavePoint(tx *gorm.DB, name string) error {
|
||||
tx.Exec("SAVEPOINT " + name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
|
||||
func (dialectopr _Dialector) RollbackTo(tx *gorm.DB, name string) error {
|
||||
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
@@ -13,22 +16,52 @@ func TestDialector(t *testing.T) {
|
||||
// This is the DSN of the in-memory SQLite database for these tests.
|
||||
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
|
||||
|
||||
// Custom connection with a custom function called "my_custom_function".
|
||||
db, err := driver.Open(InMemoryDSN, func(ctx context.Context, conn *sqlite3.Conn) error {
|
||||
return conn.CreateFunction("my_custom_function", 0, sqlite3.DETERMINISTIC,
|
||||
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
ctx.ResultText("my-result")
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows := []struct {
|
||||
description string
|
||||
dialector *Dialector
|
||||
dialector gorm.Dialector
|
||||
openSuccess bool
|
||||
query string
|
||||
querySuccess bool
|
||||
}{
|
||||
{
|
||||
description: "Default driver",
|
||||
dialector: &Dialector{
|
||||
DSN: InMemoryDSN,
|
||||
},
|
||||
description: "Default driver",
|
||||
dialector: Open(InMemoryDSN),
|
||||
openSuccess: true,
|
||||
query: "SELECT 1",
|
||||
querySuccess: true,
|
||||
},
|
||||
{
|
||||
description: "Custom function",
|
||||
dialector: Open(InMemoryDSN),
|
||||
openSuccess: true,
|
||||
query: "SELECT my_custom_function()",
|
||||
querySuccess: false,
|
||||
},
|
||||
{
|
||||
description: "Custom connection",
|
||||
dialector: OpenDB(db),
|
||||
openSuccess: true,
|
||||
query: "SELECT 1",
|
||||
querySuccess: true,
|
||||
},
|
||||
{
|
||||
description: "Custom connection, custom function",
|
||||
dialector: OpenDB(db),
|
||||
openSuccess: true,
|
||||
query: "SELECT my_custom_function()",
|
||||
querySuccess: true,
|
||||
},
|
||||
}
|
||||
for rowIndex, row := range rows {
|
||||
t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) {
|
||||
|
||||
@@ -4,11 +4,21 @@ set -euo pipefail
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
rm -rf gorm/ tests/
|
||||
git clone --filter=blob:none --branch=v1.25.1 https://github.com/go-gorm/gorm.git
|
||||
git clone --filter=blob:none https://github.com/go-gorm/gorm.git
|
||||
mv gorm/tests tests
|
||||
rm -rf gorm/
|
||||
|
||||
patch -p1 -N < tests.patch
|
||||
|
||||
cd tests
|
||||
go mod tidy && go test
|
||||
go mod edit \
|
||||
-require github.com/ncruces/go-sqlite3/gormlite@v0.0.0 \
|
||||
-replace github.com/ncruces/go-sqlite3/gormlite=../ \
|
||||
-replace github.com/ncruces/go-sqlite3=../../ \
|
||||
-droprequire gorm.io/driver/sqlite \
|
||||
-dropreplace gorm.io/gorm
|
||||
go mod tidy && go work use . && go test
|
||||
|
||||
cd ..
|
||||
rm -rf tests/
|
||||
go work use -r .
|
||||
@@ -1,51 +1,10 @@
|
||||
diff --git a/tests/.gitignore b/tests/.gitignore
|
||||
index 08cb523..72e8ffc 100644
|
||||
--- a/tests/.gitignore
|
||||
+++ b/tests/.gitignore
|
||||
@@ -1 +1 @@
|
||||
-go.sum
|
||||
+*
|
||||
diff --git a/tests/go.mod b/tests/go.mod
|
||||
index f47d175..84b80c2 100644
|
||||
--- a/tests/go.mod
|
||||
+++ b/tests/go.mod
|
||||
@@ -7,13 +7,13 @@ require (
|
||||
github.com/jackc/pgx/v5 v5.3.1 // indirect
|
||||
github.com/jinzhu/now v1.1.5
|
||||
github.com/lib/pq v1.10.8
|
||||
- github.com/mattn/go-sqlite3 v1.14.16 // indirect
|
||||
+ github.com/ncruces/go-sqlite3 v0.7.2
|
||||
+ github.com/ncruces/go-sqlite3/gormlite v0.0.0
|
||||
golang.org/x/crypto v0.8.0 // indirect
|
||||
gorm.io/driver/mysql v1.5.0
|
||||
gorm.io/driver/postgres v1.5.0
|
||||
- gorm.io/driver/sqlite v1.5.0
|
||||
gorm.io/driver/sqlserver v1.4.3
|
||||
- gorm.io/gorm v1.25.0
|
||||
+ gorm.io/gorm v1.25.1
|
||||
)
|
||||
|
||||
-replace gorm.io/gorm => ../
|
||||
+replace github.com/ncruces/go-sqlite3/gormlite => ../
|
||||
diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go
|
||||
index 1412169..472434b 100644
|
||||
--- a/tests/scanner_valuer_test.go
|
||||
+++ b/tests/scanner_valuer_test.go
|
||||
@@ -170,10 +170,10 @@ func (data *EncryptedData) Scan(value interface{}) error {
|
||||
return errors.New("Too short")
|
||||
}
|
||||
|
||||
- *data = b[3:]
|
||||
+ *data = append((*data)[0:], b[3:]...)
|
||||
return nil
|
||||
} else if s, ok := value.(string); ok {
|
||||
- *data = []byte(s)[3:]
|
||||
+ *data = []byte(s[3:])
|
||||
return nil
|
||||
}
|
||||
|
||||
diff --git a/tests/tests_test.go b/tests/tests_test.go
|
||||
index 90eb847..cd9af43 100644
|
||||
--- a/tests/tests_test.go
|
||||
+++ b/tests/tests_test.go
|
||||
@@ -7,9 +7,11 @@ import (
|
||||
@@ -61,3 +20,12 @@ index 90eb847..cd9af43 100644
|
||||
"gorm.io/driver/sqlserver"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
@@ -89,7 +91,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
|
||||
db, err = gorm.Open(mysql.Open(dbDSN), cfg)
|
||||
default:
|
||||
log.Println("testing sqlite3...")
|
||||
- db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg)
|
||||
+ db, err = gorm.Open(sqlite.Open("file:"+filepath.Join(os.TempDir(), "gorm.db")+"?_pragma=busy_timeout(1000)&_pragma=foreign_keys(1)"), cfg)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -10,6 +10,32 @@ import (
|
||||
type i32 interface{ ~int32 | ~uint32 }
|
||||
type i64 interface{ ~int64 | ~uint64 }
|
||||
|
||||
type funcVI[T0 i32] func(context.Context, api.Module, T0)
|
||||
|
||||
func (fn funcVI[T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
fn(ctx, mod, T0(stack[0]))
|
||||
}
|
||||
|
||||
func ExportFuncVI[T0 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0)) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(funcVI[T0](fn),
|
||||
[]api.ValueType{api.ValueTypeI32}, nil).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
type funcVIII[T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2)
|
||||
|
||||
func (fn funcVIII[T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]))
|
||||
}
|
||||
|
||||
func ExportFuncVIII[T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2)) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(funcVIII[T0, T1, T2](fn),
|
||||
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, nil).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
type funcII[TR, T0 i32] func(context.Context, api.Module, T0) TR
|
||||
|
||||
func (fn funcII[TR, T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
|
||||
75
internal/util/handle.go
Normal file
75
internal/util/handle.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/tetratelabs/wazero/experimental"
|
||||
)
|
||||
|
||||
type handleKey struct{}
|
||||
type handleState struct {
|
||||
handles []any
|
||||
empty int
|
||||
}
|
||||
|
||||
func NewContext(ctx context.Context) context.Context {
|
||||
state := new(handleState)
|
||||
ctx = experimental.WithCloseNotifier(ctx, state)
|
||||
ctx = context.WithValue(ctx, handleKey{}, state)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (s *handleState) CloseNotify(ctx context.Context, exitCode uint32) {
|
||||
for _, h := range s.handles {
|
||||
if c, ok := h.(io.Closer); ok {
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
s.handles = nil
|
||||
s.empty = 0
|
||||
}
|
||||
|
||||
func GetHandle(ctx context.Context, id uint32) any {
|
||||
if id == 0 {
|
||||
return nil
|
||||
}
|
||||
s := ctx.Value(handleKey{}).(*handleState)
|
||||
return s.handles[^id]
|
||||
}
|
||||
|
||||
func DelHandle(ctx context.Context, id uint32) error {
|
||||
if id == 0 {
|
||||
return nil
|
||||
}
|
||||
s := ctx.Value(handleKey{}).(*handleState)
|
||||
a := s.handles[^id]
|
||||
s.handles[^id] = nil
|
||||
s.empty++
|
||||
if c, ok := a.(io.Closer); ok {
|
||||
return c.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddHandle(ctx context.Context, a any) (id uint32) {
|
||||
if a == nil {
|
||||
panic(NilErr)
|
||||
}
|
||||
s := ctx.Value(handleKey{}).(*handleState)
|
||||
|
||||
// Find an empty slot.
|
||||
if s.empty > cap(s.handles)-len(s.handles) {
|
||||
for id, h := range s.handles {
|
||||
if h == nil {
|
||||
s.empty--
|
||||
s.handles[id] = a
|
||||
return ^uint32(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add a new slot.
|
||||
s.handles = append(s.handles, a)
|
||||
return -uint32(len(s.handles))
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -25,70 +24,67 @@ var (
|
||||
Path string // Path to load the binary from.
|
||||
)
|
||||
|
||||
var sqlite3 struct {
|
||||
var instance struct {
|
||||
runtime wazero.Runtime
|
||||
compiled wazero.CompiledModule
|
||||
err error
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func instantiateModule() (*module, error) {
|
||||
func compileSQLite() {
|
||||
ctx := context.Background()
|
||||
instance.runtime = wazero.NewRuntime(ctx)
|
||||
|
||||
sqlite3.once.Do(compileModule)
|
||||
if sqlite3.err != nil {
|
||||
return nil, sqlite3.err
|
||||
}
|
||||
|
||||
cfg := wazero.NewModuleConfig()
|
||||
|
||||
mod, err := sqlite3.runtime.InstantiateModule(ctx, sqlite3.compiled, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newModule(mod)
|
||||
}
|
||||
|
||||
func compileModule() {
|
||||
ctx := context.Background()
|
||||
sqlite3.runtime = wazero.NewRuntime(ctx)
|
||||
|
||||
env := vfs.ExportHostFunctions(sqlite3.runtime.NewHostModuleBuilder("env"))
|
||||
_, sqlite3.err = env.Instantiate(ctx)
|
||||
if sqlite3.err != nil {
|
||||
env := instance.runtime.NewHostModuleBuilder("env")
|
||||
env = vfs.ExportHostFunctions(env)
|
||||
env = exportHostFunctions(env)
|
||||
_, instance.err = env.Instantiate(ctx)
|
||||
if instance.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
bin := Binary
|
||||
if bin == nil && Path != "" {
|
||||
bin, sqlite3.err = os.ReadFile(Path)
|
||||
if sqlite3.err != nil {
|
||||
bin, instance.err = os.ReadFile(Path)
|
||||
if instance.err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if bin == nil {
|
||||
sqlite3.err = util.BinaryErr
|
||||
instance.err = util.BinaryErr
|
||||
return
|
||||
}
|
||||
|
||||
sqlite3.compiled, sqlite3.err = sqlite3.runtime.CompileModule(ctx, bin)
|
||||
instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin)
|
||||
}
|
||||
|
||||
type module struct {
|
||||
ctx context.Context
|
||||
mod api.Module
|
||||
vfs io.Closer
|
||||
api sqliteAPI
|
||||
arg [8]uint64
|
||||
type sqlite struct {
|
||||
ctx context.Context
|
||||
mod api.Module
|
||||
api sqliteAPI
|
||||
stack [8]uint64
|
||||
}
|
||||
|
||||
func newModule(mod api.Module) (m *module, err error) {
|
||||
m = new(module)
|
||||
m.mod = mod
|
||||
m.ctx, m.vfs = vfs.NewContext(context.Background())
|
||||
type sqliteKey struct{}
|
||||
|
||||
func instantiateSQLite() (sqlt *sqlite, err error) {
|
||||
instance.once.Do(compileSQLite)
|
||||
if instance.err != nil {
|
||||
return nil, instance.err
|
||||
}
|
||||
|
||||
sqlt = new(sqlite)
|
||||
sqlt.ctx = util.NewContext(context.Background())
|
||||
sqlt.ctx = context.WithValue(sqlt.ctx, sqliteKey{}, sqlt)
|
||||
|
||||
sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
|
||||
instance.compiled, wazero.NewModuleConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getFun := func(name string) api.Function {
|
||||
f := mod.ExportedFunction(name)
|
||||
f := sqlt.mod.ExportedFunction(name)
|
||||
if f == nil {
|
||||
err = util.NoFuncErr + util.ErrorString(name)
|
||||
return nil
|
||||
@@ -97,15 +93,15 @@ func newModule(mod api.Module) (m *module, err error) {
|
||||
}
|
||||
|
||||
getVal := func(name string) uint32 {
|
||||
g := mod.ExportedGlobal(name)
|
||||
g := sqlt.mod.ExportedGlobal(name)
|
||||
if g == nil {
|
||||
err = util.NoGlobalErr + util.ErrorString(name)
|
||||
return 0
|
||||
}
|
||||
return util.ReadUint32(mod, uint32(g.Get()))
|
||||
return util.ReadUint32(sqlt.mod, uint32(g.Get()))
|
||||
}
|
||||
|
||||
m.api = sqliteAPI{
|
||||
sqlt.api = sqliteAPI{
|
||||
free: getFun("free"),
|
||||
malloc: getFun("malloc"),
|
||||
destructor: getVal("malloc_destructor"),
|
||||
@@ -153,20 +149,43 @@ func newModule(mod api.Module) (m *module, err error) {
|
||||
changes: getFun("sqlite3_changes64"),
|
||||
lastRowid: getFun("sqlite3_last_insert_rowid"),
|
||||
autocommit: getFun("sqlite3_get_autocommit"),
|
||||
anyCollation: getFun("sqlite3_anycollseq_init"),
|
||||
createCollation: getFun("sqlite3_create_collation_go"),
|
||||
createFunction: getFun("sqlite3_create_function_go"),
|
||||
createAggregate: getFun("sqlite3_create_aggregate_function_go"),
|
||||
createWindow: getFun("sqlite3_create_window_function_go"),
|
||||
aggregateCtx: getFun("sqlite3_aggregate_context"),
|
||||
userData: getFun("sqlite3_user_data"),
|
||||
setAuxData: getFun("sqlite3_set_auxdata_go"),
|
||||
getAuxData: getFun("sqlite3_get_auxdata"),
|
||||
valueType: getFun("sqlite3_value_type"),
|
||||
valueInteger: getFun("sqlite3_value_int64"),
|
||||
valueFloat: getFun("sqlite3_value_double"),
|
||||
valueText: getFun("sqlite3_value_text"),
|
||||
valueBlob: getFun("sqlite3_value_blob"),
|
||||
valueBytes: getFun("sqlite3_value_bytes"),
|
||||
resultNull: getFun("sqlite3_result_null"),
|
||||
resultInteger: getFun("sqlite3_result_int64"),
|
||||
resultFloat: getFun("sqlite3_result_double"),
|
||||
resultText: getFun("sqlite3_result_text64"),
|
||||
resultBlob: getFun("sqlite3_result_blob64"),
|
||||
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
|
||||
resultError: getFun("sqlite3_result_error"),
|
||||
resultErrorCode: getFun("sqlite3_result_error_code"),
|
||||
resultErrorMem: getFun("sqlite3_result_error_nomem"),
|
||||
resultErrorBig: getFun("sqlite3_result_error_toobig"),
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
return sqlt, nil
|
||||
}
|
||||
|
||||
func (m *module) close() error {
|
||||
err := m.mod.Close(m.ctx)
|
||||
m.vfs.Close()
|
||||
return err
|
||||
func (sqlt *sqlite) close() error {
|
||||
return sqlt.mod.Close(sqlt.ctx)
|
||||
}
|
||||
|
||||
func (m *module) error(rc uint64, handle uint32, sql ...string) error {
|
||||
func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
|
||||
if rc == _OK {
|
||||
return nil
|
||||
}
|
||||
@@ -177,16 +196,16 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
|
||||
if r := m.call(m.api.errstr, rc); r != 0 {
|
||||
err.str = util.ReadString(m.mod, uint32(r), _MAX_STRING)
|
||||
if r := sqlt.call(sqlt.api.errstr, rc); r != 0 {
|
||||
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
|
||||
}
|
||||
|
||||
if r := m.call(m.api.errmsg, uint64(handle)); r != 0 {
|
||||
err.msg = util.ReadString(m.mod, uint32(r), _MAX_STRING)
|
||||
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
|
||||
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
|
||||
}
|
||||
|
||||
if sql != nil {
|
||||
if r := m.call(m.api.erroff, uint64(handle)); r != math.MaxUint32 {
|
||||
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
|
||||
err.sql = sql[0][r:]
|
||||
}
|
||||
}
|
||||
@@ -198,60 +217,58 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
|
||||
return &err
|
||||
}
|
||||
|
||||
func (m *module) call(fn api.Function, params ...uint64) uint64 {
|
||||
copy(m.arg[:], params)
|
||||
err := fn.CallWithStack(m.ctx, m.arg[:])
|
||||
func (sqlt *sqlite) call(fn api.Function, params ...uint64) uint64 {
|
||||
copy(sqlt.stack[:], params)
|
||||
err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
|
||||
if err != nil {
|
||||
// The module closed or panicked; release resources.
|
||||
m.vfs.Close()
|
||||
panic(err)
|
||||
}
|
||||
return m.arg[0]
|
||||
return sqlt.stack[0]
|
||||
}
|
||||
|
||||
func (m *module) free(ptr uint32) {
|
||||
func (sqlt *sqlite) free(ptr uint32) {
|
||||
if ptr == 0 {
|
||||
return
|
||||
}
|
||||
m.call(m.api.free, uint64(ptr))
|
||||
sqlt.call(sqlt.api.free, uint64(ptr))
|
||||
}
|
||||
|
||||
func (m *module) new(size uint64) uint32 {
|
||||
func (sqlt *sqlite) new(size uint64) uint32 {
|
||||
if size > _MAX_ALLOCATION_SIZE {
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
ptr := uint32(m.call(m.api.malloc, size))
|
||||
ptr := uint32(sqlt.call(sqlt.api.malloc, size))
|
||||
if ptr == 0 && size != 0 {
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (m *module) newBytes(b []byte) uint32 {
|
||||
func (sqlt *sqlite) newBytes(b []byte) uint32 {
|
||||
if b == nil {
|
||||
return 0
|
||||
}
|
||||
ptr := m.new(uint64(len(b)))
|
||||
util.WriteBytes(m.mod, ptr, b)
|
||||
ptr := sqlt.new(uint64(len(b)))
|
||||
util.WriteBytes(sqlt.mod, ptr, b)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (m *module) newString(s string) uint32 {
|
||||
ptr := m.new(uint64(len(s) + 1))
|
||||
util.WriteString(m.mod, ptr, s)
|
||||
func (sqlt *sqlite) newString(s string) uint32 {
|
||||
ptr := sqlt.new(uint64(len(s) + 1))
|
||||
util.WriteString(sqlt.mod, ptr, s)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (m *module) newArena(size uint64) arena {
|
||||
func (sqlt *sqlite) newArena(size uint64) arena {
|
||||
return arena{
|
||||
m: m,
|
||||
base: m.new(size),
|
||||
sqlt: sqlt,
|
||||
size: uint32(size),
|
||||
base: sqlt.new(size),
|
||||
}
|
||||
}
|
||||
|
||||
type arena struct {
|
||||
m *module
|
||||
sqlt *sqlite
|
||||
ptrs []uint32
|
||||
base uint32
|
||||
next uint32
|
||||
@@ -259,17 +276,17 @@ type arena struct {
|
||||
}
|
||||
|
||||
func (a *arena) free() {
|
||||
if a.m == nil {
|
||||
if a.sqlt == nil {
|
||||
return
|
||||
}
|
||||
a.reset()
|
||||
a.m.free(a.base)
|
||||
a.m = nil
|
||||
a.sqlt.free(a.base)
|
||||
a.sqlt = nil
|
||||
}
|
||||
|
||||
func (a *arena) reset() {
|
||||
for _, ptr := range a.ptrs {
|
||||
a.m.free(ptr)
|
||||
a.sqlt.free(ptr)
|
||||
}
|
||||
a.ptrs = nil
|
||||
a.next = 0
|
||||
@@ -281,7 +298,7 @@ func (a *arena) new(size uint64) uint32 {
|
||||
a.next += uint32(size)
|
||||
return ptr
|
||||
}
|
||||
ptr := a.m.new(size)
|
||||
ptr := a.sqlt.new(size)
|
||||
a.ptrs = append(a.ptrs, ptr)
|
||||
return ptr
|
||||
}
|
||||
@@ -291,13 +308,13 @@ func (a *arena) bytes(b []byte) uint32 {
|
||||
return 0
|
||||
}
|
||||
ptr := a.new(uint64(len(b)))
|
||||
util.WriteBytes(a.m.mod, ptr, b)
|
||||
util.WriteBytes(a.sqlt.mod, ptr, b)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (a *arena) string(s string) uint32 {
|
||||
ptr := a.new(uint64(len(s) + 1))
|
||||
util.WriteString(a.m.mod, ptr, s)
|
||||
util.WriteString(a.sqlt.mod, ptr, s)
|
||||
return ptr
|
||||
}
|
||||
|
||||
@@ -317,10 +334,10 @@ type sqliteAPI struct {
|
||||
step api.Function
|
||||
exec api.Function
|
||||
clearBindings api.Function
|
||||
bindNull api.Function
|
||||
bindCount api.Function
|
||||
bindIndex api.Function
|
||||
bindName api.Function
|
||||
bindNull api.Function
|
||||
bindInteger api.Function
|
||||
bindFloat api.Function
|
||||
bindText api.Function
|
||||
@@ -348,5 +365,30 @@ type sqliteAPI struct {
|
||||
changes api.Function
|
||||
lastRowid api.Function
|
||||
autocommit api.Function
|
||||
anyCollation api.Function
|
||||
createCollation api.Function
|
||||
createFunction api.Function
|
||||
createAggregate api.Function
|
||||
createWindow api.Function
|
||||
aggregateCtx api.Function
|
||||
userData api.Function
|
||||
setAuxData api.Function
|
||||
getAuxData api.Function
|
||||
valueType api.Function
|
||||
valueInteger api.Function
|
||||
valueFloat api.Function
|
||||
valueText api.Function
|
||||
valueBlob api.Function
|
||||
valueBytes api.Function
|
||||
resultNull api.Function
|
||||
resultInteger api.Function
|
||||
resultFloat api.Function
|
||||
resultText api.Function
|
||||
resultBlob api.Function
|
||||
resultZeroBlob api.Function
|
||||
resultError api.Function
|
||||
resultErrorCode api.Function
|
||||
resultErrorMem api.Function
|
||||
resultErrorBig api.Function
|
||||
destructor uint32
|
||||
}
|
||||
1
sqlite3/.gitignore
vendored
1
sqlite3/.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
ext/
|
||||
sqlite3.c
|
||||
sqlite3.h
|
||||
sqlite3ext.h
|
||||
@@ -1,20 +0,0 @@
|
||||
--- sqlite3.c.orig
|
||||
+++ sqlite3.c
|
||||
@@ -60425,7 +60425,7 @@
|
||||
int rc = SQLITE_OK; /* Return code */
|
||||
int tempFile = 0; /* True for temp files (incl. in-memory files) */
|
||||
int memDb = 0; /* True if this is an in-memory file */
|
||||
-#ifndef SQLITE_OMIT_DESERIALIZE
|
||||
+#if 1
|
||||
int memJM = 0; /* Memory journal mode */
|
||||
#else
|
||||
# define memJM 0
|
||||
@@ -60628,7 +60628,7 @@
|
||||
int fout = 0; /* VFS flags returned by xOpen() */
|
||||
rc = sqlite3OsOpen(pVfs, pPager->zFilename, pPager->fd, vfsFlags, &fout);
|
||||
assert( !memDb );
|
||||
-#ifndef SQLITE_OMIT_DESERIALIZE
|
||||
+#if 1
|
||||
pPager->memVfs = memJM = (fout&SQLITE_OPEN_MEMORY)!=0;
|
||||
#endif
|
||||
readOnly = (fout&SQLITE_OPEN_READONLY)!=0;
|
||||
@@ -3,32 +3,33 @@ set -euo pipefail
|
||||
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3420000.zip"
|
||||
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3430200.zip"
|
||||
unzip -d . sqlite-amalgamation-*.zip
|
||||
mv sqlite-amalgamation-*/sqlite3* .
|
||||
rm -rf sqlite-amalgamation-*
|
||||
|
||||
patch < vfs_find.patch
|
||||
patch < deserialize.patch
|
||||
cat *.patch | patch --posix
|
||||
|
||||
mkdir -p ext/
|
||||
cd ext/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/decimal.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uint.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uuid.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/base64.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/regexp.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/series.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/ext/misc/decimal.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/ext/misc/uint.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/ext/misc/uuid.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/ext/misc/base64.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/ext/misc/regexp.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/ext/misc/series.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/ext/misc/anycollseq.c"
|
||||
cd ~-
|
||||
|
||||
cd ../vfs/tests/mptest/testdata/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/mptest.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config02.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash02.subtest"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/multiwrite01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/mptest/mptest.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/mptest/config01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/mptest/config02.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/mptest/crash01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/mptest/crash02.subtest"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/mptest/multiwrite01.test"
|
||||
cd ~-
|
||||
|
||||
cd ../vfs/tests/speedtest1/testdata/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/test/speedtest1.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.2/test/speedtest1.c"
|
||||
cd ~-
|
||||
1
sqlite3/ext/.gitignore
vendored
1
sqlite3/ext/.gitignore
vendored
@@ -1 +0,0 @@
|
||||
*.c
|
||||
40
sqlite3/func.c
Normal file
40
sqlite3/func.c
Normal file
@@ -0,0 +1,40 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
|
||||
int go_compare(void *, int, const void *, int, const void *);
|
||||
void go_func(sqlite3_context *, int, sqlite3_value **);
|
||||
void go_step(sqlite3_context *, int, sqlite3_value **);
|
||||
void go_final(sqlite3_context *);
|
||||
void go_value(sqlite3_context *);
|
||||
void go_inverse(sqlite3_context *, int, sqlite3_value **);
|
||||
void go_destroy(void *);
|
||||
|
||||
int sqlite3_create_collation_go(sqlite3 *db, const char *zName, void *pApp) {
|
||||
return sqlite3_create_collation_v2(db, zName, SQLITE_UTF8, pApp, go_compare,
|
||||
go_destroy);
|
||||
}
|
||||
|
||||
int sqlite3_create_function_go(sqlite3 *db, const char *zName, int nArg,
|
||||
int flags, void *pApp) {
|
||||
return sqlite3_create_function_v2(db, zName, nArg, SQLITE_UTF8 | flags, pApp,
|
||||
go_func, NULL, NULL, go_destroy);
|
||||
}
|
||||
|
||||
int sqlite3_create_aggregate_function_go(sqlite3 *db, const char *zName,
|
||||
int nArg, int flags, void *pApp) {
|
||||
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
|
||||
pApp, go_step, go_final, NULL, NULL,
|
||||
go_destroy);
|
||||
}
|
||||
|
||||
int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg,
|
||||
int flags, void *pApp) {
|
||||
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
|
||||
pApp, go_step, go_final, go_value,
|
||||
go_inverse, go_destroy);
|
||||
}
|
||||
|
||||
void sqlite3_set_auxdata_go(sqlite3_context *ctx, int iArg, void *pAux) {
|
||||
sqlite3_set_auxdata(ctx, iArg, pAux, go_destroy);
|
||||
}
|
||||
14
sqlite3/locking_mode.patch
Normal file
14
sqlite3/locking_mode.patch
Normal file
@@ -0,0 +1,14 @@
|
||||
# Use exclusive locking mode for WAL databases with v1 VFSes.
|
||||
--- sqlite3.c.orig
|
||||
+++ sqlite3.c
|
||||
@@ -63210,7 +63210,9 @@
|
||||
SQLITE_PRIVATE int sqlite3PagerWalSupported(Pager *pPager){
|
||||
const sqlite3_io_methods *pMethods = pPager->fd->pMethods;
|
||||
if( pPager->noLock ) return 0;
|
||||
- return pPager->exclusiveMode || (pMethods->iVersion>=2 && pMethods->xShmMap);
|
||||
+ if( pMethods->iVersion>=2 && pMethods->xShmMap ) return 1;
|
||||
+ pPager->exclusiveMode = 1;
|
||||
+ return 1;
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -1,19 +1,19 @@
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
|
||||
// Configuration
|
||||
#include "sqlite_cfg.h"
|
||||
// Amalgamation
|
||||
#include "sqlite3.c"
|
||||
// VFS
|
||||
#include "vfs.c"
|
||||
// Extensions
|
||||
#include "ext/anycollseq.c"
|
||||
#include "ext/base64.c"
|
||||
#include "ext/decimal.c"
|
||||
#include "ext/regexp.c"
|
||||
#include "ext/series.c"
|
||||
#include "ext/uint.c"
|
||||
#include "ext/uuid.c"
|
||||
#include "func.c"
|
||||
#include "time.c"
|
||||
|
||||
__attribute__((constructor)) void init() {
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
#define SQLITE_USE_ALLOCA
|
||||
|
||||
// Other Options
|
||||
|
||||
#define SQLITE_ALLOW_URI_AUTHORITY
|
||||
#define SQLITE_ENABLE_BATCH_ATOMIC_WRITE
|
||||
#define SQLITE_ENABLE_ATOMIC_WRITE
|
||||
@@ -36,12 +37,9 @@
|
||||
|
||||
// Because WASM does not support shared memory,
|
||||
// SQLite disables WAL for WASM builds.
|
||||
// We set the default locking mode to EXCLUSIVE instead.
|
||||
// We patch SQLite to use exclusive locking mode instead.
|
||||
// https://www.sqlite.org/wal.html#noshm
|
||||
#undef SQLITE_OMIT_WAL
|
||||
#ifndef SQLITE_DEFAULT_LOCKING_MODE
|
||||
#define SQLITE_DEFAULT_LOCKING_MODE 1
|
||||
#endif
|
||||
|
||||
// Amalgamated Extensions
|
||||
|
||||
@@ -58,5 +56,7 @@
|
||||
// #define SQLITE_ENABLE_SESSION
|
||||
// #define SQLITE_ENABLE_PREUPDATE_HOOK
|
||||
|
||||
#define SQLITE_SOUNDEX
|
||||
|
||||
// Implemented in vfs.c.
|
||||
int localtime_s(struct tm *const pTm, time_t const *const pTime);
|
||||
@@ -134,4 +134,4 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
|
||||
|
||||
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");
|
||||
static_assert(offsetof(sqlite3_vfs, zName) == 16, "Unexpected offset");
|
||||
static_assert(offsetof(sqlite3, u1.isInterrupted) == 280, "Unexpected offset");
|
||||
static_assert(offsetof(sqlite3, u1.isInterrupted) == 288, "Unexpected offset");
|
||||
@@ -1,3 +1,4 @@
|
||||
# Wrap sqlite3_vfs_find.
|
||||
--- sqlite3.c.orig
|
||||
+++ sqlite3.c
|
||||
@@ -25394,7 +25394,7 @@
|
||||
|
||||
@@ -12,67 +12,67 @@ func init() {
|
||||
Path = "./embed/sqlite3.wasm"
|
||||
}
|
||||
|
||||
func TestConn_error_OOM(t *testing.T) {
|
||||
func Test_sqlite_error_OOM(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
defer func() { _ = recover() }()
|
||||
m.error(uint64(NOMEM), 0)
|
||||
sqlite.error(uint64(NOMEM), 0)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestConn_call_closed(t *testing.T) {
|
||||
func Test_sqlite_call_closed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m.close()
|
||||
sqlite.close()
|
||||
|
||||
defer func() { _ = recover() }()
|
||||
m.call(m.api.free)
|
||||
sqlite.call(sqlite.api.free)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestConn_new(t *testing.T) {
|
||||
func Test_sqlite_new(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
t.Run("MaxUint32", func(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
m.new(math.MaxUint32)
|
||||
sqlite.new(math.MaxUint32)
|
||||
t.Error("want panic")
|
||||
})
|
||||
|
||||
t.Run("_MAX_ALLOCATION_SIZE", func(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
m.new(_MAX_ALLOCATION_SIZE)
|
||||
m.new(_MAX_ALLOCATION_SIZE)
|
||||
sqlite.new(_MAX_ALLOCATION_SIZE)
|
||||
sqlite.new(_MAX_ALLOCATION_SIZE)
|
||||
t.Error("want panic")
|
||||
})
|
||||
}
|
||||
|
||||
func TestConn_newArena(t *testing.T) {
|
||||
func Test_sqlite_newArena(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
arena := m.newArena(16)
|
||||
arena := sqlite.newArena(16)
|
||||
defer arena.free()
|
||||
|
||||
const title = "Lorem ipsum"
|
||||
@@ -80,7 +80,7 @@ func TestConn_newArena(t *testing.T) {
|
||||
if ptr == 0 {
|
||||
t.Fatalf("got nullptr")
|
||||
}
|
||||
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != title {
|
||||
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != title {
|
||||
t.Errorf("got %q, want %q", got, title)
|
||||
}
|
||||
|
||||
@@ -89,7 +89,7 @@ func TestConn_newArena(t *testing.T) {
|
||||
if ptr == 0 {
|
||||
t.Fatalf("got nullptr")
|
||||
}
|
||||
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != body {
|
||||
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != body {
|
||||
t.Errorf("got %q, want %q", got, body)
|
||||
}
|
||||
|
||||
@@ -101,121 +101,121 @@ func TestConn_newArena(t *testing.T) {
|
||||
if ptr == 0 {
|
||||
t.Fatalf("got nullptr")
|
||||
}
|
||||
if got := util.View(m.mod, ptr, uint64(len(title))); string(got) != title {
|
||||
if got := util.View(sqlite.mod, ptr, uint64(len(title))); string(got) != title {
|
||||
t.Errorf("got %q, want %q", got, title)
|
||||
}
|
||||
|
||||
arena.free()
|
||||
}
|
||||
|
||||
func TestConn_newBytes(t *testing.T) {
|
||||
func Test_sqlite_newBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
ptr := m.newBytes(nil)
|
||||
ptr := sqlite.newBytes(nil)
|
||||
if ptr != 0 {
|
||||
t.Errorf("got %#x, want nullptr", ptr)
|
||||
}
|
||||
|
||||
buf := []byte("sqlite3")
|
||||
ptr = m.newBytes(buf)
|
||||
ptr = sqlite.newBytes(buf)
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
want := buf
|
||||
if got := util.View(m.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
|
||||
if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_newString(t *testing.T) {
|
||||
func Test_sqlite_newString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
ptr := m.newString("")
|
||||
ptr := sqlite.newString("")
|
||||
if ptr == 0 {
|
||||
t.Error("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
str := "sqlite3\000sqlite3"
|
||||
ptr = m.newString(str)
|
||||
ptr = sqlite.newString(str)
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
want := str + "\000"
|
||||
if got := util.View(m.mod, ptr, uint64(len(want))); string(got) != want {
|
||||
if got := util.View(sqlite.mod, ptr, uint64(len(want))); string(got) != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_getString(t *testing.T) {
|
||||
func Test_sqlite_getString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
ptr := m.newString("")
|
||||
ptr := sqlite.newString("")
|
||||
if ptr == 0 {
|
||||
t.Error("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
str := "sqlite3" + "\000 drop this"
|
||||
ptr = m.newString(str)
|
||||
ptr = sqlite.newString(str)
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
want := "sqlite3"
|
||||
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != want {
|
||||
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
if got := util.ReadString(m.mod, ptr, 0); got != "" {
|
||||
if got := util.ReadString(sqlite.mod, ptr, 0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() { _ = recover() }()
|
||||
util.ReadString(m.mod, ptr, uint32(len(want)/2))
|
||||
util.ReadString(sqlite.mod, ptr, uint32(len(want)/2))
|
||||
t.Error("want panic")
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() { _ = recover() }()
|
||||
util.ReadString(m.mod, 0, math.MaxUint32)
|
||||
util.ReadString(sqlite.mod, 0, math.MaxUint32)
|
||||
t.Error("want panic")
|
||||
}()
|
||||
}
|
||||
|
||||
func TestConn_free(t *testing.T) {
|
||||
func Test_sqlite_free(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
m.free(0)
|
||||
sqlite.free(0)
|
||||
|
||||
ptr := m.new(1)
|
||||
ptr := sqlite.new(1)
|
||||
if ptr == 0 {
|
||||
t.Error("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
m.free(ptr)
|
||||
sqlite.free(ptr)
|
||||
}
|
||||
35
stmt.go
35
stmt.go
@@ -61,12 +61,12 @@ func (s *Stmt) ClearBindings() error {
|
||||
func (s *Stmt) Step() bool {
|
||||
s.c.checkInterrupt()
|
||||
r := s.c.call(s.c.api.step, uint64(s.handle))
|
||||
if r == _ROW {
|
||||
switch r {
|
||||
case _ROW:
|
||||
return true
|
||||
}
|
||||
if r == _DONE {
|
||||
case _DONE:
|
||||
s.err = nil
|
||||
} else {
|
||||
default:
|
||||
s.err = s.c.error(r)
|
||||
}
|
||||
return false
|
||||
@@ -131,10 +131,11 @@ func (s *Stmt) BindName(param int) string {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindBool(param int, value bool) error {
|
||||
var i int64
|
||||
if value {
|
||||
return s.BindInt64(param, 1)
|
||||
i = 1
|
||||
}
|
||||
return s.BindInt64(param, 0)
|
||||
return s.BindInt64(param, i)
|
||||
}
|
||||
|
||||
// BindInt binds an int to the prepared statement.
|
||||
@@ -374,18 +375,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
|
||||
func (s *Stmt) ColumnRawText(col int) []byte {
|
||||
r := s.c.call(s.c.api.columnText,
|
||||
uint64(s.handle), uint64(col))
|
||||
|
||||
ptr := uint32(r)
|
||||
if ptr == 0 {
|
||||
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
|
||||
s.err = s.c.error(r)
|
||||
return nil
|
||||
}
|
||||
|
||||
r = s.c.call(s.c.api.columnBytes,
|
||||
uint64(s.handle), uint64(col))
|
||||
|
||||
return util.View(s.c.mod, ptr, r)
|
||||
return s.columnRawBytes(col, uint32(r))
|
||||
}
|
||||
|
||||
// ColumnRawBlob returns the value of the result column as a []byte.
|
||||
@@ -397,17 +387,18 @@ func (s *Stmt) ColumnRawText(col int) []byte {
|
||||
func (s *Stmt) ColumnRawBlob(col int) []byte {
|
||||
r := s.c.call(s.c.api.columnBlob,
|
||||
uint64(s.handle), uint64(col))
|
||||
return s.columnRawBytes(col, uint32(r))
|
||||
}
|
||||
|
||||
ptr := uint32(r)
|
||||
func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte {
|
||||
if ptr == 0 {
|
||||
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
|
||||
r := s.c.call(s.c.api.errcode, uint64(s.c.handle))
|
||||
s.err = s.c.error(r)
|
||||
return nil
|
||||
}
|
||||
|
||||
r = s.c.call(s.c.api.columnBytes,
|
||||
r := s.c.call(s.c.api.columnBytes,
|
||||
uint64(s.handle), uint64(col))
|
||||
|
||||
return util.View(s.c.mod, ptr, r)
|
||||
}
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ func (t params) mustExec(sql string, args ...interface{}) sql.Result {
|
||||
func (sqliteDB) RunTest(t *testing.T, fn func(params)) {
|
||||
db, err := sql.Open("sqlite3", "file:"+
|
||||
filepath.Join(t.TempDir(), "foo.db")+
|
||||
"?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)&_pragma=synchronous(off)")
|
||||
"?_pragma=busy_timeout(10000)&_pragma=synchronous(off)")
|
||||
if err != nil {
|
||||
t.Fatalf("foo.db open fail: %v", err)
|
||||
}
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
_ "embed"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
//go:embed testdata/wal.db
|
||||
var waldb []byte
|
||||
|
||||
func TestDB_memory(t *testing.T) {
|
||||
t.Parallel()
|
||||
testDB(t, ":memory:")
|
||||
@@ -19,6 +25,16 @@ func TestDB_file(t *testing.T) {
|
||||
testDB(t, filepath.Join(t.TempDir(), "test.db"))
|
||||
}
|
||||
|
||||
func TestDB_wal(t *testing.T) {
|
||||
t.Parallel()
|
||||
wal := filepath.Join(t.TempDir(), "test.db")
|
||||
err := os.WriteFile(wal, waldb, 0666)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testDB(t, wal)
|
||||
}
|
||||
|
||||
func TestDB_vfs(t *testing.T) {
|
||||
testDB(t, "file:test.db?vfs=memdb")
|
||||
}
|
||||
|
||||
@@ -2,10 +2,9 @@ package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
@@ -15,7 +14,7 @@ func TestDriver(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
db, err := driver.Open(":memory:", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
188
tests/func_test.go
Normal file
188
tests/func_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestCreateFunction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.CreateFunction("test", 1, sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
switch arg := arg[0]; arg.Int() {
|
||||
case 0:
|
||||
ctx.ResultInt(arg.Int())
|
||||
case 1:
|
||||
ctx.ResultInt64(arg.Int64())
|
||||
case 2:
|
||||
ctx.ResultBool(arg.Bool())
|
||||
case 3:
|
||||
ctx.ResultFloat(arg.Float())
|
||||
case 4:
|
||||
ctx.ResultText(arg.Text())
|
||||
case 5:
|
||||
ctx.ResultBlob(arg.Blob(nil))
|
||||
case 6:
|
||||
ctx.ResultZeroBlob(arg.Int64())
|
||||
case 7:
|
||||
ctx.ResultTime(arg.Time(sqlite3.TimeFormatUnix), sqlite3.TimeFormatDefault)
|
||||
case 8:
|
||||
ctx.ResultNull()
|
||||
case 9:
|
||||
ctx.ResultError(sqlite3.FULL)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0, 9)`)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnInt64(0); got != 1 {
|
||||
t.Errorf("got %v, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != true {
|
||||
t.Errorf("got %v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.FLOAT {
|
||||
t.Errorf("got %v, want FLOAT", got)
|
||||
}
|
||||
if got := stmt.ColumnInt64(0); got != 3 {
|
||||
t.Errorf("got %v, want 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "4" {
|
||||
t.Errorf("got %s, want 4", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
|
||||
t.Errorf("got %v, want BLOB", got)
|
||||
}
|
||||
if got := stmt.ColumnRawBlob(0); string(got) != "5" {
|
||||
t.Errorf("got %s, want 5", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
|
||||
t.Errorf("got %v, want BLOB", got)
|
||||
}
|
||||
if got := stmt.ColumnRawBlob(0); len(got) != 6 {
|
||||
t.Errorf("got %v, want 6", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); got.Unix() != 7 {
|
||||
t.Errorf("got %v, want 7", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
t.Error("want error")
|
||||
}
|
||||
if err := stmt.Err(); !errors.Is(err, sqlite3.FULL) {
|
||||
t.Errorf("got %v, want sqlite3.FULL", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnyCollationNeeded(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.AnyCollationNeeded()
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT id, name FROM users ORDER BY name COLLATE silly`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
row := 0
|
||||
ids := []int{0, 2, 1}
|
||||
names := []string{"go", "whatever", "zig"}
|
||||
for ; stmt.Step(); row++ {
|
||||
id := stmt.ColumnInt(0)
|
||||
name := stmt.ColumnText(1)
|
||||
|
||||
if id != ids[row] {
|
||||
t.Errorf("got %d, want %d", id, ids[row])
|
||||
}
|
||||
if name != names[row] {
|
||||
t.Errorf("got %q, want %q", name, names[row])
|
||||
}
|
||||
}
|
||||
if row != 3 {
|
||||
t.Errorf("got %d, want %d", row, len(ids))
|
||||
}
|
||||
|
||||
if err := stmt.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -25,7 +25,6 @@ func TestParallel(t *testing.T) {
|
||||
name := "file:" +
|
||||
filepath.Join(t.TempDir(), "test.db") +
|
||||
"?_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=locking_mode(normal)" +
|
||||
"&_pragma=journal_mode(truncate)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
testParallel(t, name, iter)
|
||||
@@ -42,7 +41,6 @@ func TestMemory(t *testing.T) {
|
||||
|
||||
name := "file:/test.db?vfs=memdb" +
|
||||
"&_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=locking_mode(normal)" +
|
||||
"&_pragma=journal_mode(memory)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
testParallel(t, name, iter)
|
||||
@@ -59,7 +57,6 @@ func TestMultiProcess(t *testing.T) {
|
||||
|
||||
name := "file:" + file +
|
||||
"?_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=locking_mode(normal)" +
|
||||
"&_pragma=journal_mode(truncate)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
|
||||
@@ -93,7 +90,6 @@ func TestChildProcess(t *testing.T) {
|
||||
|
||||
name := "file:" + file +
|
||||
"?_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=locking_mode(normal)" +
|
||||
"&_pragma=journal_mode(truncate)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
|
||||
@@ -128,10 +124,7 @@ func testParallel(t *testing.T, name string, n int) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`
|
||||
PRAGMA busy_timeout=10000;
|
||||
PRAGMA locking_mode=normal;
|
||||
`)
|
||||
err = db.Exec(`PRAGMA busy_timeout=10000`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
BIN
tests/testdata/wal.db
vendored
Normal file
BIN
tests/testdata/wal.db
vendored
Normal file
Binary file not shown.
125
value.go
Normal file
125
value.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Value is any value that can be stored in a database table.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value.html
|
||||
type Value struct {
|
||||
*sqlite
|
||||
handle uint32
|
||||
}
|
||||
|
||||
// Type returns the initial [Datatype] of the value.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Type() Datatype {
|
||||
r := v.call(v.api.valueType, uint64(v.handle))
|
||||
return Datatype(r)
|
||||
}
|
||||
|
||||
// Bool returns the value as a bool.
|
||||
// SQLite does not have a separate boolean storage class.
|
||||
// Instead, boolean values are retrieved as integers,
|
||||
// with 0 converted to false and any other value to true.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Bool() bool {
|
||||
if i := v.Int64(); i != 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Int returns the value as an int.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Int() int {
|
||||
return int(v.Int64())
|
||||
}
|
||||
|
||||
// Int64 returns the value as an int64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Int64() int64 {
|
||||
r := v.call(v.api.valueInteger, uint64(v.handle))
|
||||
return int64(r)
|
||||
}
|
||||
|
||||
// Float returns the value as a float64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Float() float64 {
|
||||
r := v.call(v.api.valueFloat, uint64(v.handle))
|
||||
return math.Float64frombits(r)
|
||||
}
|
||||
|
||||
// Time returns the value as a [time.Time].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Time(format TimeFormat) time.Time {
|
||||
var a any
|
||||
switch v.Type() {
|
||||
case INTEGER:
|
||||
a = v.Int64()
|
||||
case FLOAT:
|
||||
a = v.Float()
|
||||
case TEXT, BLOB:
|
||||
a = v.Text()
|
||||
case NULL:
|
||||
return time.Time{}
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
t, _ := format.Decode(a)
|
||||
return t
|
||||
}
|
||||
|
||||
// Text returns the value as a string.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Text() string {
|
||||
return string(v.RawText())
|
||||
}
|
||||
|
||||
// Blob appends to buf and returns
|
||||
// the value as a []byte.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Blob(buf []byte) []byte {
|
||||
return append(buf, v.RawBlob()...)
|
||||
}
|
||||
|
||||
// RawText returns the value as a []byte.
|
||||
// The []byte is owned by SQLite and may be invalidated by
|
||||
// subsequent calls to [Value] methods.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) RawText() []byte {
|
||||
r := v.call(v.api.valueText, uint64(v.handle))
|
||||
return v.rawBytes(uint32(r))
|
||||
}
|
||||
|
||||
// RawBlob returns the value as a []byte.
|
||||
// The []byte is owned by SQLite and may be invalidated by
|
||||
// subsequent calls to [Value] methods.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) RawBlob() []byte {
|
||||
r := v.call(v.api.valueBlob, uint64(v.handle))
|
||||
return v.rawBytes(uint32(r))
|
||||
}
|
||||
|
||||
func (v Value) rawBytes(ptr uint32) []byte {
|
||||
if ptr == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
r := v.call(v.api.valueBytes, uint64(v.handle))
|
||||
return util.View(v.mod, ptr, r)
|
||||
}
|
||||
@@ -2,8 +2,30 @@
|
||||
|
||||
This package implements the SQLite [OS Interface](https://www.sqlite.org/vfs.html) (aka VFS).
|
||||
|
||||
It replaces the default VFS with a pure Go implementation,
|
||||
that is tested on Linux, macOS and Windows,
|
||||
but which should also work on illumos and the various BSDs.
|
||||
It replaces the default SQLite VFS with a pure Go implementation.
|
||||
|
||||
It also exposes interfaces that should allow you to implement your own custom VFSes.
|
||||
It also exposes interfaces that should allow you to implement your own custom VFSes.
|
||||
|
||||
## Portability
|
||||
|
||||
This package is tested on Linux, macOS and Windows,
|
||||
but it should also work on FreeBSD and illumos
|
||||
(code paths for those plaforms are tested on macOS and Linux, respectively).
|
||||
|
||||
In all platforms for which this package builds,
|
||||
it should be safe to use it to access databases concurrently,
|
||||
from multiple goroutines, processes, and
|
||||
with _other_ implementations of SQLite.
|
||||
|
||||
If the package does not build for your platform,
|
||||
you may try to use the `sqlite3_flock` and `sqlite3_nolock` build tags.
|
||||
These are only minimally tested and concurrency test failures should be expected.
|
||||
|
||||
The `sqlite3_flock` tag uses
|
||||
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2).
|
||||
It should be safe to access databases concurrently from multiple goroutines and processes,
|
||||
but **not** with _other_ implementations of SQLite
|
||||
(_unless_ these are _also_ configured to use `flock`).
|
||||
|
||||
The `sqlite3_nolock` tag uses no locking at all.
|
||||
Database corruption is the likely result from concurrent write access.
|
||||
23
vfs/api.go
23
vfs/api.go
@@ -15,7 +15,7 @@ type VFS interface {
|
||||
FullPathname(name string) (string, error)
|
||||
}
|
||||
|
||||
// VFSParams extends VFS to with the ability to handle URI parameters
|
||||
// VFSParams extends VFS with the ability to handle URI parameters
|
||||
// through the OpenParams method.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/uri_boolean.html
|
||||
@@ -47,7 +47,7 @@ type File interface {
|
||||
// FileLockState extends File to implement the
|
||||
// SQLITE_FCNTL_LOCKSTATE file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntllockstate
|
||||
type FileLockState interface {
|
||||
File
|
||||
LockState() LockLevel
|
||||
@@ -56,7 +56,7 @@ type FileLockState interface {
|
||||
// FileSizeHint extends File to implement the
|
||||
// SQLITE_FCNTL_SIZE_HINT file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlsizehint
|
||||
type FileSizeHint interface {
|
||||
File
|
||||
SizeHint(size int64) error
|
||||
@@ -65,16 +65,25 @@ type FileSizeHint interface {
|
||||
// FileHasMoved extends File to implement the
|
||||
// SQLITE_FCNTL_HAS_MOVED file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlhasmoved
|
||||
type FileHasMoved interface {
|
||||
File
|
||||
HasMoved() (bool, error)
|
||||
}
|
||||
|
||||
// FileOverwrite extends File to implement the
|
||||
// SQLITE_FCNTL_OVERWRITE file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntloverwrite
|
||||
type FileOverwrite interface {
|
||||
File
|
||||
Overwrite() error
|
||||
}
|
||||
|
||||
// FilePowersafeOverwrite extends File to implement the
|
||||
// SQLITE_FCNTL_POWERSAFE_OVERWRITE file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlpowersafeoverwrite
|
||||
type FilePowersafeOverwrite interface {
|
||||
File
|
||||
PowersafeOverwrite() bool
|
||||
@@ -84,7 +93,7 @@ type FilePowersafeOverwrite interface {
|
||||
// FilePowersafeOverwrite extends File to implement the
|
||||
// SQLITE_FCNTL_COMMIT_PHASETWO file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlcommitphasetwo
|
||||
type FileCommitPhaseTwo interface {
|
||||
File
|
||||
CommitPhaseTwo() error
|
||||
@@ -94,7 +103,7 @@ type FileCommitPhaseTwo interface {
|
||||
// SQLITE_FCNTL_BEGIN_ATOMIC_WRITE, SQLITE_FCNTL_COMMIT_ATOMIC_WRITE
|
||||
// and SQLITE_FCNTL_ROLLBACK_ATOMIC_WRITE file control opcodes.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlbeginatomicwrite
|
||||
type FileBatchAtomicWrite interface {
|
||||
File
|
||||
BeginAtomicWrite() error
|
||||
|
||||
9
vfs/clear.go
Normal file
9
vfs/clear.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !go1.21
|
||||
|
||||
package vfs
|
||||
|
||||
func clear(b []byte) {
|
||||
for i := range b {
|
||||
b[i] = 0
|
||||
}
|
||||
}
|
||||
10
vfs/file.go
10
vfs/file.go
@@ -9,7 +9,6 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type vfsOS struct{}
|
||||
@@ -124,11 +123,10 @@ func (vfsOS) OpenParams(name string, flags OpenFlag, params url.Values) (File, O
|
||||
|
||||
type vfsFile struct {
|
||||
*os.File
|
||||
lockTimeout time.Duration
|
||||
lock LockLevel
|
||||
psow bool
|
||||
syncDir bool
|
||||
readOnly bool
|
||||
lock LockLevel
|
||||
psow bool
|
||||
syncDir bool
|
||||
readOnly bool
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
13
vfs/lock.go
13
vfs/lock.go
@@ -1,8 +1,9 @@
|
||||
//go:build !sqlite3_nolock
|
||||
|
||||
package vfs
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
@@ -48,7 +49,7 @@ func (f *vfsFile) Lock(lock LockLevel) error {
|
||||
if f.lock != LOCK_NONE {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
if rc := osGetSharedLock(f.File, f.lockTimeout); rc != _OK {
|
||||
if rc := osGetSharedLock(f.File); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
f.lock = LOCK_SHARED
|
||||
@@ -59,7 +60,7 @@ func (f *vfsFile) Lock(lock LockLevel) error {
|
||||
if f.lock != LOCK_SHARED {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
if rc := osGetReservedLock(f.File, f.lockTimeout); rc != _OK {
|
||||
if rc := osGetReservedLock(f.File); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
f.lock = LOCK_RESERVED
|
||||
@@ -77,7 +78,7 @@ func (f *vfsFile) Lock(lock LockLevel) error {
|
||||
}
|
||||
f.lock = LOCK_PENDING
|
||||
}
|
||||
if rc := osGetExclusiveLock(f.File, f.lockTimeout); rc != _OK {
|
||||
if rc := osGetExclusiveLock(f.File); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
f.lock = LOCK_EXCLUSIVE
|
||||
@@ -134,9 +135,9 @@ func (f *vfsFile) CheckReservedLock() (bool, error) {
|
||||
return osCheckReservedLock(f.File)
|
||||
}
|
||||
|
||||
func osGetReservedLock(file *os.File, timeout time.Duration) _ErrorCode {
|
||||
func osGetReservedLock(file *os.File) _ErrorCode {
|
||||
// Acquire the RESERVED lock.
|
||||
return osWriteLock(file, _RESERVED_BYTE, 1, timeout)
|
||||
return osWriteLock(file, _RESERVED_BYTE, 1, 0)
|
||||
}
|
||||
|
||||
func osGetPendingLock(file *os.File) _ErrorCode {
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
@@ -12,13 +11,6 @@ import (
|
||||
)
|
||||
|
||||
func Test_vfsLock(t *testing.T) {
|
||||
switch runtime.GOOS {
|
||||
case "linux", "darwin", "windows":
|
||||
break
|
||||
default:
|
||||
t.Skip("OS lacks OFD locks")
|
||||
}
|
||||
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
|
||||
// Create a temporary file.
|
||||
@@ -41,8 +33,7 @@ func Test_vfsLock(t *testing.T) {
|
||||
pOutput = 32
|
||||
)
|
||||
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
|
||||
ctx, vfs := NewContext(context.TODO())
|
||||
defer vfs.Close()
|
||||
ctx := util.NewContext(context.TODO())
|
||||
|
||||
vfsFileRegister(ctx, mod, pFile1, &vfsFile{File: file1})
|
||||
vfsFileRegister(ctx, mod, pFile2, &vfsFile{File: file2})
|
||||
@@ -212,9 +203,4 @@ func Test_vfsLock(t *testing.T) {
|
||||
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) {
|
||||
t.Error("invalid lock state", got)
|
||||
}
|
||||
|
||||
rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCK_TIMEOUT, 1)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
}
|
||||
|
||||
10
vfs/memdb/clear.go
Normal file
10
vfs/memdb/clear.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !go1.21
|
||||
|
||||
package memdb
|
||||
|
||||
func clear[T any](b []T) {
|
||||
var zero T
|
||||
for i := range b {
|
||||
b[i] = zero
|
||||
}
|
||||
}
|
||||
@@ -133,7 +133,7 @@ func (m *memFile) WriteAt(b []byte, off int64) (n int, err error) {
|
||||
n = copy((*m.data[base])[rest:], b)
|
||||
if n < len(b) {
|
||||
// Assume writes are page aligned.
|
||||
return 0, io.ErrShortWrite
|
||||
return n, io.ErrShortWrite
|
||||
}
|
||||
if size := off + int64(len(b)); size > m.size {
|
||||
m.size = size
|
||||
@@ -176,6 +176,8 @@ func (m *memFile) Size() (int64, error) {
|
||||
return m.size, nil
|
||||
}
|
||||
|
||||
const spinWait = 25 * time.Microsecond
|
||||
|
||||
func (m *memFile) Lock(lock vfs.LockLevel) error {
|
||||
if m.lock >= lock {
|
||||
return nil
|
||||
@@ -187,17 +189,11 @@ func (m *memFile) Lock(lock vfs.LockLevel) error {
|
||||
|
||||
m.lockMtx.Lock()
|
||||
defer m.lockMtx.Unlock()
|
||||
deadline := time.Now().Add(time.Millisecond)
|
||||
|
||||
switch lock {
|
||||
case vfs.LOCK_SHARED:
|
||||
for m.pending != nil {
|
||||
if time.Now().After(deadline) {
|
||||
return sqlite3.BUSY
|
||||
}
|
||||
m.lockMtx.Unlock()
|
||||
runtime.Gosched()
|
||||
m.lockMtx.Lock()
|
||||
if m.pending != nil {
|
||||
return sqlite3.BUSY
|
||||
}
|
||||
m.shared++
|
||||
|
||||
@@ -216,8 +212,8 @@ func (m *memFile) Lock(lock vfs.LockLevel) error {
|
||||
m.pending = m
|
||||
}
|
||||
|
||||
for m.shared > 1 {
|
||||
if time.Now().After(deadline) {
|
||||
for before := time.Now(); m.shared > 1; {
|
||||
if time.Since(before) > spinWait {
|
||||
return sqlite3.BUSY
|
||||
}
|
||||
m.lockMtx.Unlock()
|
||||
@@ -291,10 +287,3 @@ func divRoundUp(a, b int64) int64 {
|
||||
func modRoundUp(a, b int64) int64 {
|
||||
return b - (b-a%b)%b
|
||||
}
|
||||
|
||||
func clear[T any](b []T) {
|
||||
var zero T
|
||||
for i := range b {
|
||||
b[i] = zero
|
||||
}
|
||||
}
|
||||
|
||||
22
vfs/nolock.go
Normal file
22
vfs/nolock.go
Normal file
@@ -0,0 +1,22 @@
|
||||
//go:build sqlite3_nolock
|
||||
|
||||
package vfs
|
||||
|
||||
const (
|
||||
_PENDING_BYTE = 0x40000000
|
||||
_RESERVED_BYTE = (_PENDING_BYTE + 1)
|
||||
_SHARED_FIRST = (_PENDING_BYTE + 2)
|
||||
_SHARED_SIZE = 510
|
||||
)
|
||||
|
||||
func (f *vfsFile) Lock(lock LockLevel) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *vfsFile) Unlock(lock LockLevel) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *vfsFile) CheckReservedLock() (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build freebsd || openbsd || netbsd || dragonfly || (darwin && sqlite3_bsd)
|
||||
//go:build sqlite3_flock || freebsd
|
||||
|
||||
package vfs
|
||||
|
||||
@@ -20,16 +20,16 @@ func osUnlock(file *os.File, start, len int64) _ErrorCode {
|
||||
}
|
||||
|
||||
func osLock(file *os.File, how int, timeout time.Duration, def _ErrorCode) _ErrorCode {
|
||||
before := time.Now()
|
||||
var err error
|
||||
for {
|
||||
err = unix.Flock(int(file.Fd()), how)
|
||||
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
|
||||
break
|
||||
}
|
||||
if timeout < time.Millisecond {
|
||||
if timeout <= 0 || timeout < time.Since(before) {
|
||||
break
|
||||
}
|
||||
timeout -= time.Millisecond
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
return osLockErrorCode(err, def)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !sqlite3_bsd
|
||||
//go:build !sqlite3_flock
|
||||
|
||||
package vfs
|
||||
|
||||
|
||||
28
vfs/os_nolock.go
Normal file
28
vfs/os_nolock.go
Normal file
@@ -0,0 +1,28 @@
|
||||
//go:build sqlite3_nolock && unix && !(linux || darwin || freebsd || illumos)
|
||||
|
||||
package vfs
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
func osUnlock(file *os.File, start, len int64) _ErrorCode {
|
||||
return _OK
|
||||
}
|
||||
|
||||
func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, def _ErrorCode) _ErrorCode {
|
||||
return _OK
|
||||
}
|
||||
|
||||
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
|
||||
return _OK
|
||||
}
|
||||
|
||||
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
|
||||
return _OK
|
||||
}
|
||||
|
||||
func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) {
|
||||
return false, _OK
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux || illumos
|
||||
//go:build (linux || illumos) && !sqlite3_flock
|
||||
|
||||
package vfs
|
||||
|
||||
@@ -27,16 +27,16 @@ func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, d
|
||||
Start: start,
|
||||
Len: len,
|
||||
}
|
||||
before := time.Now()
|
||||
var err error
|
||||
for {
|
||||
err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock)
|
||||
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
|
||||
break
|
||||
}
|
||||
if timeout < time.Millisecond {
|
||||
if timeout <= 0 || timeout < time.Since(before) {
|
||||
break
|
||||
}
|
||||
timeout -= time.Millisecond
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
return osLockErrorCode(err, def)
|
||||
|
||||
36
vfs/os_std_access.go
Normal file
36
vfs/os_std_access.go
Normal file
@@ -0,0 +1,36 @@
|
||||
//go:build !unix
|
||||
|
||||
package vfs
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
)
|
||||
|
||||
const (
|
||||
_S_IREAD = 0400
|
||||
_S_IWRITE = 0200
|
||||
_S_IEXEC = 0100
|
||||
)
|
||||
|
||||
func osAccess(path string, flags AccessFlag) error {
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if flags == ACCESS_EXISTS {
|
||||
return nil
|
||||
}
|
||||
|
||||
var want fs.FileMode = _S_IREAD
|
||||
if flags == ACCESS_READWRITE {
|
||||
want |= _S_IWRITE
|
||||
}
|
||||
if fi.IsDir() {
|
||||
want |= _S_IEXEC
|
||||
}
|
||||
if fi.Mode()&want != want {
|
||||
return fs.ErrPermission
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && (!darwin || sqlite3_bsd)
|
||||
//go:build !linux && (!darwin || sqlite3_flock)
|
||||
|
||||
package vfs
|
||||
|
||||
@@ -7,10 +7,6 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
func osSync(file *os.File, fullsync, dataonly bool) error {
|
||||
return file.Sync()
|
||||
}
|
||||
|
||||
func osAllocate(file *os.File, size int64) error {
|
||||
off, err := file.Seek(0, io.SeekEnd)
|
||||
if err != nil {
|
||||
14
vfs/os_std_mode.go
Normal file
14
vfs/os_std_mode.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build !unix
|
||||
|
||||
package vfs
|
||||
|
||||
import "os"
|
||||
|
||||
func osSetMode(file *os.File, modeof string) error {
|
||||
fi, err := os.Stat(modeof)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
file.Chmod(fi.Mode())
|
||||
return nil
|
||||
}
|
||||
12
vfs/os_std_open.go
Normal file
12
vfs/os_std_open.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package vfs
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
)
|
||||
|
||||
func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
|
||||
return os.OpenFile(name, flag, perm)
|
||||
}
|
||||
9
vfs/os_std_sync.go
Normal file
9
vfs/os_std_sync.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !linux && (!darwin || sqlite3_flock)
|
||||
|
||||
package vfs
|
||||
|
||||
import "os"
|
||||
|
||||
func osSync(file *os.File, fullsync, dataonly bool) error {
|
||||
return file.Sync()
|
||||
}
|
||||
@@ -3,7 +3,6 @@
|
||||
package vfs
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -11,10 +10,6 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
|
||||
return os.OpenFile(name, flag, perm)
|
||||
}
|
||||
|
||||
func osAccess(path string, flags AccessFlag) error {
|
||||
var access uint32 // unix.F_OK
|
||||
switch flags {
|
||||
@@ -38,22 +33,18 @@ func osSetMode(file *os.File, modeof string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
|
||||
func osGetSharedLock(file *os.File) _ErrorCode {
|
||||
// Test the PENDING lock before acquiring a new SHARED lock.
|
||||
if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending {
|
||||
return _BUSY
|
||||
}
|
||||
// Acquire the SHARED lock.
|
||||
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
|
||||
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
|
||||
}
|
||||
|
||||
func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
|
||||
if timeout == 0 {
|
||||
timeout = time.Millisecond
|
||||
}
|
||||
|
||||
func osGetExclusiveLock(file *os.File) _ErrorCode {
|
||||
// Acquire the EXCLUSIVE lock.
|
||||
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
|
||||
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, time.Millisecond)
|
||||
}
|
||||
|
||||
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
|
||||
|
||||
@@ -25,40 +25,9 @@ func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
|
||||
return os.NewFile(uintptr(r), name), nil
|
||||
}
|
||||
|
||||
func osAccess(path string, flags AccessFlag) error {
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if flags == ACCESS_EXISTS {
|
||||
return nil
|
||||
}
|
||||
|
||||
var want fs.FileMode = windows.S_IRUSR
|
||||
if flags == ACCESS_READWRITE {
|
||||
want |= windows.S_IWUSR
|
||||
}
|
||||
if fi.IsDir() {
|
||||
want |= windows.S_IXUSR
|
||||
}
|
||||
if fi.Mode()&want != want {
|
||||
return fs.ErrPermission
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func osSetMode(file *os.File, modeof string) error {
|
||||
fi, err := os.Stat(modeof)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
file.Chmod(fi.Mode())
|
||||
return nil
|
||||
}
|
||||
|
||||
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
|
||||
func osGetSharedLock(file *os.File) _ErrorCode {
|
||||
// Acquire the PENDING lock temporarily before acquiring a new SHARED lock.
|
||||
rc := osReadLock(file, _PENDING_BYTE, 1, timeout)
|
||||
rc := osReadLock(file, _PENDING_BYTE, 1, 0)
|
||||
|
||||
if rc == _OK {
|
||||
// Acquire the SHARED lock.
|
||||
@@ -70,16 +39,12 @@ func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
|
||||
return rc
|
||||
}
|
||||
|
||||
func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
|
||||
if timeout == 0 {
|
||||
timeout = time.Millisecond
|
||||
}
|
||||
|
||||
func osGetExclusiveLock(file *os.File) _ErrorCode {
|
||||
// Release the SHARED lock.
|
||||
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
|
||||
|
||||
// Acquire the EXCLUSIVE lock.
|
||||
rc := osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
|
||||
rc := osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, time.Millisecond)
|
||||
|
||||
if rc != _OK {
|
||||
// Reacquire the SHARED lock.
|
||||
@@ -138,6 +103,7 @@ func osUnlock(file *os.File, start, len uint32) _ErrorCode {
|
||||
}
|
||||
|
||||
func osLock(file *os.File, flags, start, len uint32, timeout time.Duration, def _ErrorCode) _ErrorCode {
|
||||
before := time.Now()
|
||||
var err error
|
||||
for {
|
||||
err = windows.LockFileEx(windows.Handle(file.Fd()), flags,
|
||||
@@ -145,11 +111,16 @@ func osLock(file *os.File, flags, start, len uint32, timeout time.Duration, def
|
||||
if errno, _ := err.(windows.Errno); errno != windows.ERROR_LOCK_VIOLATION {
|
||||
break
|
||||
}
|
||||
if timeout < time.Millisecond {
|
||||
if timeout <= 0 || timeout < time.Since(before) {
|
||||
break
|
||||
}
|
||||
if err := windows.TimeBeginPeriod(1); err != nil {
|
||||
break
|
||||
}
|
||||
timeout -= time.Millisecond
|
||||
time.Sleep(time.Millisecond)
|
||||
if err := windows.TimeEndPeriod(1); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return osLockErrorCode(err, def)
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ var (
|
||||
)
|
||||
|
||||
// Create creates an immutable database from reader.
|
||||
// The caller should insure that data from reader does not mutate,
|
||||
// The caller should ensure that data from reader does not mutate,
|
||||
// otherwise SQLite might return incorrect query results and/or [sqlite3.CORRUPT] errors.
|
||||
func Create(name string, reader SizeReaderAt) {
|
||||
readerMtx.Lock()
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package readervfs_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
_ "embed"
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
//go:embed testdata/test.db
|
||||
var testDB []byte
|
||||
var testDB string
|
||||
|
||||
func Example_http() {
|
||||
readervfs.Create("demo.db", httpreadat.New("https://www.sanford.io/demo.db"))
|
||||
@@ -65,7 +65,7 @@ func Example_http() {
|
||||
}
|
||||
|
||||
func Example_embed() {
|
||||
readervfs.Create("test.db", readervfs.NewSizeReaderAt(bytes.NewReader(testDB)))
|
||||
readervfs.Create("test.db", readervfs.NewSizeReaderAt(strings.NewReader(testDB)))
|
||||
defer readervfs.Delete("test.db")
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:test.db?vfs=reader")
|
||||
|
||||
@@ -37,7 +37,7 @@ func (readerVFS) FullPathname(name string) (string, error) {
|
||||
|
||||
type readerFile struct{ SizeReaderAt }
|
||||
|
||||
func (r readerFile) Close() error {
|
||||
func (readerFile) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package mptest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/bzip2"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"embed"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/ncruces/go-sqlite3/vfs"
|
||||
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
"github.com/tetratelabs/wazero"
|
||||
@@ -23,8 +25,8 @@ import (
|
||||
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
|
||||
)
|
||||
|
||||
//go:embed testdata/mptest.wasm
|
||||
var binary []byte
|
||||
//go:embed testdata/mptest.wasm.bz2
|
||||
var compressed string
|
||||
|
||||
//go:embed testdata/*.*test
|
||||
var scripts embed.FS
|
||||
@@ -36,8 +38,7 @@ var (
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
ctx := context.TODO()
|
||||
|
||||
ctx := context.Background()
|
||||
rt = wazero.NewRuntime(ctx)
|
||||
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
|
||||
|
||||
@@ -48,6 +49,11 @@ func TestMain(m *testing.M) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
binary, err := io.ReadAll(bzip2.NewReader(strings.NewReader(compressed)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
module, err = rt.CompileModule(ctx, binary)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -83,16 +89,15 @@ func system(ctx context.Context, mod api.Module, ptr uint32) uint32 {
|
||||
|
||||
cfg := config(ctx).WithArgs(args...)
|
||||
go func() {
|
||||
ctx, vfs := vfs.NewContext(ctx)
|
||||
ctx := util.NewContext(ctx)
|
||||
mod, _ := rt.InstantiateModule(ctx, module, cfg)
|
||||
mod.Close(ctx)
|
||||
vfs.Close()
|
||||
}()
|
||||
return 0
|
||||
}
|
||||
|
||||
func Test_config01(t *testing.T) {
|
||||
ctx, vfs := vfs.NewContext(newContext(t))
|
||||
ctx := util.NewContext(newContext(t))
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
cfg := config(ctx).WithArgs("mptest", name, "config01.test")
|
||||
mod, err := rt.InstantiateModule(ctx, module, cfg)
|
||||
@@ -100,7 +105,6 @@ func Test_config01(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
mod.Close(ctx)
|
||||
vfs.Close()
|
||||
}
|
||||
|
||||
func Test_config02(t *testing.T) {
|
||||
@@ -111,7 +115,7 @@ func Test_config02(t *testing.T) {
|
||||
t.Skip("skipping in CI")
|
||||
}
|
||||
|
||||
ctx, vfs := vfs.NewContext(newContext(t))
|
||||
ctx := util.NewContext(newContext(t))
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
cfg := config(ctx).WithArgs("mptest", name, "config02.test")
|
||||
mod, err := rt.InstantiateModule(ctx, module, cfg)
|
||||
@@ -119,7 +123,6 @@ func Test_config02(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
mod.Close(ctx)
|
||||
vfs.Close()
|
||||
}
|
||||
|
||||
func Test_crash01(t *testing.T) {
|
||||
@@ -127,7 +130,7 @@ func Test_crash01(t *testing.T) {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
ctx, vfs := vfs.NewContext(newContext(t))
|
||||
ctx := util.NewContext(newContext(t))
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
cfg := config(ctx).WithArgs("mptest", name, "crash01.test")
|
||||
mod, err := rt.InstantiateModule(ctx, module, cfg)
|
||||
@@ -135,7 +138,6 @@ func Test_crash01(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
mod.Close(ctx)
|
||||
vfs.Close()
|
||||
}
|
||||
|
||||
func Test_multiwrite01(t *testing.T) {
|
||||
@@ -143,7 +145,7 @@ func Test_multiwrite01(t *testing.T) {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
ctx, vfs := vfs.NewContext(newContext(t))
|
||||
ctx := util.NewContext(newContext(t))
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
cfg := config(ctx).WithArgs("mptest", name, "multiwrite01.test")
|
||||
mod, err := rt.InstantiateModule(ctx, module, cfg)
|
||||
@@ -151,12 +153,11 @@ func Test_multiwrite01(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
mod.Close(ctx)
|
||||
vfs.Close()
|
||||
}
|
||||
|
||||
func Test_config01_memory(t *testing.T) {
|
||||
ctx, vfs := vfs.NewContext(newContext(t))
|
||||
cfg := config(ctx).WithArgs("mptest", "test.db",
|
||||
ctx := util.NewContext(newContext(t))
|
||||
cfg := config(ctx).WithArgs("mptest", "/test.db",
|
||||
"config01.test",
|
||||
"--vfs", "memdb",
|
||||
"--timeout", "1000")
|
||||
@@ -165,7 +166,6 @@ func Test_config01_memory(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
mod.Close(ctx)
|
||||
vfs.Close()
|
||||
}
|
||||
|
||||
func Test_multiwrite01_memory(t *testing.T) {
|
||||
@@ -173,7 +173,7 @@ func Test_multiwrite01_memory(t *testing.T) {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
ctx, vfs := vfs.NewContext(newContext(t))
|
||||
ctx := util.NewContext(newContext(t))
|
||||
cfg := config(ctx).WithArgs("mptest", "/test.db",
|
||||
"multiwrite01.test",
|
||||
"--vfs", "memdb",
|
||||
@@ -183,7 +183,6 @@ func Test_multiwrite01_memory(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
mod.Close(ctx)
|
||||
vfs.Close()
|
||||
}
|
||||
|
||||
func newContext(t *testing.T) context.Context {
|
||||
|
||||
2
vfs/tests/mptest/testdata/.gitattributes
vendored
2
vfs/tests/mptest/testdata/.gitattributes
vendored
@@ -1,2 +1,2 @@
|
||||
mptest.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
mptest.wasm.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.*test -crlf
|
||||
14
vfs/tests/mptest/testdata/build.sh
vendored
14
vfs/tests/mptest/testdata/build.sh
vendored
@@ -4,25 +4,29 @@ set -euo pipefail
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
ROOT=../../../../
|
||||
BINARYEN="$ROOT/tools/binaryen-version_113/bin"
|
||||
BINARYEN="$ROOT/tools/binaryen-version_116/bin"
|
||||
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
|
||||
|
||||
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
|
||||
-o mptest.wasm main.c \
|
||||
-I"$ROOT/sqlite3" \
|
||||
-mmutable-globals \
|
||||
-msimd128 -mmutable-globals \
|
||||
-mbulk-memory -mreference-types \
|
||||
-mnontrapping-fptoint -msign-ext \
|
||||
-fno-stack-protector -fno-stack-clash-protection \
|
||||
-Wl,--stack-first \
|
||||
-Wl,--import-undefined \
|
||||
-D_HAVE_SQLITE_CONFIG_H \
|
||||
-DSQLITE_DEFAULT_SYNCHRONOUS=0 \
|
||||
-DSQLITE_DEFAULT_LOCKING_MODE=0 \
|
||||
-DHAVE_USLEEP -DSQLITE_NO_SYNC \
|
||||
-DSQLITE_THREADSAFE=0 -DSQLITE_OMIT_LOAD_EXTENSION \
|
||||
-D_WASI_EMULATED_GETPID -lwasi-emulated-getpid
|
||||
|
||||
"$BINARYEN/wasm-opt" -g -O2 mptest.wasm -o mptest.tmp \
|
||||
--enable-multivalue --enable-mutable-globals \
|
||||
"$BINARYEN/wasm-opt" -g --strip -c -O3 \
|
||||
mptest.wasm -o mptest.tmp \
|
||||
--enable-simd --enable-mutable-globals --enable-multivalue \
|
||||
--enable-bulk-memory --enable-reference-types \
|
||||
--enable-nontrapping-float-to-int --enable-sign-ext
|
||||
mv mptest.tmp mptest.wasm
|
||||
mv mptest.tmp mptest.wasm
|
||||
bzip2 -9f mptest.wasm
|
||||
2
vfs/tests/mptest/testdata/main.c
vendored
2
vfs/tests/mptest/testdata/main.c
vendored
@@ -1,8 +1,6 @@
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
|
||||
// Configuration
|
||||
#include "sqlite_cfg.h"
|
||||
// Amalgamation
|
||||
#include "sqlite3.c"
|
||||
// VFS
|
||||
|
||||
3
vfs/tests/mptest/testdata/mptest.wasm
vendored
3
vfs/tests/mptest/testdata/mptest.wasm
vendored
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:53bb204be0ddd312c23976e7b627ad62fb384d220557f4e3f723d89273d78a01
|
||||
size 1477009
|
||||
3
vfs/tests/mptest/testdata/mptest.wasm.bz2
vendored
Normal file
3
vfs/tests/mptest/testdata/mptest.wasm.bz2
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:816ddd36a255d2b3995f75924bdebd9dc7c4378eeccbe28176751434cfb788ca
|
||||
size 508404
|
||||
@@ -2,6 +2,7 @@ package speedtest1
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/bzip2"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"flag"
|
||||
@@ -18,12 +19,13 @@ import (
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/ncruces/go-sqlite3/vfs"
|
||||
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
//go:embed testdata/speedtest1.wasm
|
||||
var binary []byte
|
||||
//go:embed testdata/speedtest1.wasm.bz2
|
||||
var compressed string
|
||||
|
||||
var (
|
||||
rt wazero.Runtime
|
||||
@@ -44,6 +46,11 @@ func TestMain(m *testing.M) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
binary, err := io.ReadAll(bzip2.NewReader(strings.NewReader(compressed)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
module, err = rt.CompileModule(ctx, binary)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -74,7 +81,7 @@ func initFlags() {
|
||||
|
||||
func Benchmark_speedtest1(b *testing.B) {
|
||||
output.Reset()
|
||||
ctx, vfs := vfs.NewContext(context.Background())
|
||||
ctx := util.NewContext(context.Background())
|
||||
name := filepath.Join(b.TempDir(), "test.db")
|
||||
args := append(options, "--size", strconv.Itoa(b.N), name)
|
||||
cfg := wazero.NewModuleConfig().
|
||||
@@ -88,5 +95,4 @@ func Benchmark_speedtest1(b *testing.B) {
|
||||
b.Error(err)
|
||||
}
|
||||
mod.Close(ctx)
|
||||
vfs.Close()
|
||||
}
|
||||
|
||||
2
vfs/tests/speedtest1/testdata/.gitattributes
vendored
2
vfs/tests/speedtest1/testdata/.gitattributes
vendored
@@ -1 +1 @@
|
||||
speedtest1.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
speedtest1.wasm.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
16
vfs/tests/speedtest1/testdata/build.sh
vendored
16
vfs/tests/speedtest1/testdata/build.sh
vendored
@@ -4,20 +4,24 @@ set -euo pipefail
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
ROOT=../../../../
|
||||
BINARYEN="$ROOT/tools/binaryen-version_113/bin"
|
||||
BINARYEN="$ROOT/tools/binaryen-version_116/bin"
|
||||
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
|
||||
|
||||
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
|
||||
-o speedtest1.wasm main.c \
|
||||
-I"$ROOT/sqlite3" \
|
||||
-mmutable-globals \
|
||||
-msimd128 -mmutable-globals \
|
||||
-mbulk-memory -mreference-types \
|
||||
-mnontrapping-fptoint -msign-ext \
|
||||
-fno-stack-protector -fno-stack-clash-protection \
|
||||
-Wl,--stack-first \
|
||||
-Wl,--import-undefined
|
||||
-Wl,--import-undefined \
|
||||
-D_HAVE_SQLITE_CONFIG_H
|
||||
|
||||
"$BINARYEN/wasm-opt" -g -O2 speedtest1.wasm -o speedtest1.tmp \
|
||||
--enable-multivalue --enable-mutable-globals \
|
||||
"$BINARYEN/wasm-opt" -g --strip -c -O3 \
|
||||
speedtest1.wasm -o speedtest1.tmp \
|
||||
--enable-simd --enable-mutable-globals --enable-multivalue \
|
||||
--enable-bulk-memory --enable-reference-types \
|
||||
--enable-nontrapping-float-to-int --enable-sign-ext
|
||||
mv speedtest1.tmp speedtest1.wasm
|
||||
mv speedtest1.tmp speedtest1.wasm
|
||||
bzip2 -9f speedtest1.wasm
|
||||
2
vfs/tests/speedtest1/testdata/main.c
vendored
2
vfs/tests/speedtest1/testdata/main.c
vendored
@@ -1,8 +1,6 @@
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
|
||||
// Configuration
|
||||
#include "sqlite_cfg.h"
|
||||
// Amalgamation
|
||||
#include "sqlite3.c"
|
||||
// VFS
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:23ac5877bc34cf5ff06ed2cf1c5ac148e5d364d6d01c85bd0f99d14f0a681830
|
||||
size 1513516
|
||||
3
vfs/tests/speedtest1/testdata/speedtest1.wasm.bz2
vendored
Normal file
3
vfs/tests/speedtest1/testdata/speedtest1.wasm.bz2
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c0c551f14aeb338ab652f2bded33ba58188b76f2cb24ca076576c968341f841e
|
||||
size 523315
|
||||
83
vfs/vfs.go
83
vfs/vfs.go
@@ -44,33 +44,6 @@ func ExportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder
|
||||
return env
|
||||
}
|
||||
|
||||
type vfsKey struct{}
|
||||
type vfsState struct {
|
||||
files []File
|
||||
}
|
||||
|
||||
// NewContext is an internal API users need not call directly.
|
||||
//
|
||||
// NewContext creates a new context to hold [api.Module] specific VFS data.
|
||||
// The context should be passed to any [api.Function] calls that might
|
||||
// generate VFS host callbacks.
|
||||
// The returned [io.Closer] should be closed after the [api.Module] is closed,
|
||||
// to release any associated resources.
|
||||
func NewContext(ctx context.Context) (context.Context, io.Closer) {
|
||||
vfs := new(vfsState)
|
||||
return context.WithValue(ctx, vfsKey{}, vfs), vfs
|
||||
}
|
||||
|
||||
func (vfs *vfsState) Close() error {
|
||||
for _, f := range vfs.files {
|
||||
if f != nil {
|
||||
f.Close()
|
||||
}
|
||||
}
|
||||
vfs.files = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func vfsFind(ctx context.Context, mod api.Module, zVfsName uint32) uint32 {
|
||||
name := util.ReadString(mod, zVfsName, _MAX_STRING)
|
||||
if vfs := Find(name); vfs != nil && vfs != (vfsOS{}) {
|
||||
@@ -183,6 +156,10 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla
|
||||
file, flags, err = vfs.Open(path, flags)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return vfsErrorCode(err, _CANTOPEN)
|
||||
}
|
||||
|
||||
if file, ok := file.(FilePowersafeOverwrite); ok {
|
||||
if !parsed {
|
||||
params = vfsURIParameters(ctx, mod, zPath, flags)
|
||||
@@ -192,14 +169,10 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return vfsErrorCode(err, _CANTOPEN)
|
||||
}
|
||||
|
||||
vfsFileRegister(ctx, mod, pFile, file)
|
||||
if pOutFlags != 0 {
|
||||
util.WriteUint32(mod, pOutFlags, uint32(flags))
|
||||
}
|
||||
vfsFileRegister(ctx, mod, pFile, file)
|
||||
return _OK
|
||||
}
|
||||
|
||||
@@ -291,14 +264,6 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl
|
||||
return _OK
|
||||
}
|
||||
|
||||
case _FCNTL_LOCK_TIMEOUT:
|
||||
if file, ok := file.(*vfsFile); ok {
|
||||
millis := file.lockTimeout.Milliseconds()
|
||||
file.lockTimeout = time.Duration(util.ReadUint32(mod, pArg)) * time.Millisecond
|
||||
util.WriteUint32(mod, pArg, uint32(millis))
|
||||
return _OK
|
||||
}
|
||||
|
||||
case _FCNTL_POWERSAFE_OVERWRITE:
|
||||
if file, ok := file.(FilePowersafeOverwrite); ok {
|
||||
switch util.ReadUint32(mod, pArg) {
|
||||
@@ -336,6 +301,12 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl
|
||||
return vfsErrorCode(err, _IOERR_FSTAT)
|
||||
}
|
||||
|
||||
case _FCNTL_OVERWRITE:
|
||||
if file, ok := file.(FileOverwrite); ok {
|
||||
err := file.Overwrite()
|
||||
return vfsErrorCode(err, _IOERR)
|
||||
}
|
||||
|
||||
case _FCNTL_COMMIT_PHASETWO:
|
||||
if file, ok := file.(FileCommitPhaseTwo); ok {
|
||||
err := file.CommitPhaseTwo()
|
||||
@@ -360,10 +331,8 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl
|
||||
}
|
||||
|
||||
// Consider also implementing these opcodes (in use by SQLite):
|
||||
// _FCNTL_PDB
|
||||
// _FCNTL_BUSYHANDLER
|
||||
// _FCNTL_CHUNK_SIZE
|
||||
// _FCNTL_OVERWRITE
|
||||
// _FCNTL_PRAGMA
|
||||
// _FCNTL_SYNC
|
||||
return _NOTFOUND
|
||||
@@ -431,40 +400,22 @@ func vfsGet(mod api.Module, pVfs uint32) VFS {
|
||||
panic(util.NoVFSErr + util.ErrorString(name))
|
||||
}
|
||||
|
||||
func vfsFileNew(vfs *vfsState, file File) uint32 {
|
||||
// Find an empty slot.
|
||||
for id, f := range vfs.files {
|
||||
if f == nil {
|
||||
vfs.files[id] = file
|
||||
return uint32(id)
|
||||
}
|
||||
}
|
||||
|
||||
// Add a new slot.
|
||||
vfs.files = append(vfs.files, file)
|
||||
return uint32(len(vfs.files) - 1)
|
||||
}
|
||||
|
||||
func vfsFileRegister(ctx context.Context, mod api.Module, pFile uint32, file File) {
|
||||
const fileHandleOffset = 4
|
||||
id := vfsFileNew(ctx.Value(vfsKey{}).(*vfsState), file)
|
||||
id := util.AddHandle(ctx, file)
|
||||
util.WriteUint32(mod, pFile+fileHandleOffset, id)
|
||||
}
|
||||
|
||||
func vfsFileGet(ctx context.Context, mod api.Module, pFile uint32) File {
|
||||
const fileHandleOffset = 4
|
||||
vfs := ctx.Value(vfsKey{}).(*vfsState)
|
||||
id := util.ReadUint32(mod, pFile+fileHandleOffset)
|
||||
return vfs.files[id]
|
||||
return util.GetHandle(ctx, id).(File)
|
||||
}
|
||||
|
||||
func vfsFileClose(ctx context.Context, mod api.Module, pFile uint32) error {
|
||||
const fileHandleOffset = 4
|
||||
vfs := ctx.Value(vfsKey{}).(*vfsState)
|
||||
id := util.ReadUint32(mod, pFile+fileHandleOffset)
|
||||
file := vfs.files[id]
|
||||
vfs.files[id] = nil
|
||||
return file.Close()
|
||||
return util.DelHandle(ctx, id)
|
||||
}
|
||||
|
||||
func vfsErrorCode(err error, def _ErrorCode) _ErrorCode {
|
||||
@@ -477,9 +428,3 @@ func vfsErrorCode(err error, def _ErrorCode) _ErrorCode {
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func clear(b []byte) {
|
||||
for i := range b {
|
||||
b[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,8 +220,7 @@ func Test_vfsAccess(t *testing.T) {
|
||||
|
||||
func Test_vfsFile(t *testing.T) {
|
||||
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
|
||||
ctx, vfs := NewContext(context.TODO())
|
||||
defer vfs.Close()
|
||||
ctx := util.NewContext(context.TODO())
|
||||
|
||||
// Open a temporary file.
|
||||
rc := vfsOpen(ctx, mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)
|
||||
@@ -293,8 +292,7 @@ func Test_vfsFile(t *testing.T) {
|
||||
|
||||
func Test_vfsFile_psow(t *testing.T) {
|
||||
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
|
||||
ctx, vfs := NewContext(context.TODO())
|
||||
defer vfs.Close()
|
||||
ctx := util.NewContext(context.TODO())
|
||||
|
||||
// Open a temporary file.
|
||||
rc := vfsOpen(ctx, mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)
|
||||
|
||||
Reference in New Issue
Block a user