Compare commits

...

60 Commits

Author SHA1 Message Date
Nuno Cruces
994d9b1812 Updated dependencies. 2023-10-02 10:09:26 +01:00
Nuno Cruces
b19bd28ed3 Simplify lock timeouts. 2023-10-02 10:06:09 +01:00
Nuno Cruces
e66bd51845 More VFS API. 2023-09-21 02:43:45 +01:00
Nuno Cruces
f5614bc2ed Tweaks. 2023-09-20 15:07:07 +01:00
Nuno Cruces
d9fcf60b7d Driver API. 2023-09-20 02:41:09 +01:00
Nuno Cruces
ac6dd1aa5f Updated dependencies. 2023-09-18 15:22:11 +01:00
Nuno Cruces
b1495bd6cb Build tags, docs. 2023-09-18 15:11:05 +01:00
Nuno Cruces
2d91760295 Portability. 2023-09-18 12:44:18 +01:00
Nuno Cruces
38d4254bc4 Update README.md 2023-09-15 15:37:57 +01:00
Nuno Cruces
c0aa734786 binaryen-version_116. 2023-09-15 15:10:08 +01:00
Nuno Cruces
fa845dbd3d Run test in all platforms. 2023-09-12 15:30:43 +01:00
Nuno Cruces
fed315ab79 Update go.yml 2023-09-12 15:28:11 +01:00
Nuno Cruces
726d7316f7 Update README.md 2023-09-12 00:00:32 +01:00
Nuno Cruces
ddb387b021 Updated dependencies. 2023-09-11 23:54:22 +01:00
Nuno Cruces
d0f19507f5 SQLite 3.43.1. 2023-09-11 23:48:38 +01:00
Nuno Cruces
9d997552ad Pearson correlation. 2023-09-02 00:48:55 +01:00
Nuno Cruces
9d75c39dcc Update README.md 2023-09-01 16:01:42 +01:00
Nuno Cruces
746a84965e Covariance. 2023-09-01 02:38:57 +01:00
Nuno Cruces
312d3b58f2 Statistics functions. 2023-09-01 01:23:25 +01:00
Nuno Cruces
b71cd295c2 Updated dependencies. 2023-08-25 09:56:09 +01:00
Nuno Cruces
5b3b61a304 SQLite 3.43.0. 2023-08-24 18:56:23 +01:00
Nuno Cruces
d661d15723 wazero v1.5.0. 2023-08-24 18:56:10 +01:00
Nuno Cruces
1e38165ad0 Timer resolution. 2023-08-20 03:12:55 +01:00
Nuno Cruces
58a32d7c9d Update GORM. 2023-08-20 00:56:08 +01:00
Nuno Cruces
6765e883c1 Register collation. 2023-08-10 13:39:52 +01:00
Nuno Cruces
18fc608433 Embed database as string. 2023-08-10 13:23:54 +01:00
Nuno Cruces
77f37893b9 Driver connector. 2023-08-10 13:18:13 +01:00
Nuno Cruces
f1e36e2581 Updated dependencies. 2023-08-09 16:30:32 +01:00
Nuno Cruces
772b9153c7 Use clear builtin. 2023-08-09 16:16:45 +01:00
Nuno Cruces
4b280a3a7e Updated dependencies. 2023-08-09 15:22:48 +01:00
Nuno Cruces
19b6098bf6 Update go.yml (#28) 2023-08-05 01:12:16 +01:00
dependabot[bot]
2aa685320f Bump golang.org/x/text from 0.11.0 to 0.12.0 (#26)
Bumps [golang.org/x/text](https://github.com/golang/text) from 0.11.0 to 0.12.0.
- [Release notes](https://github.com/golang/text/releases)
- [Commits](https://github.com/golang/text/compare/v0.11.0...v0.12.0)

---
updated-dependencies:
- dependency-name: golang.org/x/text
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-08-05 00:36:56 +01:00
dependabot[bot]
9941be05c2 Bump golang.org/x/sys from 0.10.0 to 0.11.0 (#27)
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.10.0 to 0.11.0.
- [Commits](https://github.com/golang/sys/compare/v0.10.0...v0.11.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-08-05 00:35:11 +01:00
Nuno Cruces
a0a9ab7737 Avoid unnecessary alloc. 2023-08-04 14:12:36 +01:00
Nuno Cruces
a77727a1ce Port script. 2023-07-31 15:27:10 +01:00
Nuno Cruces
47fe032078 Updated dependencies. 2023-07-26 12:42:18 +01:00
Nuno Cruces
bdfe279444 Soundex. 2023-07-26 02:02:39 +01:00
dependabot[bot]
a86937a54e Bump github.com/tetratelabs/wazero from 1.3.0 to 1.3.1
Bumps [github.com/tetratelabs/wazero](https://github.com/tetratelabs/wazero) from 1.3.0 to 1.3.1.
- [Release notes](https://github.com/tetratelabs/wazero/releases)
- [Commits](https://github.com/tetratelabs/wazero/compare/v1.3.0...v1.3.1)

---
updated-dependencies:
- dependency-name: github.com/tetratelabs/wazero
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-07-25 08:02:20 +01:00
Nuno Cruces
6ef422fbde Unicode tests. 2023-07-13 12:19:32 +01:00
Nuno Cruces
ff0cb6fb88 Unicode tests, fixes. 2023-07-12 13:39:07 +01:00
Nuno Cruces
72db90efdf Unicode. 2023-07-11 16:34:15 +01:00
Nuno Cruces
5a3fdef3c5 wazero v1.3.0. 2023-07-11 12:30:39 +01:00
dependabot[bot]
ff34b0cae1 Bump golang.org/x/text from 0.10.0 to 0.11.0
Bumps [golang.org/x/text](https://github.com/golang/text) from 0.10.0 to 0.11.0.
- [Release notes](https://github.com/golang/text/releases)
- [Commits](https://github.com/golang/text/compare/v0.10.0...v0.11.0)

---
updated-dependencies:
- dependency-name: golang.org/x/text
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-07-04 23:55:17 +01:00
Nuno Cruces
f064492bb1 Updated dependencies. 2023-07-04 19:55:11 +01:00
Nuno Cruces
1427d30541 Updated dependencies. 2023-07-04 19:48:55 +01:00
Nuno Cruces
d3730341f0 Unknown collations. 2023-07-04 11:16:29 +01:00
Nuno Cruces
78ac2386f6 Refactor. 2023-07-04 02:29:38 +01:00
Nuno Cruces
632ea933b3 Function aux data. 2023-07-04 02:18:03 +01:00
Nuno Cruces
0f7fa6ebc9 Tests. 2023-07-03 18:28:46 +01:00
Nuno Cruces
6f7f776488 Refactor. 2023-07-03 17:42:53 +01:00
Nuno Cruces
f6d7c5e9c5 Refactor. 2023-07-03 17:08:16 +01:00
Nuno Cruces
1cc7ecfe8d Custom aggregate functions. 2023-07-03 15:45:16 +01:00
Nuno Cruces
3844e81404 Custom aggregate functions. 2023-07-01 15:19:45 +01:00
Nuno Cruces
fec1f8d32a Custom scalar functions. 2023-07-01 00:16:42 +01:00
Nuno Cruces
31572e6095 Fix nil/zero handles. 2023-06-30 17:09:01 +01:00
Nuno Cruces
4aee38b957 Error handling. 2023-06-30 12:25:07 +01:00
Nuno Cruces
232a7705b5 Wrap context. 2023-06-30 11:48:54 +01:00
Nuno Cruces
a6c2fccd74 Wrap value. 2023-06-30 10:45:16 +01:00
Nuno Cruces
6a982559cd Custom collating sequences. 2023-06-30 02:49:21 +01:00
Nuno Cruces
c7904d30de Refactor file handles. 2023-06-30 01:52:18 +01:00
73 changed files with 2641 additions and 587 deletions

View File

@@ -34,9 +34,8 @@ jobs:
- name: Download
run: go mod download
# Fixed in go 1.21: https://go.dev/issue/54372
# - name: Verify
# run: go mod verify
- name: Verify
run: go mod verify
- name: Vet
run: go vet ./...
@@ -48,8 +47,12 @@ jobs:
- name: Test
run: go test -v ./...
- name: Test no locks
run: go test -v -tags sqlite3_nolock .
if: matrix.os == 'ubuntu-latest'
- name: Test BSD locks
run: go test -v -tags sqlite3_bsd ./...
run: go test -v -tags sqlite3_flock ./...
if: matrix.os == 'macos-latest'
- name: Coverage report
@@ -57,7 +60,8 @@ jobs:
with:
chart: 'true'
amend: 'true'
reuse-go: 'true'
if: |
matrix.os == 'ubuntu-latest' &&
github.event_name == 'push'
github.event_name == 'push' &&
matrix.os == 'ubuntu-latest'
continue-on-error: true

View File

@@ -7,18 +7,26 @@
Go module `github.com/ncruces/go-sqlite3` wraps a [WASM](https://webassembly.org/) build of [SQLite](https://sqlite.org/),
and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings.
- Package [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
wraps the [C SQLite API](https://www.sqlite.org/cintro.html)
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
- Package [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver)
provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
- Package [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
embeds a build of SQLite into your application.
- Package [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
- Package [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
provides a [GORM](https://gorm.io) driver.
- [`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
wraps the [C SQLite API](https://www.sqlite.org/cintro.html)
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3#example-package)).
- [`github.com/ncruces/go-sqlite3/driver`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver)
provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
- [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
embeds a build of SQLite into your application.
- [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
- [`github.com/ncruces/go-sqlite3/vfs/memdb`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/memdb)
implements an in-memory VFS.
- [`github.com/ncruces/go-sqlite3/vfs/readervfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/readervfs)
implements a VFS for immutable databases.
- [`github.com/ncruces/go-sqlite3/ext/unicode`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/unicode)
registers Unicode aware functions.
- [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
registers [statistics functions](https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html).
- [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
provides a [GORM](https://gorm.io) driver.
### Caveats
@@ -35,26 +43,37 @@ To work around this limitation, SQLite is [patched](sqlite3/locking_mode.patch)
to always use `EXCLUSIVE` locking mode for WAL databases.
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
to open WAL databases you should disable connection pooling by calling
to use the [`database/sql`](https://pkg.go.dev/database/sql)
driver with WAL mode databases you should disable connection pooling by calling
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
#### POSIX Advisory Locks
POSIX advisory locks, which SQLite uses, are
[broken by design](https://www.sqlite.org/src/artifact/90c4fa?ln=1073-1161).
[broken by design](https://www.sqlite.org/src/artifact/2e8b12?ln=1073-1161).
On Linux, macOS and illumos, this module uses
[OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
to synchronize access to database files.
OFD locks are fully compatible with process-associated POSIX advisory locks.
On BSD Unixes, this module uses
On BSD Unixes, this module may use
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2).
BSD locks may _not_ be compatible with process-associated POSIX advisory locks.
##### TL;DR
In all platforms for which this package builds,
it should be safe to use it to access databases concurrently,
from multiple goroutines, processes, and
with _other_ implementations of SQLite.
If the package does not build for your platform,
see [this](vfs/README.md#portability).
#### Testing
The pure Go VFS is tested by running an unmodified build of SQLite's
The pure Go VFS is tested by running SQLite's
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
on Linux, macOS and Windows.
Performance is tested by running
@@ -63,6 +82,7 @@ Performance is tested by running
### Roadmap
- [ ] advanced SQLite features
- [x] custom functions
- [x] nested transactions
- [x] incremental BLOB I/O
- [x] online backup
@@ -72,11 +92,10 @@ Performance is tested by running
- [x] in-memory VFS
- [x] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
- [ ] custom SQL functions
### Alternatives
- [`modernc.org/sqlite`](https://pkg.go.dev/modernc.org/sqlite)
- [`crawshaw.io/sqlite`](https://pkg.go.dev/crawshaw.io/sqlite)
- [`github.com/mattn/go-sqlite3`](https://pkg.go.dev/github.com/mattn/go-sqlite3)
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)
- [`github.com/zombiezen/go-sqlite`](https://pkg.go.dev/github.com/zombiezen/go-sqlite)

View File

@@ -77,7 +77,7 @@ func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string
if r == 0 {
defer c.closeDB(other)
r = c.call(c.api.errcode, uint64(dst))
return nil, c.module.error(r, dst)
return nil, c.sqlite.error(r, dst)
}
return &Backup{

36
conn.go
View File

@@ -2,7 +2,6 @@ package sqlite3
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"net/url"
@@ -19,7 +18,7 @@ import (
//
// https://www.sqlite.org/c3ref/sqlite3.html
type Conn struct {
*module
*sqlite
interrupt context.Context
waiter chan struct{}
@@ -50,19 +49,19 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
}
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
mod, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
return nil, err
}
defer func() {
if conn == nil {
mod.close()
sqlite.close()
} else {
runtime.SetFinalizer(conn, util.Finalizer[Conn](3))
}
}()
c := &Conn{module: mod}
c := &Conn{sqlite: sqlite}
c.arena = c.newArena(1024)
c.handle, err = c.openDB(filename, flags)
if err != nil {
@@ -80,7 +79,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
handle := util.ReadUint32(c.mod, connPtr)
if err := c.module.error(r, handle); err != nil {
if err := c.sqlite.error(r, handle); err != nil {
c.closeDB(handle)
return 0, err
}
@@ -99,7 +98,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
c.arena.reset()
pragmaPtr := c.arena.string(pragmas.String())
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.module.error(r, handle, pragmas.String()); err != nil {
if err := c.sqlite.error(r, handle, pragmas.String()); err != nil {
if errors.Is(err, ERROR) {
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
@@ -113,7 +112,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
func (c *Conn) closeDB(handle uint32) {
r := c.call(c.api.closeZombie, uint64(handle))
if err := c.module.error(r, handle); err != nil {
if err := c.sqlite.error(r, handle); err != nil {
panic(err)
}
}
@@ -143,7 +142,7 @@ func (c *Conn) Close() error {
c.handle = 0
runtime.SetFinalizer(c, nil)
return c.module.close()
return c.close()
}
// Exec is a convenience function that allows an application to run
@@ -240,6 +239,11 @@ func (c *Conn) Changes() int64 {
//
// https://www.sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
// Is it the same context?
if ctx == c.interrupt {
return ctx
}
// Is a waiter running?
if c.waiter != nil {
c.waiter <- struct{}{} // Cancel the waiter.
@@ -319,7 +323,7 @@ func (c *Conn) Pragma(str string) ([]string, error) {
}
func (c *Conn) error(rc uint64, sql ...string) error {
return c.module.error(rc, c.handle, sql...)
return c.sqlite.error(rc, c.handle, sql...)
}
// DriverConn is implemented by the SQLite [database/sql] driver connection.
@@ -331,15 +335,5 @@ func (c *Conn) error(rc uint64, sql ...string) error {
// [online backup]: https://www.sqlite.org/backup.html
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
type DriverConn interface {
driver.Conn
driver.ConnBeginTx
driver.ExecerContext
driver.ConnPrepareContext
SetInterrupt(ctx context.Context) (old context.Context)
Savepoint() Savepoint
Backup(srcDB, dstURI string) error
Restore(dstDB, srcURI string) error
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
Raw() *Conn
}

View File

@@ -167,6 +167,18 @@ const (
PREPARE_NO_VTAB PrepareFlag = 0x04
)
// FunctionFlag is a flag that can be passed to [Conn.PrepareFlags].
//
// https://www.sqlite.org/c3ref/c_deterministic.html
type FunctionFlag uint32
const (
DETERMINISTIC FunctionFlag = 0x000000800
DIRECTONLY FunctionFlag = 0x000080000
SUBTYPE FunctionFlag = 0x000100000
INNOCUOUS FunctionFlag = 0x000200000
)
// Datatype is a fundamental datatype of SQLite.
//
// https://www.sqlite.org/c3ref/c_blob.html
@@ -182,18 +194,18 @@ const (
// String implements the [fmt.Stringer] interface.
func (t Datatype) String() string {
const name = "INTEGERFLOATTEXTBLOBNULL"
const name = "INTEGERFLOATEXTBLOBNULL"
switch t {
case INTEGER:
return name[0:7]
case FLOAT:
return name[7:12]
case TEXT:
return name[12:16]
return name[11:15]
case BLOB:
return name[16:20]
return name[15:19]
case NULL:
return name[20:24]
return name[19:23]
}
return strconv.FormatUint(uint64(t), 10)
}

174
context.go Normal file
View File

@@ -0,0 +1,174 @@
package sqlite3
import (
"errors"
"math"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Context is the context in which an SQL function executes.
// An SQLite [Context] is in no way related to a Go [context.Context].
//
// https://www.sqlite.org/c3ref/context.html
type Context struct {
*sqlite
handle uint32
}
// SetAuxData saves metadata for argument n of the function.
//
// https://www.sqlite.org/c3ref/get_auxdata.html
func (c Context) SetAuxData(n int, data any) {
ptr := util.AddHandle(c.ctx, data)
c.call(c.api.setAuxData, uint64(c.handle), uint64(n), uint64(ptr))
}
// GetAuxData returns metadata for argument n of the function.
//
// https://www.sqlite.org/c3ref/get_auxdata.html
func (c Context) GetAuxData(n int) any {
ptr := uint32(c.call(c.api.getAuxData, uint64(c.handle), uint64(n)))
return util.GetHandle(c.ctx, ptr)
}
// ResultBool sets the result of the function to a bool.
// SQLite does not have a separate boolean storage class.
// Instead, boolean values are stored as integers 0 (false) and 1 (true).
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultBool(value bool) {
var i int64
if value {
i = 1
}
c.ResultInt64(i)
}
// ResultInt sets the result of the function to an int.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultInt(value int) {
c.ResultInt64(int64(value))
}
// ResultInt64 sets the result of the function to an int64.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultInt64(value int64) {
c.call(c.api.resultInteger,
uint64(c.handle), uint64(value))
}
// ResultFloat sets the result of the function to a float64.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultFloat(value float64) {
c.call(c.api.resultFloat,
uint64(c.handle), math.Float64bits(value))
}
// ResultText sets the result of the function to a string.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultText(value string) {
ptr := c.newString(value)
c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(value)),
uint64(c.api.destructor), _UTF8)
}
// ResultBlob sets the result of the function to a []byte.
// Returning a nil slice is the same as calling [Context.ResultNull].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultBlob(value []byte) {
ptr := c.newBytes(value)
c.call(c.api.resultBlob,
uint64(c.handle), uint64(ptr), uint64(len(value)),
uint64(c.api.destructor))
}
// BindZeroBlob sets the result of the function to a zero-filled, length n BLOB.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultZeroBlob(n int64) {
c.call(c.api.resultZeroBlob,
uint64(c.handle), uint64(n))
}
// ResultNull sets the result of the function to NULL.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultNull() {
c.call(c.api.resultNull,
uint64(c.handle))
}
// ResultTime sets the result of the function to a [time.Time].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultTime(value time.Time, format TimeFormat) {
if format == TimeFormatDefault {
c.resultRFC3339Nano(value)
return
}
switch v := format.Encode(value).(type) {
case string:
c.ResultText(v)
case int64:
c.ResultInt64(v)
case float64:
c.ResultFloat(v)
default:
panic(util.AssertErr())
}
}
func (c Context) resultRFC3339Nano(value time.Time) {
const maxlen = uint64(len(time.RFC3339Nano))
ptr := c.new(maxlen)
buf := util.View(c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(buf)),
uint64(c.api.destructor), _UTF8)
}
// ResultError sets the result of the function an error.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultError(err error) {
if errors.Is(err, NOMEM) {
c.call(c.api.resultErrorMem, uint64(c.handle))
return
}
if errors.Is(err, TOOBIG) {
c.call(c.api.resultErrorBig, uint64(c.handle))
return
}
str := err.Error()
ptr := c.newString(str)
c.call(c.api.resultError,
uint64(c.handle), uint64(ptr), uint64(len(str)))
c.free(ptr)
var code uint64
var ecode ErrorCode
var xcode xErrorCode
switch {
case errors.As(err, &xcode):
code = uint64(xcode)
case errors.As(err, &ecode):
code = uint64(ecode)
}
if code != 0 {
c.call(c.api.resultErrorCode,
uint64(c.handle), code)
}
}

View File

@@ -40,42 +40,96 @@ import (
"github.com/ncruces/go-sqlite3/internal/util"
)
// This variable can be replaced with -ldflags:
//
// go build -ldflags="-X github.com/ncruces/go-sqlite3.driverName=sqlite"
var driverName = "sqlite3"
func init() {
sql.Register("sqlite3", sqlite{})
if driverName != "" {
sql.Register(driverName, sqlite{})
}
}
// Open opens the SQLite database specified by dataSourceName as a [database/sql.DB].
//
// The init function is called by the driver on new connections.
// The conn can be used to execute queries, register functions, etc.
// Any error return closes the conn and passes the error to database/sql.
func Open(dataSourceName string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*sql.DB, error) {
c, err := newConnector(dataSourceName, init)
if err != nil {
return nil, err
}
return sql.OpenDB(c), nil
}
type sqlite struct{}
func (sqlite) Open(name string) (_ driver.Conn, err error) {
var c conn
c.Conn, err = sqlite3.Open(name)
func (sqlite) Open(name string) (driver.Conn, error) {
c, err := newConnector(name, nil)
if err != nil {
return nil, err
}
return c.Connect(context.Background())
}
var pragmas bool
c.txBegin = "BEGIN"
func (sqlite) OpenConnector(name string) (driver.Connector, error) {
return newConnector(name, nil)
}
func newConnector(name string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*connector, error) {
c := connector{name: name, init: init}
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
query, _ := url.ParseQuery(after)
switch s := query.Get("_txlock"); s {
case "":
c.txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
c.txBegin = "BEGIN " + s
default:
c.Close()
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
query, err := url.ParseQuery(after)
if err != nil {
return nil, err
}
pragmas = len(query["_pragma"]) > 0
c.txlock = query.Get("_txlock")
c.pragmas = len(query["_pragma"]) > 0
}
}
if !pragmas {
err := c.Conn.Exec(`PRAGMA busy_timeout=60000`)
return &c, nil
}
type connector struct {
init func(ctx context.Context, conn *sqlite3.Conn) error
name string
txlock string
pragmas bool
}
func (n *connector) Driver() driver.Driver {
return sqlite{}
}
func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
var c conn
c.Conn, err = sqlite3.Open(n.name)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
c.Close()
}
}()
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
switch n.txlock {
case "":
c.txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
c.txBegin = "BEGIN " + n.txlock
default:
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", n.txlock)
}
if !n.pragmas {
err = c.Conn.Exec(`PRAGMA busy_timeout=60000`)
if err != nil {
return nil, err
}
c.reusable = true
@@ -86,7 +140,6 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
PRAGMA_query_only;
`)
if err != nil {
c.Close()
return nil, err
}
if s.Step() {
@@ -95,7 +148,12 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
}
err = s.Close()
if err != nil {
c.Close()
return nil, err
}
}
if n.init != nil {
err = n.init(ctx, c.Conn)
if err != nil {
return nil, err
}
}
@@ -113,12 +171,17 @@ type conn struct {
var (
// Ensure these interfaces are implemented:
_ driver.ExecerContext = &conn{}
_ driver.ConnBeginTx = &conn{}
_ driver.Validator = &conn{}
_ sqlite3.DriverConn = &conn{}
_ driver.ConnPrepareContext = &conn{}
_ driver.ExecerContext = &conn{}
_ driver.ConnBeginTx = &conn{}
_ driver.Validator = &conn{}
_ sqlite3.DriverConn = &conn{}
)
func (c *conn) Raw() *sqlite3.Conn {
return c.Conn
}
func (c *conn) IsValid() bool {
return c.reusable
}
@@ -163,7 +226,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
func (c *conn) Commit() error {
err := c.Conn.Exec(c.txCommit)
if err != nil && !c.GetAutocommit() {
if err != nil && !c.Conn.GetAutocommit() {
c.Rollback()
}
return err
@@ -255,13 +318,14 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
}
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
// Use QueryContext to setup bindings.
// No need to close rows: that simply resets the statement, exec does the same.
_, err := s.QueryContext(ctx, args)
err := s.setupBindings(args)
if err != nil {
return nil, err
}
old := s.Conn.SetInterrupt(ctx)
defer s.Conn.SetInterrupt(old)
err = s.Stmt.Exec()
if err != nil {
return nil, err
@@ -271,10 +335,18 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
}
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
err := s.Stmt.ClearBindings()
err := s.setupBindings(args)
if err != nil {
return nil, err
}
return &rows{ctx, s.Stmt, s.Conn}, nil
}
func (s *stmt) setupBindings(args []driver.NamedValue) error {
err := s.Stmt.ClearBindings()
if err != nil {
return err
}
var ids [3]int
for _, arg := range args {
@@ -314,11 +386,10 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
}
}
if err != nil {
return nil, err
return err
}
}
return &rows{ctx, s.Stmt, s.Conn}, nil
return nil
}
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {

View File

@@ -47,7 +47,7 @@ func ExampleDriverConn() {
}
err = conn.Raw(func(driverConn any) error {
conn := driverConn.(sqlite3.DriverConn)
conn := driverConn.(sqlite3.DriverConn).Raw()
savept := conn.Savepoint()
defer savept.Release(&err)

View File

@@ -1,6 +1,6 @@
# Embeddable WASM build of SQLite
This folder includes an embeddable WASM build of SQLite 3.42.0 for use with
This folder includes an embeddable WASM build of SQLite 3.43.1 for use with
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
The following optional features are compiled in:
@@ -9,6 +9,7 @@ The following optional features are compiled in:
- [JSON](https://www.sqlite.org/json1.html)
- [R*Tree](https://www.sqlite.org/rtree.html)
- [GeoPoly](https://www.sqlite.org/geopoly.html)
- [soundex](https://www.sqlite.org/lang_corefunc.html#soundex)
- [base64](https://github.com/sqlite/sqlite/blob/master/ext/misc/base64.c)
- [decimal](https://github.com/sqlite/sqlite/blob/master/ext/misc/decimal.c)
- [regexp](https://github.com/sqlite/sqlite/blob/master/ext/misc/regexp.c)

View File

@@ -4,7 +4,7 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../
BINARYEN="$ROOT/tools/binaryen-version_114/bin"
BINARYEN="$ROOT/tools/binaryen-version_116/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \

View File

@@ -33,10 +33,10 @@ sqlite3_column_blob
sqlite3_column_bytes
sqlite3_blob_open
sqlite3_blob_close
sqlite3_blob_reopen
sqlite3_blob_bytes
sqlite3_blob_read
sqlite3_blob_write
sqlite3_blob_reopen
sqlite3_backup_init
sqlite3_backup_step
sqlite3_backup_finish
@@ -48,12 +48,14 @@ sqlite3_changes64
sqlite3_last_insert_rowid
sqlite3_get_autocommit
sqlite3_anycollseq_init
sqlite3_create_go_collation
sqlite3_create_go_function
sqlite3_create_go_window_function
sqlite3_create_go_aggregate_function
sqlite3_create_collation_go
sqlite3_create_function_go
sqlite3_create_aggregate_function_go
sqlite3_create_window_function_go
sqlite3_aggregate_context
sqlite3_user_data
sqlite3_set_auxdata_go
sqlite3_get_auxdata
sqlite3_value_type
sqlite3_value_int64
sqlite3_value_double

Binary file not shown.

View File

@@ -68,6 +68,19 @@ func (e *Error) Is(err error) bool {
return false
}
// As converts this error to an [ErrorCode] or [ExtendedErrorCode].
func (e *Error) As(err any) bool {
switch c := err.(type) {
case *ErrorCode:
*c = e.Code()
return true
case *ExtendedErrorCode:
*c = e.ExtendedCode()
return true
}
return false
}
// Temporary returns true for [BUSY] errors.
func (e *Error) Temporary() bool {
return e.Code() == BUSY
@@ -104,6 +117,15 @@ func (e ExtendedErrorCode) Is(err error) bool {
return ok && c == ErrorCode(e)
}
// As converts this error to an [ErrorCode].
func (e ExtendedErrorCode) As(err any) bool {
c, ok := err.(*ErrorCode)
if ok {
*c = ErrorCode(e)
}
return ok
}
// Temporary returns true for [BUSY] errors.
func (e ExtendedErrorCode) Temporary() bool {
return ErrorCode(e) == BUSY

View File

@@ -18,22 +18,36 @@ func Test_assertErr(t *testing.T) {
func TestError(t *testing.T) {
t.Parallel()
err := Error{code: 0x8080}
if rc := err.Code(); rc != 0x80 {
t.Errorf("got %#x, want 0x80", rc)
var ecode ErrorCode
var xcode xErrorCode
err := &Error{code: 0x8080}
if !errors.As(err, &err) {
t.Fatal("want true")
}
if !errors.Is(&err, ErrorCode(0x80)) {
if ecode := err.Code(); ecode != 0x80 {
t.Errorf("got %#x, want 0x80", uint8(ecode))
}
if ok := errors.As(err, &ecode); !ok || ecode != ErrorCode(0x80) {
t.Errorf("got %#x, want 0x80", uint8(ecode))
}
if !errors.Is(err, ErrorCode(0x80)) {
t.Errorf("want true")
}
if rc := err.ExtendedCode(); rc != 0x8080 {
t.Errorf("got %#x, want 0x8080", rc)
if xcode := err.ExtendedCode(); xcode != 0x8080 {
t.Errorf("got %#x, want 0x8080", uint16(xcode))
}
if !errors.Is(&err, ExtendedErrorCode(0x8080)) {
if ok := errors.As(err, &xcode); !ok || xcode != xErrorCode(0x8080) {
t.Errorf("got %#x, want 0x8080", uint16(xcode))
}
if !errors.Is(err, xErrorCode(0x8080)) {
t.Errorf("want true")
}
if s := err.Error(); s != "sqlite3: 32896" {
t.Errorf("got %q", s)
}
if ok := errors.As(err.ExtendedCode(), &ecode); !ok || ecode != ErrorCode(0x80) {
t.Errorf("got %#x, want 0x80", uint8(ecode))
}
if !errors.Is(err.ExtendedCode(), ErrorCode(0x80)) {
t.Errorf("want true")
}

109
ext/stats/stats.go Normal file
View File

@@ -0,0 +1,109 @@
// Package stats provides aggregate functions for statistics.
//
// Functions:
// - stddev_pop: population standard deviation
// - stddev_samp: sample standard deviation
// - var_pop: population variance
// - var_samp: sample variance
// - covar_pop: population covariance
// - covar_samp: sample covariance
// - corr: correlation coefficient
//
// See: [ANSI SQL Aggregate Functions]
//
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
package stats
import "github.com/ncruces/go-sqlite3"
// Register registers statistics functions.
func Register(db *sqlite3.Conn) {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop))
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp))
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop))
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp))
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr))
}
const (
var_pop = iota
var_samp
stddev_pop
stddev_samp
corr
)
func newVariance(kind int) func() sqlite3.AggregateFunction {
return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
}
type variance struct {
kind int
welford
}
func (fn *variance) Value(ctx sqlite3.Context) {
var r float64
switch fn.kind {
case var_pop:
r = fn.var_pop()
case var_samp:
r = fn.var_samp()
case stddev_pop:
r = fn.stddev_pop()
case stddev_samp:
r = fn.stddev_samp()
}
ctx.ResultFloat(r)
}
func (fn *variance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
if a := arg[0]; a.Type() != sqlite3.NULL {
fn.enqueue(a.Float())
}
}
func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
if a := arg[0]; a.Type() != sqlite3.NULL {
fn.dequeue(a.Float())
}
}
func newCovariance(kind int) func() sqlite3.AggregateFunction {
return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
}
type covariance struct {
kind int
welford2
}
func (fn *covariance) Value(ctx sqlite3.Context) {
var r float64
switch fn.kind {
case var_pop:
r = fn.covar_pop()
case var_samp:
r = fn.covar_samp()
case corr:
r = fn.correlation()
}
ctx.ResultFloat(r)
}
func (fn *covariance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
a, b := arg[0], arg[1]
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
fn.enqueue(a.Float(), b.Float())
}
}
func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
a, b := arg[0], arg[1]
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
fn.dequeue(a.Float(), b.Float())
}
}

140
ext/stats/stats_test.go Normal file
View File

@@ -0,0 +1,140 @@
package stats
import (
"math"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestRegister_variance(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
Register(db)
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`
SELECT
sum(x), avg(x),
var_samp(x), var_pop(x),
stddev_samp(x), stddev_pop(x)
FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 40 {
t.Errorf("got %v, want 40", got)
}
if got := stmt.ColumnFloat(1); got != 10 {
t.Errorf("got %v, want 10", got)
}
if got := stmt.ColumnFloat(2); got != 30 {
t.Errorf("got %v, want 30", got)
}
if got := stmt.ColumnFloat(3); got != 22.5 {
t.Errorf("got %v, want 22.5", got)
}
if got := stmt.ColumnFloat(4); got != math.Sqrt(30) {
t.Errorf("got %v, want √30", got)
}
if got := stmt.ColumnFloat(5); got != math.Sqrt(22.5) {
t.Errorf("got %v, want √22.5", got)
}
}
{
stmt, _, err := db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
want := [...]float64{0, 4.5, 18, 0, 0}
for i := 0; stmt.Step(); i++ {
if got := stmt.ColumnFloat(0); got != want[i] {
t.Errorf("got %v, want %v", got, want[i])
}
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
t.Errorf("got %v, want %v", got, want[i])
}
}
}
}
func TestRegister_covariance(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
Register(db)
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x, y)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO data (x, y) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT
corr(x, y), covar_samp(x, y), covar_pop(x, y) FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
t.Errorf("got %v, want 0.9881049293224639", got)
}
if got := stmt.ColumnFloat(1); got != 21.25 {
t.Errorf("got %v, want 21.25", got)
}
if got := stmt.ColumnFloat(2); got != 17 {
t.Errorf("got %v, want 17", got)
}
}
{
stmt, _, err := db.Prepare(`SELECT covar_samp(x, y) OVER (ROWS 1 PRECEDING) FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
want := [...]float64{0, 10, 30, 75, 22.5}
for i := 0; stmt.Step(); i++ {
if got := stmt.ColumnFloat(0); got != want[i] {
t.Errorf("got %v, want %v", got, want[i])
}
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
t.Errorf("got %v, want %v", got, want[i])
}
}
}
}

109
ext/stats/welford.go Normal file
View File

@@ -0,0 +1,109 @@
package stats
import "math"
// Welford's algorithm with Kahan summation:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
type welford struct {
m1, m2 kahan
n uint64
}
func (w welford) average() float64 {
return w.m1.hi
}
func (w welford) var_pop() float64 {
return w.m2.hi / float64(w.n)
}
func (w welford) var_samp() float64 {
return w.m2.hi / float64(w.n-1) // Bessel's correction
}
func (w welford) stddev_pop() float64 {
return math.Sqrt(w.var_pop())
}
func (w welford) stddev_samp() float64 {
return math.Sqrt(w.var_samp())
}
func (w *welford) enqueue(x float64) {
w.n++
d1 := x - w.m1.hi - w.m1.lo
w.m1.add(d1 / float64(w.n))
d2 := x - w.m1.hi - w.m1.lo
w.m2.add(d1 * d2)
}
func (w *welford) dequeue(x float64) {
w.n--
d1 := x - w.m1.hi - w.m1.lo
w.m1.sub(d1 / float64(w.n))
d2 := x - w.m1.hi - w.m1.lo
w.m2.sub(d1 * d2)
}
type welford2 struct {
m1x, m2x kahan
m1y, m2y kahan
cov kahan
n uint64
}
func (w welford2) covar_pop() float64 {
return w.cov.hi / float64(w.n)
}
func (w welford2) covar_samp() float64 {
return w.cov.hi / float64(w.n-1) // Bessel's correction
}
func (w welford2) correlation() float64 {
return w.cov.hi / math.Sqrt(w.m2x.hi*w.m2y.hi)
}
func (w *welford2) enqueue(x, y float64) {
w.n++
d1x := x - w.m1x.hi - w.m1x.lo
d1y := y - w.m1y.hi - w.m1y.lo
w.m1x.add(d1x / float64(w.n))
w.m1y.add(d1y / float64(w.n))
d2x := x - w.m1x.hi - w.m1x.lo
d2y := y - w.m1y.hi - w.m1y.lo
w.m2x.add(d1x * d2x)
w.m2y.add(d1y * d2y)
w.cov.add(d1x * d2y)
}
func (w *welford2) dequeue(x, y float64) {
w.n--
d1x := x - w.m1x.hi - w.m1x.lo
d1y := y - w.m1y.hi - w.m1y.lo
w.m1x.sub(d1x / float64(w.n))
w.m1y.sub(d1y / float64(w.n))
d2x := x - w.m1x.hi - w.m1x.lo
d2y := y - w.m1y.hi - w.m1y.lo
w.m2x.sub(d1x * d2x)
w.m2y.sub(d1y * d2y)
w.cov.sub(d1x * d2y)
}
type kahan struct{ hi, lo float64 }
func (k *kahan) add(x float64) {
y := k.lo + x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}
func (k *kahan) sub(x float64) {
y := k.lo - x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}

75
ext/stats/welford_test.go Normal file
View File

@@ -0,0 +1,75 @@
package stats
import (
"math"
"testing"
)
func Test_welford(t *testing.T) {
var s1, s2 welford
s1.enqueue(4)
s1.enqueue(7)
s1.enqueue(13)
s1.enqueue(16)
if got := s1.average(); got != 10 {
t.Errorf("got %v, want 10", got)
}
if got := s1.var_samp(); got != 30 {
t.Errorf("got %v, want 30", got)
}
if got := s1.var_pop(); got != 22.5 {
t.Errorf("got %v, want 22.5", got)
}
if got := s1.stddev_samp(); got != math.Sqrt(30) {
t.Errorf("got %v, want √30", got)
}
if got := s1.stddev_pop(); got != math.Sqrt(22.5) {
t.Errorf("got %v, want √22.5", got)
}
s1.dequeue(4)
s2.enqueue(7)
s2.enqueue(13)
s2.enqueue(16)
if s1.var_pop() != s2.var_pop() {
t.Errorf("got %v, want %v", s1, s2)
}
}
func Test_covar(t *testing.T) {
var c1, c2 welford2
c1.enqueue(3, 70)
c1.enqueue(5, 80)
c1.enqueue(2, 60)
c1.enqueue(7, 90)
c1.enqueue(4, 75)
if got := c1.covar_samp(); got != 21.25 {
t.Errorf("got %v, want 21.25", got)
}
if got := c1.covar_pop(); got != 17 {
t.Errorf("got %v, want 17", got)
}
c1.dequeue(3, 70)
c2.enqueue(5, 80)
c2.enqueue(2, 60)
c2.enqueue(7, 90)
c2.enqueue(4, 75)
if c1.covar_pop() != c2.covar_pop() {
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
}
}
func Test_correlation(t *testing.T) {
var c welford2
c.enqueue(1, 3)
c.enqueue(2, 2)
c.enqueue(3, 1)
if got := c.correlation(); got != -1 {
t.Errorf("got %v, want -1", got)
}
}

181
ext/unicode/unicode.go Normal file
View File

@@ -0,0 +1,181 @@
// Package unicode provides an alternative to the SQLite ICU extension.
//
// Like the [ICU extension], it provides Unicode aware:
// - upper() and lower() functions,
// - LIKE and REGEXP operators,
// - collation sequences.
//
// The implementation is not 100% compatible with the [ICU extension]:
// - upper() and lower() use [strings.ToUpper], [strings.ToLower] and [cases];
// - the LIKE operator follows [strings.EqualFold] rules;
// - the REGEXP operator uses Go [regex/syntax];
// - collation sequences use [collate].
//
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
//
// [ICU extension]: https://sqlite.org/src/dir/ext/icu
package unicode
import (
"bytes"
"regexp"
"strings"
"unicode/utf8"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"golang.org/x/text/cases"
"golang.org/x/text/collate"
"golang.org/x/text/language"
)
// Register registers Unicode aware functions for a database connection.
func Register(db *sqlite3.Conn) {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
db.CreateFunction("like", 2, flags, like)
db.CreateFunction("like", 3, flags, like)
db.CreateFunction("upper", 1, flags, upper)
db.CreateFunction("upper", 2, flags, upper)
db.CreateFunction("lower", 1, flags, lower)
db.CreateFunction("lower", 2, flags, lower)
db.CreateFunction("regexp", 2, flags, regex)
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
name := arg[1].Text()
if name == "" {
return
}
err := RegisterCollation(db, arg[0].Text(), name)
if err != nil {
ctx.ResultError(err)
return
}
})
}
// RegisterCollation registers a Unicode collation sequence for a database connection.
func RegisterCollation(db *sqlite3.Conn, locale, name string) error {
tag, err := language.Parse(locale)
if err != nil {
return err
}
return db.CreateCollation(name, collate.New(tag).Compare)
}
func upper(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) == 1 {
ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob()))
return
}
cs, ok := ctx.GetAuxData(1).(cases.Caser)
if !ok {
t, err := language.Parse(arg[1].Text())
if err != nil {
ctx.ResultError(err)
return
}
c := cases.Upper(t)
ctx.SetAuxData(1, c)
cs = c
}
ctx.ResultBlob(cs.Bytes(arg[0].RawBlob()))
}
func lower(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) == 1 {
ctx.ResultBlob(bytes.ToLower(arg[0].RawBlob()))
return
}
cs, ok := ctx.GetAuxData(1).(cases.Caser)
if !ok {
t, err := language.Parse(arg[1].Text())
if err != nil {
ctx.ResultError(err)
return
}
c := cases.Lower(t)
ctx.SetAuxData(1, c)
cs = c
}
ctx.ResultBlob(cs.Bytes(arg[0].RawBlob()))
}
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return
}
re = r
ctx.SetAuxData(0, re)
}
ctx.ResultBool(re.Match(arg[1].RawBlob()))
}
func like(ctx sqlite3.Context, arg ...sqlite3.Value) {
escape := rune(-1)
if len(arg) == 3 {
var size int
b := arg[2].RawBlob()
escape, size = utf8.DecodeRune(b)
if size != len(b) {
ctx.ResultError(util.ErrorString("ESCAPE expression must be a single character"))
return
}
}
type likeData struct {
*regexp.Regexp
escape rune
}
re, ok := ctx.GetAuxData(0).(likeData)
if !ok || re.escape != escape {
re = likeData{
regexp.MustCompile(like2regex(arg[0].Text(), escape)),
escape,
}
ctx.SetAuxData(0, re)
}
ctx.ResultBool(re.Match(arg[1].RawBlob()))
}
func like2regex(pattern string, escape rune) string {
var re strings.Builder
start := 0
literal := false
re.Grow(len(pattern) + 10)
re.WriteString(`(?is)\A`) // case insensitive, . matches any character
for i, r := range pattern {
if start < 0 {
start = i
}
if literal {
literal = false
continue
}
var symbol string
switch r {
case '_':
symbol = `.`
case '%':
symbol = `.*`
case escape:
literal = true
default:
continue
}
re.WriteString(regexp.QuoteMeta(pattern[start:i]))
re.WriteString(symbol)
start = -1
}
if start >= 0 {
re.WriteString(regexp.QuoteMeta(pattern[start:]))
}
re.WriteString(`\z`)
return re.String()
}

215
ext/unicode/unicode_test.go Normal file
View File

@@ -0,0 +1,215 @@
package unicode
import (
"errors"
"reflect"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestRegister(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
exec := func(fn string) string {
stmt, _, err := db.Prepare(`SELECT ` + fn)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
return stmt.ColumnText(0)
}
t.Fatal(stmt.Err())
return ""
}
Register(db)
tests := []struct {
test string
want string
}{
{`upper('hello')`, "HELLO"},
{`lower('HELLO')`, "hello"},
{`upper('привет')`, "ПРИВЕТ"},
{`lower('ПРИВЕТ')`, "привет"},
{`upper('istanbul')`, "ISTANBUL"},
{`upper('istanbul', 'tr-TR')`, "İSTANBUL"},
{`lower('Dünyanın İlk Borsası', 'tr-TR')`, "dünyanın ilk borsası"},
{`upper('Dünyanın İlk Borsası', 'tr-TR')`, "DÜNYANIN İLK BORSASI"},
{`'Hello' REGEXP 'ell'`, "1"},
{`'Hello' REGEXP 'el.'`, "1"},
{`'Hello' LIKE 'hel_'`, "0"},
{`'Hello' LIKE 'hel%'`, "1"},
{`'Hello' LIKE 'h_llo'`, "1"},
{`'Hello' LIKE 'hello'`, "1"},
{`'Привет' LIKE 'ПРИВЕТ'`, "1"},
{`'100%' LIKE '100|%' ESCAPE '|'`, "1"},
}
for _, tt := range tests {
t.Run(tt.test, func(t *testing.T) {
if got := exec(tt.test); got != tt.want {
t.Errorf("exec(%q) = %q, want %q", tt.test, got, tt.want)
}
})
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}
func TestRegister_collation(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
Register(db)
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`SELECT icu_load_collation('fr_FR', 'french')`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT word FROM words ORDER BY word COLLATE french`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
got, want := []string{}, []string{"cote", "coté", "côte", "côté", "cotée", "coter"}
for stmt.Step() {
got = append(got, stmt.ColumnText(0))
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, want) {
t.Error("not equal")
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}
func TestRegister_error(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
Register(db)
err = db.Exec(`SELECT upper('hello', 'enUS')`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT lower('hello', 'enUS')`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT 'hello' REGEXP '\'`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT 'hello' LIKE 'HELLO' ESCAPE '\\'`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT icu_load_collation('enUS', 'error')`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT icu_load_collation('enUS', '')`)
if err != nil {
t.Error(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}
func Test_like2regex(t *testing.T) {
const prefix = `(?is)\A`
const sufix = `\z`
tests := []struct {
pattern string
escape rune
want string
}{
{`a`, -1, `a`},
{`a.`, -1, `a\.`},
{`a%`, -1, `a.*`},
{`a\`, -1, `a\\`},
{`a_b`, -1, `a.b`},
{`a|b`, '|', `ab`},
{`a|_`, '|', `a_`},
}
for _, tt := range tests {
t.Run(tt.pattern, func(t *testing.T) {
want := prefix + tt.want + sufix
if got := like2regex(tt.pattern, tt.escape); got != want {
t.Errorf("like2regex() = %q, want %q", got, want)
}
})
}
}

182
func.go
View File

@@ -8,29 +8,179 @@ import (
"github.com/tetratelabs/wazero/api"
)
// AnyCollationNeeded registers a fake collating function
// for any unknown collating sequence.
// The fake collating function works like BINARY.
//
// This can be used to load schemas that contain
// one or more unknown collating sequences.
func (c *Conn) AnyCollationNeeded() {
c.call(c.api.anyCollation, uint64(c.handle), 0, 0)
}
// CreateCollation defines a new collating sequence.
//
// https://www.sqlite.org/c3ref/create_collation.html
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
r := c.call(c.api.createCollation,
uint64(c.handle), uint64(namePtr), uint64(funcPtr))
if err := c.error(r); err != nil {
util.DelHandle(c.ctx, funcPtr)
return err
}
return nil
}
// CreateFunction defines a new scalar SQL function.
//
// https://www.sqlite.org/c3ref/create_function.html
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error {
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
r := c.call(c.api.createFunction,
uint64(c.handle), uint64(namePtr), uint64(nArg),
uint64(flag), uint64(funcPtr))
return c.error(r)
}
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
// If fn returns a [WindowFunction], then an aggregate window function is created.
//
// https://www.sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
call := c.api.createAggregate
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
if _, ok := fn().(WindowFunction); ok {
call = c.api.createWindow
}
r := c.call(call,
uint64(c.handle), uint64(namePtr), uint64(nArg),
uint64(flag), uint64(funcPtr))
return c.error(r)
}
// AggregateFunction is the interface an aggregate function should implement.
//
// https://www.sqlite.org/appfunc.html
type AggregateFunction interface {
// Step is invoked to add a row to the current window.
// The function arguments, if any, corresponding to the row being added are passed to Step.
Step(ctx Context, arg ...Value)
// Value is invoked to return the current value of the aggregate.
Value(ctx Context)
}
// WindowFunction is the interface an aggregate window function should implement.
//
// https://www.sqlite.org/windowfunctions.html
type WindowFunction interface {
AggregateFunction
// Inverse is invoked to remove the oldest presently aggregated result of Step from the current window.
// The function arguments, if any, are those passed to Step for the row being removed.
Inverse(ctx Context, arg ...Value)
}
func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncVI(env, "go_destroy", cbDestroy)
util.ExportFuncIIIIII(env, "go_compare", cbCompare)
util.ExportFuncVIII(env, "go_func", cbFunc)
util.ExportFuncVIII(env, "go_step", cbStep)
util.ExportFuncVI(env, "go_final", cbFinal)
util.ExportFuncVI(env, "go_value", cbValue)
util.ExportFuncVIII(env, "go_inverse", cbInverse)
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
util.ExportFuncVIII(env, "go_func", callbackFunc)
util.ExportFuncVIII(env, "go_step", callbackStep)
util.ExportFuncVI(env, "go_final", callbackFinal)
util.ExportFuncVI(env, "go_value", callbackValue)
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
return env
}
func cbDestroy(ctx context.Context, mod api.Module, pArg uint32) {}
func cbCompare(ctx context.Context, mod api.Module, pArg, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
return 0
func callbackDestroy(ctx context.Context, mod api.Module, pApp uint32) {
util.DelHandle(ctx, pApp)
}
func cbFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {}
func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))
}
func cbStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {}
func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackHandle(sqlite, pCtx).(func(ctx Context, arg ...Value))
fn(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
}
func cbFinal(ctx context.Context, mod api.Module, pCtx uint32) {}
func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
fn.Step(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
}
func cbValue(ctx context.Context, mod api.Module, pCtx uint32) {}
func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) {
var handle uint32
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, &handle).(AggregateFunction)
fn.Value(Context{sqlite, pCtx})
if err := util.DelHandle(ctx, handle); err != nil {
Context{sqlite, pCtx}.ResultError(err)
}
}
func cbInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {}
func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
fn.Value(Context{sqlite, pCtx})
}
func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, nil).(WindowFunction)
fn.Inverse(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
}
func callbackHandle(sqlite *sqlite, pCtx uint32) any {
pApp := uint32(sqlite.call(sqlite.api.userData, uint64(pCtx)))
return util.GetHandle(sqlite.ctx, pApp)
}
func callbackAggregate(sqlite *sqlite, pCtx uint32, close *uint32) any {
// On close, we're getting rid of the handle.
// Don't allocate space to store it.
var size uint64
if close == nil {
size = ptrlen
}
ptr := uint32(sqlite.call(sqlite.api.aggregateCtx, uint64(pCtx), size))
// Try loading the handle, if we already have one, or want a new one.
if ptr != 0 || size != 0 {
if handle := util.ReadUint32(sqlite.mod, ptr); handle != 0 {
fn := util.GetHandle(sqlite.ctx, handle)
if close != nil {
*close = handle
}
if fn != nil {
return fn
}
}
}
// Create a new aggregate and store the handle.
fn := callbackHandle(sqlite, pCtx).(func() AggregateFunction)()
if ptr != 0 {
util.WriteUint32(sqlite.mod, ptr, util.AddHandle(sqlite.ctx, fn))
}
return fn
}
func callbackArgs(sqlite *sqlite, nArg, pArg uint32) []Value {
args := make([]Value, nArg)
for i := range args {
args[i] = Value{
sqlite: sqlite,
handle: util.ReadUint32(sqlite.mod, pArg+ptrlen*uint32(i)),
}
}
return args
}

154
func_test.go Normal file
View File

@@ -0,0 +1,154 @@
package sqlite3_test
import (
"bytes"
"fmt"
"log"
"regexp"
"golang.org/x/text/collate"
"golang.org/x/text/language"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func ExampleConn_CreateCollation() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
log.Fatal(err)
}
err = db.CreateCollation("french", collate.New(language.French).Compare)
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT word FROM words ORDER BY word COLLATE french`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnText(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Output:
// cote
// coté
// côte
// côté
// cotée
// coter
}
func ExampleConn_CreateFunction() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
log.Fatal(err)
}
err = db.CreateFunction("upper", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob()))
})
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT upper(word) FROM words`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnText(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Unordered output:
// COTE
// COTÉ
// CÔTE
// CÔTÉ
// COTÉE
// COTER
}
func ExampleContext_SetAuxData() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
log.Fatal(err)
}
err = db.CreateFunction("regexp", 2, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return
}
ctx.SetAuxData(0, r)
re = r
}
ctx.ResultBool(re.Match(arg[1].RawBlob()))
})
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT word FROM words WHERE word REGEXP '^\p{L}+e$'`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnText(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Unordered output:
// cote
// côte
// cotée
}

87
func_win_test.go Normal file
View File

@@ -0,0 +1,87 @@
package sqlite3_test
import (
"fmt"
"log"
"unicode"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func ExampleConn_CreateWindowFunction() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
log.Fatal(err)
}
err = db.CreateWindowFunction("count_ascii", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, newASCIICounter)
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT count_ascii(word) OVER (ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM words`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnInt(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Output:
// 1
// 2
// 2
// 1
// 0
// 0
}
type countASCII struct{ result int }
func newASCIICounter() sqlite3.AggregateFunction {
return &countASCII{}
}
func (f *countASCII) Value(ctx sqlite3.Context) {
ctx.ResultInt(f.result)
}
func (f *countASCII) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
if f.isASCII(arg[0]) {
f.result++
}
}
func (f *countASCII) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
if f.isASCII(arg[0]) {
f.result--
}
}
func (f *countASCII) isASCII(arg sqlite3.Value) bool {
if arg.Type() != sqlite3.TEXT {
return false
}
for _, c := range arg.RawBlob() {
if c > unicode.MaxASCII {
return false
}
}
return true
}

7
go.mod
View File

@@ -1,13 +1,14 @@
module github.com/ncruces/go-sqlite3
go 1.19
go 1.21
require (
github.com/ncruces/julianday v0.1.5
github.com/psanford/httpreadat v0.1.0
github.com/tetratelabs/wazero v1.2.1
github.com/tetratelabs/wazero v1.5.0
golang.org/x/sync v0.3.0
golang.org/x/sys v0.9.0
golang.org/x/sys v0.12.0
golang.org/x/text v0.13.0
)
retract v0.4.0 // tagged from the wrong branch

10
go.sum
View File

@@ -2,9 +2,11 @@ github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FB
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
github.com/tetratelabs/wazero v1.2.1 h1:J4X2hrGzJvt+wqltuvcSjHQ7ujQxA9gb6PeMs4qlUWs=
github.com/tetratelabs/wazero v1.2.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=

View File

@@ -1,4 +1,4 @@
go 1.19
go 1.21
use (
.

5
go.work.sum Normal file
View File

@@ -0,0 +1,5 @@
github.com/ncruces/go-sqlite3 v0.9.0/go.mod h1:IyRoNwT0Z+mNRXIVeP2DgWPNl78Kmc/B+pO9i6GNgRg=
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=

View File

@@ -1,16 +1,16 @@
module github.com/ncruces/go-sqlite3/gormlite
go 1.19
go 1.21
require (
github.com/ncruces/go-sqlite3 v0.8.1
gorm.io/gorm v1.25.2
github.com/ncruces/go-sqlite3 v0.9.0
gorm.io/gorm v1.25.4
)
require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/ncruces/julianday v0.1.5 // indirect
github.com/tetratelabs/wazero v1.2.1 // indirect
golang.org/x/sys v0.9.0 // indirect
github.com/tetratelabs/wazero v1.5.0 // indirect
golang.org/x/sys v0.12.0 // indirect
)

View File

@@ -2,13 +2,15 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/ncruces/go-sqlite3 v0.8.1 h1:e1Y7uHu96xC4fWKsCVWprbTi8vAaQX9R+8kgkxOHWaY=
github.com/ncruces/go-sqlite3 v0.8.1/go.mod h1:EhHe1qvG6Zc/8ffYMzre8n//rTRs1YNN5dUD1f1mEGc=
github.com/ncruces/go-sqlite3 v0.9.0 h1:tl5eEmGEyzZH2ur8sDgPJTdzV4CRnKpsFngoP1QRjD8=
github.com/ncruces/go-sqlite3 v0.9.0/go.mod h1:IyRoNwT0Z+mNRXIVeP2DgWPNl78Kmc/B+pO9i6GNgRg=
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.2.1 h1:J4X2hrGzJvt+wqltuvcSjHQ7ujQxA9gb6PeMs4qlUWs=
github.com/tetratelabs/wazero v1.2.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
gorm.io/gorm v1.25.4 h1:iyNd8fNAe8W9dvtlgeRI5zSVZPsq3OpcTu37cYcpCmw=
gorm.io/gorm v1.25.4/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=

View File

@@ -3,9 +3,7 @@ package gormlite
import (
"context"
"database/sql"
"strconv"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
@@ -14,7 +12,7 @@ import (
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
_ "github.com/ncruces/go-sqlite3/driver"
"github.com/ncruces/go-sqlite3/driver"
)
type Dialector struct {
@@ -34,7 +32,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
if dialector.Conn != nil {
db.ConnPool = dialector.Conn
} else {
conn, err := sql.Open("sqlite3", dialector.DSN)
conn, err := driver.Open(dialector.DSN, nil)
if err != nil {
return err
}
@@ -136,19 +134,51 @@ func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement,
}
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteByte('`')
if strings.Contains(str, ".") {
for idx, str := range strings.Split(str, ".") {
if idx > 0 {
writer.WriteString(".`")
var (
underQuoted, selfQuoted bool
continuousBacktick int8
shiftDelimiter int8
)
for _, v := range []byte(str) {
switch v {
case '`':
continuousBacktick++
if continuousBacktick == 2 {
writer.WriteString("``")
continuousBacktick = 0
}
writer.WriteString(str)
writer.WriteByte('`')
case '.':
if continuousBacktick > 0 || !selfQuoted {
shiftDelimiter = 0
underQuoted = false
continuousBacktick = 0
writer.WriteString("`")
}
writer.WriteByte(v)
continue
default:
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
writer.WriteString("`")
underQuoted = true
if selfQuoted = continuousBacktick > 0; selfQuoted {
continuousBacktick -= 1
}
}
for ; continuousBacktick > 0; continuousBacktick -= 1 {
writer.WriteString("``")
}
writer.WriteByte(v)
}
} else {
writer.WriteString(str)
writer.WriteByte('`')
shiftDelimiter++
}
if continuousBacktick > 0 && !selfQuoted {
writer.WriteString("``")
}
writer.WriteString("`")
}
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {

View File

@@ -1,11 +1,14 @@
package gormlite
import (
"context"
"fmt"
"testing"
"gorm.io/gorm"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
@@ -13,6 +16,17 @@ func TestDialector(t *testing.T) {
// This is the DSN of the in-memory SQLite database for these tests.
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
// Custom connection with a custom function called "my_custom_function".
conn, err := driver.Open(InMemoryDSN, func(ctx context.Context, conn *sqlite3.Conn) error {
return conn.CreateFunction("my_custom_function", 0, sqlite3.DETERMINISTIC,
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultText("my-result")
})
})
if err != nil {
t.Fatal(err)
}
rows := []struct {
description string
dialector *Dialector
@@ -29,6 +43,33 @@ func TestDialector(t *testing.T) {
query: "SELECT 1",
querySuccess: true,
},
{
description: "Custom function",
dialector: &Dialector{
DSN: InMemoryDSN,
},
openSuccess: true,
query: "SELECT my_custom_function()",
querySuccess: false,
},
{
description: "Custom connection",
dialector: &Dialector{
Conn: conn,
},
openSuccess: true,
query: "SELECT 1",
querySuccess: true,
},
{
description: "Custom connection, custom function",
dialector: &Dialector{
Conn: conn,
},
openSuccess: true,
query: "SELECT my_custom_function()",
querySuccess: true,
},
}
for rowIndex, row := range rows {
t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) {

View File

@@ -3,7 +3,7 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
rm -rf gorm/ tests/ $TMPDIR/gorm.db
rm -rf gorm/ tests/
git clone --filter=blob:none https://github.com/go-gorm/gorm.git
mv gorm/tests tests
rm -rf gorm/
@@ -11,8 +11,14 @@ rm -rf gorm/
patch -p1 -N < tests.patch
cd tests
go mod edit \
-require github.com/ncruces/go-sqlite3/gormlite@v0.0.0 \
-replace github.com/ncruces/go-sqlite3/gormlite=../ \
-replace github.com/ncruces/go-sqlite3=../../ \
-droprequire gorm.io/driver/sqlite \
-dropreplace gorm.io/gorm
go mod tidy && go work use . && go test
cd ..
rm -rf tests/ $TMPDIR/gorm.db
rm -rf tests/
go work use -r .

View File

@@ -4,26 +4,6 @@ diff --git a/tests/.gitignore b/tests/.gitignore
@@ -1 +1 @@
-go.sum
+*
diff --git a/tests/go.mod b/tests/go.mod
--- a/tests/go.mod
+++ b/tests/go.mod
@@ -7,12 +7,12 @@ require (
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.8
- github.com/mattn/go-sqlite3 v1.14.16 // indirect
+ github.com/ncruces/go-sqlite3 v0.7.2
+ github.com/ncruces/go-sqlite3/gormlite v0.0.0
gorm.io/driver/mysql v1.5.0
gorm.io/driver/postgres v1.5.0
- gorm.io/driver/sqlite v1.5.0
gorm.io/driver/sqlserver v1.5.1
- gorm.io/gorm v1.25.1
+ gorm.io/gorm v1.25.2
)
-replace gorm.io/gorm => ../
+replace github.com/ncruces/go-sqlite3/gormlite => ../
diff --git a/tests/tests_test.go b/tests/tests_test.go
--- a/tests/tests_test.go
+++ b/tests/tests_test.go
@@ -40,3 +20,12 @@ diff --git a/tests/tests_test.go b/tests/tests_test.go
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@@ -89,7 +91,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
db, err = gorm.Open(mysql.Open(dbDSN), cfg)
default:
log.Println("testing sqlite3...")
- db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg)
+ db, err = gorm.Open(sqlite.Open("file:"+filepath.Join(os.TempDir(), "gorm.db")+"?_pragma=busy_timeout(1000)&_pragma=foreign_keys(1)"), cfg)
}
if err != nil {

75
internal/util/handle.go Normal file
View File

@@ -0,0 +1,75 @@
package util
import (
"context"
"io"
"github.com/tetratelabs/wazero/experimental"
)
type handleKey struct{}
type handleState struct {
handles []any
empty int
}
func NewContext(ctx context.Context) context.Context {
state := new(handleState)
ctx = experimental.WithCloseNotifier(ctx, state)
ctx = context.WithValue(ctx, handleKey{}, state)
return ctx
}
func (s *handleState) CloseNotify(ctx context.Context, exitCode uint32) {
for _, h := range s.handles {
if c, ok := h.(io.Closer); ok {
c.Close()
}
}
s.handles = nil
s.empty = 0
}
func GetHandle(ctx context.Context, id uint32) any {
if id == 0 {
return nil
}
s := ctx.Value(handleKey{}).(*handleState)
return s.handles[^id]
}
func DelHandle(ctx context.Context, id uint32) error {
if id == 0 {
return nil
}
s := ctx.Value(handleKey{}).(*handleState)
a := s.handles[^id]
s.handles[^id] = nil
s.empty++
if c, ok := a.(io.Closer); ok {
return c.Close()
}
return nil
}
func AddHandle(ctx context.Context, a any) (id uint32) {
if a == nil {
panic(NilErr)
}
s := ctx.Value(handleKey{}).(*handleState)
// Find an empty slot.
if s.empty > cap(s.handles)-len(s.handles) {
for id, h := range s.handles {
if h == nil {
s.empty--
s.handles[id] = a
return ^uint32(id)
}
}
}
// Add a new slot.
s.handles = append(s.handles, a)
return -uint32(len(s.handles))
}

View File

@@ -3,7 +3,6 @@ package sqlite3
import (
"context"
"io"
"math"
"os"
"sync"
@@ -25,72 +24,67 @@ var (
Path string // Path to load the binary from.
)
var sqlite3 struct {
var instance struct {
runtime wazero.Runtime
compiled wazero.CompiledModule
err error
once sync.Once
}
func instantiateModule() (*module, error) {
func compileSQLite() {
ctx := context.Background()
instance.runtime = wazero.NewRuntime(ctx)
sqlite3.once.Do(compileModule)
if sqlite3.err != nil {
return nil, sqlite3.err
}
cfg := wazero.NewModuleConfig()
mod, err := sqlite3.runtime.InstantiateModule(ctx, sqlite3.compiled, cfg)
if err != nil {
return nil, err
}
return newModule(mod)
}
func compileModule() {
ctx := context.Background()
sqlite3.runtime = wazero.NewRuntime(ctx)
env := sqlite3.runtime.NewHostModuleBuilder("env")
env := instance.runtime.NewHostModuleBuilder("env")
env = vfs.ExportHostFunctions(env)
env = exportHostFunctions(env)
_, sqlite3.err = env.Instantiate(ctx)
if sqlite3.err != nil {
_, instance.err = env.Instantiate(ctx)
if instance.err != nil {
return
}
bin := Binary
if bin == nil && Path != "" {
bin, sqlite3.err = os.ReadFile(Path)
if sqlite3.err != nil {
bin, instance.err = os.ReadFile(Path)
if instance.err != nil {
return
}
}
if bin == nil {
sqlite3.err = util.BinaryErr
instance.err = util.BinaryErr
return
}
sqlite3.compiled, sqlite3.err = sqlite3.runtime.CompileModule(ctx, bin)
instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin)
}
type module struct {
ctx context.Context
mod api.Module
vfs io.Closer
api sqliteAPI
arg [8]uint64
type sqlite struct {
ctx context.Context
mod api.Module
api sqliteAPI
stack [8]uint64
}
func newModule(mod api.Module) (m *module, err error) {
m = new(module)
m.mod = mod
m.ctx, m.vfs = vfs.NewContext(context.Background())
type sqliteKey struct{}
func instantiateSQLite() (sqlt *sqlite, err error) {
instance.once.Do(compileSQLite)
if instance.err != nil {
return nil, instance.err
}
sqlt = new(sqlite)
sqlt.ctx = util.NewContext(context.Background())
sqlt.ctx = context.WithValue(sqlt.ctx, sqliteKey{}, sqlt)
sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
instance.compiled, wazero.NewModuleConfig())
if err != nil {
return nil, err
}
getFun := func(name string) api.Function {
f := mod.ExportedFunction(name)
f := sqlt.mod.ExportedFunction(name)
if f == nil {
err = util.NoFuncErr + util.ErrorString(name)
return nil
@@ -99,15 +93,15 @@ func newModule(mod api.Module) (m *module, err error) {
}
getVal := func(name string) uint32 {
g := mod.ExportedGlobal(name)
g := sqlt.mod.ExportedGlobal(name)
if g == nil {
err = util.NoGlobalErr + util.ErrorString(name)
return 0
}
return util.ReadUint32(mod, uint32(g.Get()))
return util.ReadUint32(sqlt.mod, uint32(g.Get()))
}
m.api = sqliteAPI{
sqlt.api = sqliteAPI{
free: getFun("free"),
malloc: getFun("malloc"),
destructor: getVal("malloc_destructor"),
@@ -155,20 +149,43 @@ func newModule(mod api.Module) (m *module, err error) {
changes: getFun("sqlite3_changes64"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
autocommit: getFun("sqlite3_get_autocommit"),
anyCollation: getFun("sqlite3_anycollseq_init"),
createCollation: getFun("sqlite3_create_collation_go"),
createFunction: getFun("sqlite3_create_function_go"),
createAggregate: getFun("sqlite3_create_aggregate_function_go"),
createWindow: getFun("sqlite3_create_window_function_go"),
aggregateCtx: getFun("sqlite3_aggregate_context"),
userData: getFun("sqlite3_user_data"),
setAuxData: getFun("sqlite3_set_auxdata_go"),
getAuxData: getFun("sqlite3_get_auxdata"),
valueType: getFun("sqlite3_value_type"),
valueInteger: getFun("sqlite3_value_int64"),
valueFloat: getFun("sqlite3_value_double"),
valueText: getFun("sqlite3_value_text"),
valueBlob: getFun("sqlite3_value_blob"),
valueBytes: getFun("sqlite3_value_bytes"),
resultNull: getFun("sqlite3_result_null"),
resultInteger: getFun("sqlite3_result_int64"),
resultFloat: getFun("sqlite3_result_double"),
resultText: getFun("sqlite3_result_text64"),
resultBlob: getFun("sqlite3_result_blob64"),
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
resultError: getFun("sqlite3_result_error"),
resultErrorCode: getFun("sqlite3_result_error_code"),
resultErrorMem: getFun("sqlite3_result_error_nomem"),
resultErrorBig: getFun("sqlite3_result_error_toobig"),
}
if err != nil {
return nil, err
}
return m, nil
return sqlt, nil
}
func (m *module) close() error {
err := m.mod.Close(m.ctx)
m.vfs.Close()
return err
func (sqlt *sqlite) close() error {
return sqlt.mod.Close(sqlt.ctx)
}
func (m *module) error(rc uint64, handle uint32, sql ...string) error {
func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
if rc == _OK {
return nil
}
@@ -179,16 +196,16 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
panic(util.OOMErr)
}
if r := m.call(m.api.errstr, rc); r != 0 {
err.str = util.ReadString(m.mod, uint32(r), _MAX_STRING)
if r := sqlt.call(sqlt.api.errstr, rc); r != 0 {
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if r := m.call(m.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(m.mod, uint32(r), _MAX_STRING)
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if sql != nil {
if r := m.call(m.api.erroff, uint64(handle)); r != math.MaxUint32 {
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:]
}
}
@@ -200,60 +217,58 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
return &err
}
func (m *module) call(fn api.Function, params ...uint64) uint64 {
copy(m.arg[:], params)
err := fn.CallWithStack(m.ctx, m.arg[:])
func (sqlt *sqlite) call(fn api.Function, params ...uint64) uint64 {
copy(sqlt.stack[:], params)
err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
if err != nil {
// The module closed or panicked; release resources.
m.vfs.Close()
panic(err)
}
return m.arg[0]
return sqlt.stack[0]
}
func (m *module) free(ptr uint32) {
func (sqlt *sqlite) free(ptr uint32) {
if ptr == 0 {
return
}
m.call(m.api.free, uint64(ptr))
sqlt.call(sqlt.api.free, uint64(ptr))
}
func (m *module) new(size uint64) uint32 {
func (sqlt *sqlite) new(size uint64) uint32 {
if size > _MAX_ALLOCATION_SIZE {
panic(util.OOMErr)
}
ptr := uint32(m.call(m.api.malloc, size))
ptr := uint32(sqlt.call(sqlt.api.malloc, size))
if ptr == 0 && size != 0 {
panic(util.OOMErr)
}
return ptr
}
func (m *module) newBytes(b []byte) uint32 {
func (sqlt *sqlite) newBytes(b []byte) uint32 {
if b == nil {
return 0
}
ptr := m.new(uint64(len(b)))
util.WriteBytes(m.mod, ptr, b)
ptr := sqlt.new(uint64(len(b)))
util.WriteBytes(sqlt.mod, ptr, b)
return ptr
}
func (m *module) newString(s string) uint32 {
ptr := m.new(uint64(len(s) + 1))
util.WriteString(m.mod, ptr, s)
func (sqlt *sqlite) newString(s string) uint32 {
ptr := sqlt.new(uint64(len(s) + 1))
util.WriteString(sqlt.mod, ptr, s)
return ptr
}
func (m *module) newArena(size uint64) arena {
func (sqlt *sqlite) newArena(size uint64) arena {
return arena{
m: m,
base: m.new(size),
sqlt: sqlt,
size: uint32(size),
base: sqlt.new(size),
}
}
type arena struct {
m *module
sqlt *sqlite
ptrs []uint32
base uint32
next uint32
@@ -261,17 +276,17 @@ type arena struct {
}
func (a *arena) free() {
if a.m == nil {
if a.sqlt == nil {
return
}
a.reset()
a.m.free(a.base)
a.m = nil
a.sqlt.free(a.base)
a.sqlt = nil
}
func (a *arena) reset() {
for _, ptr := range a.ptrs {
a.m.free(ptr)
a.sqlt.free(ptr)
}
a.ptrs = nil
a.next = 0
@@ -283,7 +298,7 @@ func (a *arena) new(size uint64) uint32 {
a.next += uint32(size)
return ptr
}
ptr := a.m.new(size)
ptr := a.sqlt.new(size)
a.ptrs = append(a.ptrs, ptr)
return ptr
}
@@ -293,13 +308,13 @@ func (a *arena) bytes(b []byte) uint32 {
return 0
}
ptr := a.new(uint64(len(b)))
util.WriteBytes(a.m.mod, ptr, b)
util.WriteBytes(a.sqlt.mod, ptr, b)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint64(len(s) + 1))
util.WriteString(a.m.mod, ptr, s)
util.WriteString(a.sqlt.mod, ptr, s)
return ptr
}
@@ -319,10 +334,10 @@ type sqliteAPI struct {
step api.Function
exec api.Function
clearBindings api.Function
bindNull api.Function
bindCount api.Function
bindIndex api.Function
bindName api.Function
bindNull api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
@@ -350,5 +365,30 @@ type sqliteAPI struct {
changes api.Function
lastRowid api.Function
autocommit api.Function
anyCollation api.Function
createCollation api.Function
createFunction api.Function
createAggregate api.Function
createWindow api.Function
aggregateCtx api.Function
userData api.Function
setAuxData api.Function
getAuxData api.Function
valueType api.Function
valueInteger api.Function
valueFloat api.Function
valueText api.Function
valueBlob api.Function
valueBytes api.Function
resultNull api.Function
resultInteger api.Function
resultFloat api.Function
resultText api.Function
resultBlob api.Function
resultZeroBlob api.Function
resultError api.Function
resultErrorCode api.Function
resultErrorMem api.Function
resultErrorBig api.Function
destructor uint32
}

View File

@@ -3,33 +3,33 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3420000.zip"
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3430100.zip"
unzip -d . sqlite-amalgamation-*.zip
mv sqlite-amalgamation-*/sqlite3* .
rm -rf sqlite-amalgamation-*
cat *.patch | patch
cat *.patch | patch --posix
mkdir -p ext/
cd ext/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uuid.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/series.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/anycollseq.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/uuid.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/series.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/ext/misc/anycollseq.c"
cd ~-
cd ../vfs/tests/mptest/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/mptest.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config02.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash02.subtest"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/multiwrite01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/mptest.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/config02.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/crash01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/crash02.subtest"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/mptest/multiwrite01.test"
cd ~-
cd ../vfs/tests/speedtest1/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/test/speedtest1.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.43.1/test/speedtest1.c"
cd ~-

View File

@@ -10,27 +10,31 @@ void go_value(sqlite3_context *);
void go_inverse(sqlite3_context *, int, sqlite3_value **);
void go_destroy(void *);
int sqlite3_create_go_collation(sqlite3 *db, const char *zName, void *pApp) {
int sqlite3_create_collation_go(sqlite3 *db, const char *zName, void *pApp) {
return sqlite3_create_collation_v2(db, zName, SQLITE_UTF8, pApp, go_compare,
go_destroy);
}
int sqlite3_create_go_function(sqlite3 *db, const char *zName, int nArg,
int sqlite3_create_function_go(sqlite3 *db, const char *zName, int nArg,
int flags, void *pApp) {
return sqlite3_create_function_v2(db, zName, nArg, SQLITE_UTF8 | flags, pApp,
go_func, NULL, NULL, go_destroy);
}
int sqlite3_create_go_window_function(sqlite3 *db, const char *zName, int nArg,
int flags, void *pApp) {
int sqlite3_create_aggregate_function_go(sqlite3 *db, const char *zName,
int nArg, int flags, void *pApp) {
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
pApp, go_step, go_final, NULL, NULL,
go_destroy);
}
int sqlite3_create_go_aggregate_function(sqlite3 *db, const char *zName,
int nArg, int flags, void *pApp) {
int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg,
int flags, void *pApp) {
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
pApp, go_step, go_final, go_value,
go_inverse, go_destroy);
}
void sqlite3_set_auxdata_go(sqlite3_context *ctx, int iArg, void *pAux) {
sqlite3_set_auxdata(ctx, iArg, pAux, go_destroy);
}

View File

@@ -1,26 +0,0 @@
# Allow the VFS to force memory journal mode
# regardless of SQLITE_OMIT_DESERIALIZE.
--- sqlite3.c.orig
+++ sqlite3.c
@@ -60425,11 +60425,7 @@
int rc = SQLITE_OK; /* Return code */
int tempFile = 0; /* True for temp files (incl. in-memory files) */
int memDb = 0; /* True if this is an in-memory file */
-#ifndef SQLITE_OMIT_DESERIALIZE
int memJM = 0; /* Memory journal mode */
-#else
-# define memJM 0
-#endif
int readOnly = 0; /* True if this is a read-only file */
int journalFileSize; /* Bytes to allocate for each journal fd */
char *zPathname = 0; /* Full path to database file */
@@ -60628,9 +60624,7 @@
int fout = 0; /* VFS flags returned by xOpen() */
rc = sqlite3OsOpen(pVfs, pPager->zFilename, pPager->fd, vfsFlags, &fout);
assert( !memDb );
-#ifndef SQLITE_OMIT_DESERIALIZE
pPager->memVfs = memJM = (fout&SQLITE_OPEN_MEMORY)!=0;
-#endif
readOnly = (fout&SQLITE_OPEN_READONLY)!=0;
/* If the file was successfully opened for read/write access,

View File

@@ -29,6 +29,7 @@
#define SQLITE_USE_ALLOCA
// Other Options
#define SQLITE_ALLOW_URI_AUTHORITY
#define SQLITE_ENABLE_BATCH_ATOMIC_WRITE
#define SQLITE_ENABLE_ATOMIC_WRITE
@@ -55,5 +56,7 @@
// #define SQLITE_ENABLE_SESSION
// #define SQLITE_ENABLE_PREUPDATE_HOOK
#define SQLITE_SOUNDEX
// Implemented in vfs.c.
int localtime_s(struct tm *const pTm, time_t const *const pTime);

View File

@@ -12,67 +12,67 @@ func init() {
Path = "./embed/sqlite3.wasm"
}
func TestConn_error_OOM(t *testing.T) {
func Test_sqlite_error_OOM(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
defer func() { _ = recover() }()
m.error(uint64(NOMEM), 0)
sqlite.error(uint64(NOMEM), 0)
t.Error("want panic")
}
func TestConn_call_closed(t *testing.T) {
func Test_sqlite_call_closed(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
m.close()
sqlite.close()
defer func() { _ = recover() }()
m.call(m.api.free)
sqlite.call(sqlite.api.free)
t.Error("want panic")
}
func TestConn_new(t *testing.T) {
func Test_sqlite_new(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
t.Run("MaxUint32", func(t *testing.T) {
defer func() { _ = recover() }()
m.new(math.MaxUint32)
sqlite.new(math.MaxUint32)
t.Error("want panic")
})
t.Run("_MAX_ALLOCATION_SIZE", func(t *testing.T) {
defer func() { _ = recover() }()
m.new(_MAX_ALLOCATION_SIZE)
m.new(_MAX_ALLOCATION_SIZE)
sqlite.new(_MAX_ALLOCATION_SIZE)
sqlite.new(_MAX_ALLOCATION_SIZE)
t.Error("want panic")
})
}
func TestConn_newArena(t *testing.T) {
func Test_sqlite_newArena(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
arena := m.newArena(16)
arena := sqlite.newArena(16)
defer arena.free()
const title = "Lorem ipsum"
@@ -80,7 +80,7 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != title {
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != title {
t.Errorf("got %q, want %q", got, title)
}
@@ -89,7 +89,7 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != body {
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != body {
t.Errorf("got %q, want %q", got, body)
}
@@ -101,121 +101,121 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := util.View(m.mod, ptr, uint64(len(title))); string(got) != title {
if got := util.View(sqlite.mod, ptr, uint64(len(title))); string(got) != title {
t.Errorf("got %q, want %q", got, title)
}
arena.free()
}
func TestConn_newBytes(t *testing.T) {
func Test_sqlite_newBytes(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
ptr := m.newBytes(nil)
ptr := sqlite.newBytes(nil)
if ptr != 0 {
t.Errorf("got %#x, want nullptr", ptr)
}
buf := []byte("sqlite3")
ptr = m.newBytes(buf)
ptr = sqlite.newBytes(buf)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := buf
if got := util.View(m.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
func TestConn_newString(t *testing.T) {
func Test_sqlite_newString(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
ptr := m.newString("")
ptr := sqlite.newString("")
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
str := "sqlite3\000sqlite3"
ptr = m.newString(str)
ptr = sqlite.newString(str)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := str + "\000"
if got := util.View(m.mod, ptr, uint64(len(want))); string(got) != want {
if got := util.View(sqlite.mod, ptr, uint64(len(want))); string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestConn_getString(t *testing.T) {
func Test_sqlite_getString(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
ptr := m.newString("")
ptr := sqlite.newString("")
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
str := "sqlite3" + "\000 drop this"
ptr = m.newString(str)
ptr = sqlite.newString(str)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := "sqlite3"
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != want {
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != want {
t.Errorf("got %q, want %q", got, want)
}
if got := util.ReadString(m.mod, ptr, 0); got != "" {
if got := util.ReadString(sqlite.mod, ptr, 0); got != "" {
t.Errorf("got %q, want empty", got)
}
func() {
defer func() { _ = recover() }()
util.ReadString(m.mod, ptr, uint32(len(want)/2))
util.ReadString(sqlite.mod, ptr, uint32(len(want)/2))
t.Error("want panic")
}()
func() {
defer func() { _ = recover() }()
util.ReadString(m.mod, 0, math.MaxUint32)
util.ReadString(sqlite.mod, 0, math.MaxUint32)
t.Error("want panic")
}()
}
func TestConn_free(t *testing.T) {
func Test_sqlite_free(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
m.free(0)
sqlite.free(0)
ptr := m.new(1)
ptr := sqlite.new(1)
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
m.free(ptr)
sqlite.free(ptr)
}

35
stmt.go
View File

@@ -61,12 +61,12 @@ func (s *Stmt) ClearBindings() error {
func (s *Stmt) Step() bool {
s.c.checkInterrupt()
r := s.c.call(s.c.api.step, uint64(s.handle))
if r == _ROW {
switch r {
case _ROW:
return true
}
if r == _DONE {
case _DONE:
s.err = nil
} else {
default:
s.err = s.c.error(r)
}
return false
@@ -131,10 +131,11 @@ func (s *Stmt) BindName(param int) string {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindBool(param int, value bool) error {
var i int64
if value {
return s.BindInt64(param, 1)
i = 1
}
return s.BindInt64(param, 0)
return s.BindInt64(param, i)
}
// BindInt binds an int to the prepared statement.
@@ -374,18 +375,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
func (s *Stmt) ColumnRawText(col int) []byte {
r := s.c.call(s.c.api.columnText,
uint64(s.handle), uint64(col))
ptr := uint32(r)
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r)
return nil
}
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
return util.View(s.c.mod, ptr, r)
return s.columnRawBytes(col, uint32(r))
}
// ColumnRawBlob returns the value of the result column as a []byte.
@@ -397,17 +387,18 @@ func (s *Stmt) ColumnRawText(col int) []byte {
func (s *Stmt) ColumnRawBlob(col int) []byte {
r := s.c.call(s.c.api.columnBlob,
uint64(s.handle), uint64(col))
return s.columnRawBytes(col, uint32(r))
}
ptr := uint32(r)
func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte {
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
r := s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r)
return nil
}
r = s.c.call(s.c.api.columnBytes,
r := s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
return util.View(s.c.mod, ptr, r)
}

View File

@@ -2,10 +2,9 @@ package tests
import (
"context"
"database/sql"
"testing"
_ "github.com/ncruces/go-sqlite3/driver"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
@@ -15,7 +14,7 @@ func TestDriver(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", ":memory:")
db, err := driver.Open(":memory:", nil)
if err != nil {
t.Fatal(err)
}

188
tests/func_test.go Normal file
View File

@@ -0,0 +1,188 @@
package tests
import (
"errors"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestCreateFunction(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.CreateFunction("test", 1, sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
switch arg := arg[0]; arg.Int() {
case 0:
ctx.ResultInt(arg.Int())
case 1:
ctx.ResultInt64(arg.Int64())
case 2:
ctx.ResultBool(arg.Bool())
case 3:
ctx.ResultFloat(arg.Float())
case 4:
ctx.ResultText(arg.Text())
case 5:
ctx.ResultBlob(arg.Blob(nil))
case 6:
ctx.ResultZeroBlob(arg.Int64())
case 7:
ctx.ResultTime(arg.Time(sqlite3.TimeFormatUnix), sqlite3.TimeFormatDefault)
case 8:
ctx.ResultNull()
case 9:
ctx.ResultError(sqlite3.FULL)
}
})
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0, 9)`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want 1", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnInt64(0); got != 1 {
t.Errorf("got %v, want 2", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != true {
t.Errorf("got %v, want true", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.FLOAT {
t.Errorf("got %v, want FLOAT", got)
}
if got := stmt.ColumnInt64(0); got != 3 {
t.Errorf("got %v, want 3", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnText(0); got != "4" {
t.Errorf("got %s, want 4", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnRawBlob(0); string(got) != "5" {
t.Errorf("got %s, want 5", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnRawBlob(0); len(got) != 6 {
t.Errorf("got %v, want 6", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); got.Unix() != 7 {
t.Errorf("got %v, want 7", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
}
if stmt.Step() {
t.Error("want error")
}
if err := stmt.Err(); !errors.Is(err, sqlite3.FULL) {
t.Errorf("got %v, want sqlite3.FULL", err)
}
}
func TestAnyCollationNeeded(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}
db.AnyCollationNeeded()
stmt, _, err := db.Prepare(`SELECT id, name FROM users ORDER BY name COLLATE silly`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
row := 0
ids := []int{0, 2, 1}
names := []string{"go", "whatever", "zig"}
for ; stmt.Step(); row++ {
id := stmt.ColumnInt(0)
name := stmt.ColumnText(1)
if id != ids[row] {
t.Errorf("got %d, want %d", id, ids[row])
}
if name != names[row] {
t.Errorf("got %q, want %q", name, names[row])
}
}
if row != 3 {
t.Errorf("got %d, want %d", row, len(ids))
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
}

125
value.go Normal file
View File

@@ -0,0 +1,125 @@
package sqlite3
import (
"math"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Value is any value that can be stored in a database table.
//
// https://www.sqlite.org/c3ref/value.html
type Value struct {
*sqlite
handle uint32
}
// Type returns the initial [Datatype] of the value.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Type() Datatype {
r := v.call(v.api.valueType, uint64(v.handle))
return Datatype(r)
}
// Bool returns the value as a bool.
// SQLite does not have a separate boolean storage class.
// Instead, boolean values are retrieved as integers,
// with 0 converted to false and any other value to true.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Bool() bool {
if i := v.Int64(); i != 0 {
return true
}
return false
}
// Int returns the value as an int.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Int() int {
return int(v.Int64())
}
// Int64 returns the value as an int64.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Int64() int64 {
r := v.call(v.api.valueInteger, uint64(v.handle))
return int64(r)
}
// Float returns the value as a float64.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Float() float64 {
r := v.call(v.api.valueFloat, uint64(v.handle))
return math.Float64frombits(r)
}
// Time returns the value as a [time.Time].
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Time(format TimeFormat) time.Time {
var a any
switch v.Type() {
case INTEGER:
a = v.Int64()
case FLOAT:
a = v.Float()
case TEXT, BLOB:
a = v.Text()
case NULL:
return time.Time{}
default:
panic(util.AssertErr())
}
t, _ := format.Decode(a)
return t
}
// Text returns the value as a string.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Text() string {
return string(v.RawText())
}
// Blob appends to buf and returns
// the value as a []byte.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Blob(buf []byte) []byte {
return append(buf, v.RawBlob()...)
}
// RawText returns the value as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Value] methods.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) RawText() []byte {
r := v.call(v.api.valueText, uint64(v.handle))
return v.rawBytes(uint32(r))
}
// RawBlob returns the value as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Value] methods.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) RawBlob() []byte {
r := v.call(v.api.valueBlob, uint64(v.handle))
return v.rawBytes(uint32(r))
}
func (v Value) rawBytes(ptr uint32) []byte {
if ptr == 0 {
return nil
}
r := v.call(v.api.valueBytes, uint64(v.handle))
return util.View(v.mod, ptr, r)
}

View File

@@ -2,8 +2,30 @@
This package implements the SQLite [OS Interface](https://www.sqlite.org/vfs.html) (aka VFS).
It replaces the default VFS with a pure Go implementation,
that is tested on Linux, macOS and Windows,
but which should also work on illumos and the various BSDs.
It replaces the default SQLite VFS with a pure Go implementation.
It also exposes interfaces that should allow you to implement your own custom VFSes.
It also exposes interfaces that should allow you to implement your own custom VFSes.
## Portability
This package is tested on Linux, macOS and Windows,
but it should also work on FreeBSD and illumos
(code paths for those plaforms are tested on macOS and Linux, respectively).
In all platforms for which this package builds,
it should be safe to use it to access databases concurrently,
from multiple goroutines, processes, and
with _other_ implementations of SQLite.
If the package does not build for your platform,
you may try to use the `sqlite3_flock` and `sqlite3_nolock` build tags.
These are only minimally tested and concurrency test failures should be expected.
The `sqlite3_flock` tag uses
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2).
It should be safe to access databases concurrently from multiple goroutines and processes,
but **not** with _other_ implementations of SQLite
(_unless_ these are _also_ configured to use `flock`).
The `sqlite3_nolock` tag uses no locking at all.
Database corruption is the likely result from concurrent write access.

View File

@@ -15,7 +15,7 @@ type VFS interface {
FullPathname(name string) (string, error)
}
// VFSParams extends VFS to with the ability to handle URI parameters
// VFSParams extends VFS with the ability to handle URI parameters
// through the OpenParams method.
//
// https://www.sqlite.org/c3ref/uri_boolean.html
@@ -47,7 +47,7 @@ type File interface {
// FileLockState extends File to implement the
// SQLITE_FCNTL_LOCKSTATE file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntllockstate
type FileLockState interface {
File
LockState() LockLevel
@@ -56,7 +56,7 @@ type FileLockState interface {
// FileSizeHint extends File to implement the
// SQLITE_FCNTL_SIZE_HINT file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlsizehint
type FileSizeHint interface {
File
SizeHint(size int64) error
@@ -65,16 +65,25 @@ type FileSizeHint interface {
// FileHasMoved extends File to implement the
// SQLITE_FCNTL_HAS_MOVED file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlhasmoved
type FileHasMoved interface {
File
HasMoved() (bool, error)
}
// FileOverwrite extends File to implement the
// SQLITE_FCNTL_OVERWRITE file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntloverwrite
type FileOverwrite interface {
File
Overwrite() error
}
// FilePowersafeOverwrite extends File to implement the
// SQLITE_FCNTL_POWERSAFE_OVERWRITE file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlpowersafeoverwrite
type FilePowersafeOverwrite interface {
File
PowersafeOverwrite() bool
@@ -84,7 +93,7 @@ type FilePowersafeOverwrite interface {
// FilePowersafeOverwrite extends File to implement the
// SQLITE_FCNTL_COMMIT_PHASETWO file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlcommitphasetwo
type FileCommitPhaseTwo interface {
File
CommitPhaseTwo() error
@@ -94,7 +103,7 @@ type FileCommitPhaseTwo interface {
// SQLITE_FCNTL_BEGIN_ATOMIC_WRITE, SQLITE_FCNTL_COMMIT_ATOMIC_WRITE
// and SQLITE_FCNTL_ROLLBACK_ATOMIC_WRITE file control opcodes.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html#sqlitefcntlbeginatomicwrite
type FileBatchAtomicWrite interface {
File
BeginAtomicWrite() error

9
vfs/clear.go Normal file
View File

@@ -0,0 +1,9 @@
//go:build !go1.21
package vfs
func clear(b []byte) {
for i := range b {
b[i] = 0
}
}

View File

@@ -9,7 +9,6 @@ import (
"path/filepath"
"runtime"
"syscall"
"time"
)
type vfsOS struct{}
@@ -124,11 +123,10 @@ func (vfsOS) OpenParams(name string, flags OpenFlag, params url.Values) (File, O
type vfsFile struct {
*os.File
lockTimeout time.Duration
lock LockLevel
psow bool
syncDir bool
readOnly bool
lock LockLevel
psow bool
syncDir bool
readOnly bool
}
var (

View File

@@ -1,8 +1,9 @@
//go:build !sqlite3_nolock
package vfs
import (
"os"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
@@ -48,7 +49,7 @@ func (f *vfsFile) Lock(lock LockLevel) error {
if f.lock != LOCK_NONE {
panic(util.AssertErr())
}
if rc := osGetSharedLock(f.File, f.lockTimeout); rc != _OK {
if rc := osGetSharedLock(f.File); rc != _OK {
return rc
}
f.lock = LOCK_SHARED
@@ -59,7 +60,7 @@ func (f *vfsFile) Lock(lock LockLevel) error {
if f.lock != LOCK_SHARED {
panic(util.AssertErr())
}
if rc := osGetReservedLock(f.File, f.lockTimeout); rc != _OK {
if rc := osGetReservedLock(f.File); rc != _OK {
return rc
}
f.lock = LOCK_RESERVED
@@ -77,7 +78,7 @@ func (f *vfsFile) Lock(lock LockLevel) error {
}
f.lock = LOCK_PENDING
}
if rc := osGetExclusiveLock(f.File, f.lockTimeout); rc != _OK {
if rc := osGetExclusiveLock(f.File); rc != _OK {
return rc
}
f.lock = LOCK_EXCLUSIVE
@@ -134,9 +135,9 @@ func (f *vfsFile) CheckReservedLock() (bool, error) {
return osCheckReservedLock(f.File)
}
func osGetReservedLock(file *os.File, timeout time.Duration) _ErrorCode {
func osGetReservedLock(file *os.File) _ErrorCode {
// Acquire the RESERVED lock.
return osWriteLock(file, _RESERVED_BYTE, 1, timeout)
return osWriteLock(file, _RESERVED_BYTE, 1, 0)
}
func osGetPendingLock(file *os.File) _ErrorCode {

View File

@@ -4,7 +4,6 @@ import (
"context"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -12,13 +11,6 @@ import (
)
func Test_vfsLock(t *testing.T) {
switch runtime.GOOS {
case "linux", "darwin", "windows":
break
default:
t.Skip("OS lacks OFD locks")
}
name := filepath.Join(t.TempDir(), "test.db")
// Create a temporary file.
@@ -41,8 +33,7 @@ func Test_vfsLock(t *testing.T) {
pOutput = 32
)
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
ctx, vfs := NewContext(context.TODO())
defer vfs.Close()
ctx := util.NewContext(context.TODO())
vfsFileRegister(ctx, mod, pFile1, &vfsFile{File: file1})
vfsFileRegister(ctx, mod, pFile2, &vfsFile{File: file2})
@@ -212,9 +203,4 @@ func Test_vfsLock(t *testing.T) {
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) {
t.Error("invalid lock state", got)
}
rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCK_TIMEOUT, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
}

10
vfs/memdb/clear.go Normal file
View File

@@ -0,0 +1,10 @@
//go:build !go1.21
package memdb
func clear[T any](b []T) {
var zero T
for i := range b {
b[i] = zero
}
}

View File

@@ -133,7 +133,7 @@ func (m *memFile) WriteAt(b []byte, off int64) (n int, err error) {
n = copy((*m.data[base])[rest:], b)
if n < len(b) {
// Assume writes are page aligned.
return 0, io.ErrShortWrite
return n, io.ErrShortWrite
}
if size := off + int64(len(b)); size > m.size {
m.size = size
@@ -176,6 +176,8 @@ func (m *memFile) Size() (int64, error) {
return m.size, nil
}
const spinWait = 25 * time.Microsecond
func (m *memFile) Lock(lock vfs.LockLevel) error {
if m.lock >= lock {
return nil
@@ -210,8 +212,8 @@ func (m *memFile) Lock(lock vfs.LockLevel) error {
m.pending = m
}
for start := time.Now(); m.shared > 1; {
if time.Since(start) > time.Millisecond {
for before := time.Now(); m.shared > 1; {
if time.Since(before) > spinWait {
return sqlite3.BUSY
}
m.lockMtx.Unlock()
@@ -285,10 +287,3 @@ func divRoundUp(a, b int64) int64 {
func modRoundUp(a, b int64) int64 {
return b - (b-a%b)%b
}
func clear[T any](b []T) {
var zero T
for i := range b {
b[i] = zero
}
}

22
vfs/nolock.go Normal file
View File

@@ -0,0 +1,22 @@
//go:build sqlite3_nolock
package vfs
const (
_PENDING_BYTE = 0x40000000
_RESERVED_BYTE = (_PENDING_BYTE + 1)
_SHARED_FIRST = (_PENDING_BYTE + 2)
_SHARED_SIZE = 510
)
func (f *vfsFile) Lock(lock LockLevel) error {
return nil
}
func (f *vfsFile) Unlock(lock LockLevel) error {
return nil
}
func (f *vfsFile) CheckReservedLock() (bool, error) {
return false, nil
}

View File

@@ -1,4 +1,4 @@
//go:build freebsd || openbsd || netbsd || dragonfly || (darwin && sqlite3_bsd)
//go:build sqlite3_flock || freebsd
package vfs
@@ -20,16 +20,16 @@ func osUnlock(file *os.File, start, len int64) _ErrorCode {
}
func osLock(file *os.File, how int, timeout time.Duration, def _ErrorCode) _ErrorCode {
before := time.Now()
var err error
for {
err = unix.Flock(int(file.Fd()), how)
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
break
}
if timeout < time.Millisecond {
if timeout <= 0 || timeout < time.Since(before) {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)

View File

@@ -1,4 +1,4 @@
//go:build !sqlite3_bsd
//go:build !sqlite3_flock
package vfs

28
vfs/os_nolock.go Normal file
View File

@@ -0,0 +1,28 @@
//go:build sqlite3_nolock && unix && !(linux || darwin || freebsd || illumos)
package vfs
import (
"os"
"time"
)
func osUnlock(file *os.File, start, len int64) _ErrorCode {
return _OK
}
func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, def _ErrorCode) _ErrorCode {
return _OK
}
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return _OK
}
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return _OK
}
func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) {
return false, _OK
}

View File

@@ -1,4 +1,4 @@
//go:build linux || illumos
//go:build (linux || illumos) && !sqlite3_flock
package vfs
@@ -27,16 +27,16 @@ func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, d
Start: start,
Len: len,
}
before := time.Now()
var err error
for {
err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock)
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
break
}
if timeout < time.Millisecond {
if timeout <= 0 || timeout < time.Since(before) {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)

36
vfs/os_std_access.go Normal file
View File

@@ -0,0 +1,36 @@
//go:build !unix
package vfs
import (
"io/fs"
"os"
)
const (
_S_IREAD = 0400
_S_IWRITE = 0200
_S_IEXEC = 0100
)
func osAccess(path string, flags AccessFlag) error {
fi, err := os.Stat(path)
if err != nil {
return err
}
if flags == ACCESS_EXISTS {
return nil
}
var want fs.FileMode = _S_IREAD
if flags == ACCESS_READWRITE {
want |= _S_IWRITE
}
if fi.IsDir() {
want |= _S_IEXEC
}
if fi.Mode()&want != want {
return fs.ErrPermission
}
return nil
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && (!darwin || sqlite3_bsd)
//go:build !linux && (!darwin || sqlite3_flock)
package vfs
@@ -7,10 +7,6 @@ import (
"os"
)
func osSync(file *os.File, fullsync, dataonly bool) error {
return file.Sync()
}
func osAllocate(file *os.File, size int64) error {
off, err := file.Seek(0, io.SeekEnd)
if err != nil {

14
vfs/os_std_mode.go Normal file
View File

@@ -0,0 +1,14 @@
//go:build !unix
package vfs
import "os"
func osSetMode(file *os.File, modeof string) error {
fi, err := os.Stat(modeof)
if err != nil {
return err
}
file.Chmod(fi.Mode())
return nil
}

12
vfs/os_std_open.go Normal file
View File

@@ -0,0 +1,12 @@
//go:build !windows
package vfs
import (
"io/fs"
"os"
)
func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(name, flag, perm)
}

9
vfs/os_std_sync.go Normal file
View File

@@ -0,0 +1,9 @@
//go:build !linux && (!darwin || sqlite3_flock)
package vfs
import "os"
func osSync(file *os.File, fullsync, dataonly bool) error {
return file.Sync()
}

View File

@@ -3,7 +3,6 @@
package vfs
import (
"io/fs"
"os"
"syscall"
"time"
@@ -11,10 +10,6 @@ import (
"golang.org/x/sys/unix"
)
func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(name, flag, perm)
}
func osAccess(path string, flags AccessFlag) error {
var access uint32 // unix.F_OK
switch flags {
@@ -38,22 +33,18 @@ func osSetMode(file *os.File, modeof string) error {
return nil
}
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
func osGetSharedLock(file *os.File) _ErrorCode {
// Test the PENDING lock before acquiring a new SHARED lock.
if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending {
return _BUSY
}
// Acquire the SHARED lock.
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
}
func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
if timeout == 0 {
timeout = time.Millisecond
}
func osGetExclusiveLock(file *os.File) _ErrorCode {
// Acquire the EXCLUSIVE lock.
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, time.Millisecond)
}
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {

View File

@@ -25,40 +25,9 @@ func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.NewFile(uintptr(r), name), nil
}
func osAccess(path string, flags AccessFlag) error {
fi, err := os.Stat(path)
if err != nil {
return err
}
if flags == ACCESS_EXISTS {
return nil
}
var want fs.FileMode = windows.S_IRUSR
if flags == ACCESS_READWRITE {
want |= windows.S_IWUSR
}
if fi.IsDir() {
want |= windows.S_IXUSR
}
if fi.Mode()&want != want {
return fs.ErrPermission
}
return nil
}
func osSetMode(file *os.File, modeof string) error {
fi, err := os.Stat(modeof)
if err != nil {
return err
}
file.Chmod(fi.Mode())
return nil
}
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
func osGetSharedLock(file *os.File) _ErrorCode {
// Acquire the PENDING lock temporarily before acquiring a new SHARED lock.
rc := osReadLock(file, _PENDING_BYTE, 1, timeout)
rc := osReadLock(file, _PENDING_BYTE, 1, 0)
if rc == _OK {
// Acquire the SHARED lock.
@@ -70,16 +39,12 @@ func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
return rc
}
func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
if timeout == 0 {
timeout = time.Millisecond
}
func osGetExclusiveLock(file *os.File) _ErrorCode {
// Release the SHARED lock.
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Acquire the EXCLUSIVE lock.
rc := osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
rc := osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, time.Millisecond)
if rc != _OK {
// Reacquire the SHARED lock.
@@ -138,6 +103,7 @@ func osUnlock(file *os.File, start, len uint32) _ErrorCode {
}
func osLock(file *os.File, flags, start, len uint32, timeout time.Duration, def _ErrorCode) _ErrorCode {
before := time.Now()
var err error
for {
err = windows.LockFileEx(windows.Handle(file.Fd()), flags,
@@ -145,11 +111,16 @@ func osLock(file *os.File, flags, start, len uint32, timeout time.Duration, def
if errno, _ := err.(windows.Errno); errno != windows.ERROR_LOCK_VIOLATION {
break
}
if timeout < time.Millisecond {
if timeout <= 0 || timeout < time.Since(before) {
break
}
if err := windows.TimeBeginPeriod(1); err != nil {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
if err := windows.TimeEndPeriod(1); err != nil {
break
}
}
return osLockErrorCode(err, def)
}

View File

@@ -1,10 +1,10 @@
package readervfs_test
import (
"bytes"
"database/sql"
"fmt"
"log"
"strings"
_ "embed"
@@ -15,7 +15,7 @@ import (
)
//go:embed testdata/test.db
var testDB []byte
var testDB string
func Example_http() {
readervfs.Create("demo.db", httpreadat.New("https://www.sanford.io/demo.db"))
@@ -65,7 +65,7 @@ func Example_http() {
}
func Example_embed() {
readervfs.Create("test.db", readervfs.NewSizeReaderAt(bytes.NewReader(testDB)))
readervfs.Create("test.db", readervfs.NewSizeReaderAt(strings.NewReader(testDB)))
defer readervfs.Delete("test.db")
db, err := sql.Open("sqlite3", "file:test.db?vfs=reader")

View File

@@ -16,6 +16,7 @@ import (
"sync/atomic"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/vfs"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
"github.com/tetratelabs/wazero"
@@ -82,16 +83,15 @@ func system(ctx context.Context, mod api.Module, ptr uint32) uint32 {
cfg := config(ctx).WithArgs(args...)
go func() {
ctx, vfs := vfs.NewContext(ctx)
ctx := util.NewContext(ctx)
mod, _ := rt.InstantiateModule(ctx, module, cfg)
mod.Close(ctx)
vfs.Close()
}()
return 0
}
func Test_config01(t *testing.T) {
ctx, vfs := vfs.NewContext(newContext(t))
ctx := util.NewContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "config01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
@@ -99,7 +99,6 @@ func Test_config01(t *testing.T) {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func Test_config02(t *testing.T) {
@@ -110,7 +109,7 @@ func Test_config02(t *testing.T) {
t.Skip("skipping in CI")
}
ctx, vfs := vfs.NewContext(newContext(t))
ctx := util.NewContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "config02.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
@@ -118,7 +117,6 @@ func Test_config02(t *testing.T) {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func Test_crash01(t *testing.T) {
@@ -126,7 +124,7 @@ func Test_crash01(t *testing.T) {
t.Skip("skipping in short mode")
}
ctx, vfs := vfs.NewContext(newContext(t))
ctx := util.NewContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "crash01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
@@ -134,7 +132,6 @@ func Test_crash01(t *testing.T) {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func Test_multiwrite01(t *testing.T) {
@@ -142,7 +139,7 @@ func Test_multiwrite01(t *testing.T) {
t.Skip("skipping in short mode")
}
ctx, vfs := vfs.NewContext(newContext(t))
ctx := util.NewContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "multiwrite01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
@@ -150,12 +147,11 @@ func Test_multiwrite01(t *testing.T) {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func Test_config01_memory(t *testing.T) {
ctx, vfs := vfs.NewContext(newContext(t))
cfg := config(ctx).WithArgs("mptest", "test.db",
ctx := util.NewContext(newContext(t))
cfg := config(ctx).WithArgs("mptest", "/test.db",
"config01.test",
"--vfs", "memdb",
"--timeout", "1000")
@@ -164,7 +160,6 @@ func Test_config01_memory(t *testing.T) {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func Test_multiwrite01_memory(t *testing.T) {
@@ -172,7 +167,7 @@ func Test_multiwrite01_memory(t *testing.T) {
t.Skip("skipping in short mode")
}
ctx, vfs := vfs.NewContext(newContext(t))
ctx := util.NewContext(newContext(t))
cfg := config(ctx).WithArgs("mptest", "/test.db",
"multiwrite01.test",
"--vfs", "memdb",
@@ -182,7 +177,6 @@ func Test_multiwrite01_memory(t *testing.T) {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func newContext(t *testing.T) context.Context {

View File

@@ -4,7 +4,7 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../../../../
BINARYEN="$ROOT/tools/binaryen-version_114/bin"
BINARYEN="$ROOT/tools/binaryen-version_116/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d23b37d507077cdcbb616852185370a227278b599187dc134200ed274a7a3a02
size 1441194
oid sha256:5b77e9e13a487e976a6e71bc698542098433d1cc586ad8f24784f1f325ffb8dd
size 1459145

View File

@@ -18,6 +18,7 @@ import (
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/vfs"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
@@ -74,7 +75,7 @@ func initFlags() {
func Benchmark_speedtest1(b *testing.B) {
output.Reset()
ctx, vfs := vfs.NewContext(context.Background())
ctx := util.NewContext(context.Background())
name := filepath.Join(b.TempDir(), "test.db")
args := append(options, "--size", strconv.Itoa(b.N), name)
cfg := wazero.NewModuleConfig().
@@ -88,5 +89,4 @@ func Benchmark_speedtest1(b *testing.B) {
b.Error(err)
}
mod.Close(ctx)
vfs.Close()
}

View File

@@ -4,7 +4,7 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../../../../
BINARYEN="$ROOT/tools/binaryen-version_114/bin"
BINARYEN="$ROOT/tools/binaryen-version_116/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:83d67feda51cc974634e245ac2b072f9587c607c7ad97321f2de9dde2188e63a
size 1481348
oid sha256:3b52de3306965ac3f812592be29697d75232802a13bb16a34344f8d81dbf0637
size 1499410

View File

@@ -44,33 +44,6 @@ func ExportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder
return env
}
type vfsKey struct{}
type vfsState struct {
files []File
}
// NewContext is an internal API users need not call directly.
//
// NewContext creates a new context to hold [api.Module] specific VFS data.
// The context should be passed to any [api.Function] calls that might
// generate VFS host callbacks.
// The returned [io.Closer] should be closed after the [api.Module] is closed,
// to release any associated resources.
func NewContext(ctx context.Context) (context.Context, io.Closer) {
vfs := new(vfsState)
return context.WithValue(ctx, vfsKey{}, vfs), vfs
}
func (vfs *vfsState) Close() error {
for _, f := range vfs.files {
if f != nil {
f.Close()
}
}
vfs.files = nil
return nil
}
func vfsFind(ctx context.Context, mod api.Module, zVfsName uint32) uint32 {
name := util.ReadString(mod, zVfsName, _MAX_STRING)
if vfs := Find(name); vfs != nil && vfs != (vfsOS{}) {
@@ -183,6 +156,10 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla
file, flags, err = vfs.Open(path, flags)
}
if err != nil {
return vfsErrorCode(err, _CANTOPEN)
}
if file, ok := file.(FilePowersafeOverwrite); ok {
if !parsed {
params = vfsURIParameters(ctx, mod, zPath, flags)
@@ -192,14 +169,10 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla
}
}
if err != nil {
return vfsErrorCode(err, _CANTOPEN)
}
vfsFileRegister(ctx, mod, pFile, file)
if pOutFlags != 0 {
util.WriteUint32(mod, pOutFlags, uint32(flags))
}
vfsFileRegister(ctx, mod, pFile, file)
return _OK
}
@@ -291,14 +264,6 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl
return _OK
}
case _FCNTL_LOCK_TIMEOUT:
if file, ok := file.(*vfsFile); ok {
millis := file.lockTimeout.Milliseconds()
file.lockTimeout = time.Duration(util.ReadUint32(mod, pArg)) * time.Millisecond
util.WriteUint32(mod, pArg, uint32(millis))
return _OK
}
case _FCNTL_POWERSAFE_OVERWRITE:
if file, ok := file.(FilePowersafeOverwrite); ok {
switch util.ReadUint32(mod, pArg) {
@@ -336,6 +301,12 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl
return vfsErrorCode(err, _IOERR_FSTAT)
}
case _FCNTL_OVERWRITE:
if file, ok := file.(FileOverwrite); ok {
err := file.Overwrite()
return vfsErrorCode(err, _IOERR)
}
case _FCNTL_COMMIT_PHASETWO:
if file, ok := file.(FileCommitPhaseTwo); ok {
err := file.CommitPhaseTwo()
@@ -360,10 +331,8 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl
}
// Consider also implementing these opcodes (in use by SQLite):
// _FCNTL_PDB
// _FCNTL_BUSYHANDLER
// _FCNTL_CHUNK_SIZE
// _FCNTL_OVERWRITE
// _FCNTL_PRAGMA
// _FCNTL_SYNC
return _NOTFOUND
@@ -431,40 +400,22 @@ func vfsGet(mod api.Module, pVfs uint32) VFS {
panic(util.NoVFSErr + util.ErrorString(name))
}
func vfsFileNew(vfs *vfsState, file File) uint32 {
// Find an empty slot.
for id, f := range vfs.files {
if f == nil {
vfs.files[id] = file
return uint32(id)
}
}
// Add a new slot.
vfs.files = append(vfs.files, file)
return uint32(len(vfs.files) - 1)
}
func vfsFileRegister(ctx context.Context, mod api.Module, pFile uint32, file File) {
const fileHandleOffset = 4
id := vfsFileNew(ctx.Value(vfsKey{}).(*vfsState), file)
id := util.AddHandle(ctx, file)
util.WriteUint32(mod, pFile+fileHandleOffset, id)
}
func vfsFileGet(ctx context.Context, mod api.Module, pFile uint32) File {
const fileHandleOffset = 4
vfs := ctx.Value(vfsKey{}).(*vfsState)
id := util.ReadUint32(mod, pFile+fileHandleOffset)
return vfs.files[id]
return util.GetHandle(ctx, id).(File)
}
func vfsFileClose(ctx context.Context, mod api.Module, pFile uint32) error {
const fileHandleOffset = 4
vfs := ctx.Value(vfsKey{}).(*vfsState)
id := util.ReadUint32(mod, pFile+fileHandleOffset)
file := vfs.files[id]
vfs.files[id] = nil
return file.Close()
return util.DelHandle(ctx, id)
}
func vfsErrorCode(err error, def _ErrorCode) _ErrorCode {
@@ -477,9 +428,3 @@ func vfsErrorCode(err error, def _ErrorCode) _ErrorCode {
}
return def
}
func clear(b []byte) {
for i := range b {
b[i] = 0
}
}

View File

@@ -220,8 +220,7 @@ func Test_vfsAccess(t *testing.T) {
func Test_vfsFile(t *testing.T) {
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
ctx, vfs := NewContext(context.TODO())
defer vfs.Close()
ctx := util.NewContext(context.TODO())
// Open a temporary file.
rc := vfsOpen(ctx, mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)
@@ -293,8 +292,7 @@ func Test_vfsFile(t *testing.T) {
func Test_vfsFile_psow(t *testing.T) {
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
ctx, vfs := NewContext(context.TODO())
defer vfs.Close()
ctx := util.NewContext(context.TODO())
// Open a temporary file.
rc := vfsOpen(ctx, mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)