Compare commits

...

37 Commits

Author SHA1 Message Date
Nuno Cruces
47fe032078 Updated dependencies. 2023-07-26 12:42:18 +01:00
Nuno Cruces
bdfe279444 Soundex. 2023-07-26 02:02:39 +01:00
dependabot[bot]
a86937a54e Bump github.com/tetratelabs/wazero from 1.3.0 to 1.3.1
Bumps [github.com/tetratelabs/wazero](https://github.com/tetratelabs/wazero) from 1.3.0 to 1.3.1.
- [Release notes](https://github.com/tetratelabs/wazero/releases)
- [Commits](https://github.com/tetratelabs/wazero/compare/v1.3.0...v1.3.1)

---
updated-dependencies:
- dependency-name: github.com/tetratelabs/wazero
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-07-25 08:02:20 +01:00
Nuno Cruces
6ef422fbde Unicode tests. 2023-07-13 12:19:32 +01:00
Nuno Cruces
ff0cb6fb88 Unicode tests, fixes. 2023-07-12 13:39:07 +01:00
Nuno Cruces
72db90efdf Unicode. 2023-07-11 16:34:15 +01:00
Nuno Cruces
5a3fdef3c5 wazero v1.3.0. 2023-07-11 12:30:39 +01:00
dependabot[bot]
ff34b0cae1 Bump golang.org/x/text from 0.10.0 to 0.11.0
Bumps [golang.org/x/text](https://github.com/golang/text) from 0.10.0 to 0.11.0.
- [Release notes](https://github.com/golang/text/releases)
- [Commits](https://github.com/golang/text/compare/v0.10.0...v0.11.0)

---
updated-dependencies:
- dependency-name: golang.org/x/text
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-07-04 23:55:17 +01:00
Nuno Cruces
f064492bb1 Updated dependencies. 2023-07-04 19:55:11 +01:00
Nuno Cruces
1427d30541 Updated dependencies. 2023-07-04 19:48:55 +01:00
Nuno Cruces
d3730341f0 Unknown collations. 2023-07-04 11:16:29 +01:00
Nuno Cruces
78ac2386f6 Refactor. 2023-07-04 02:29:38 +01:00
Nuno Cruces
632ea933b3 Function aux data. 2023-07-04 02:18:03 +01:00
Nuno Cruces
0f7fa6ebc9 Tests. 2023-07-03 18:28:46 +01:00
Nuno Cruces
6f7f776488 Refactor. 2023-07-03 17:42:53 +01:00
Nuno Cruces
f6d7c5e9c5 Refactor. 2023-07-03 17:08:16 +01:00
Nuno Cruces
1cc7ecfe8d Custom aggregate functions. 2023-07-03 15:45:16 +01:00
Nuno Cruces
3844e81404 Custom aggregate functions. 2023-07-01 15:19:45 +01:00
Nuno Cruces
fec1f8d32a Custom scalar functions. 2023-07-01 00:16:42 +01:00
Nuno Cruces
31572e6095 Fix nil/zero handles. 2023-06-30 17:09:01 +01:00
Nuno Cruces
4aee38b957 Error handling. 2023-06-30 12:25:07 +01:00
Nuno Cruces
232a7705b5 Wrap context. 2023-06-30 11:48:54 +01:00
Nuno Cruces
a6c2fccd74 Wrap value. 2023-06-30 10:45:16 +01:00
Nuno Cruces
6a982559cd Custom collating sequences. 2023-06-30 02:49:21 +01:00
Nuno Cruces
c7904d30de Refactor file handles. 2023-06-30 01:52:18 +01:00
Nuno Cruces
ce4386604d GORM v1.25.1. 2023-06-29 20:06:56 +01:00
Nuno Cruces
26b62c520d Towards SQL functions. 2023-06-29 14:21:59 +01:00
Nuno Cruces
738714bf32 Fix WAL. 2023-06-26 13:31:42 +01:00
Nuno Cruces
41b020bafc go-sqlite3 v0.8.0. 2023-06-16 17:21:50 +01:00
Nuno Cruces
d0e720272b Optimization flags. 2023-06-15 15:57:39 +01:00
Nuno Cruces
76171da12b go-sqlite3 v0.7.3. 2023-06-15 03:56:02 +01:00
Nuno Cruces
dcc845d684 wazero v1.2.1. 2023-06-15 03:43:25 +01:00
dependabot[bot]
f1b42c26d5 Bump golang.org/x/sync from 0.2.0 to 0.3.0
Bumps [golang.org/x/sync](https://github.com/golang/sync) from 0.2.0 to 0.3.0.
- [Commits](https://github.com/golang/sync/compare/v0.2.0...v0.3.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sync
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-06-15 00:13:43 +01:00
dependabot[bot]
1e94407ae7 Bump golang.org/x/sys from 0.8.0 to 0.9.0
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.8.0 to 0.9.0.
- [Commits](https://github.com/golang/sys/compare/v0.8.0...v0.9.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-06-13 00:31:02 +01:00
Nuno Cruces
eb8d9b95fd Consistent lock timeouts. 2023-06-12 13:04:37 +01:00
Nuno Cruces
04037a75ed GORM driver sync. 2023-06-12 10:56:03 +01:00
Nuno Cruces
2472ceb0a0 Fix GORM module name. 2023-06-07 12:40:18 +01:00
69 changed files with 1968 additions and 416 deletions

View File

@@ -34,8 +34,9 @@ jobs:
- name: Download
run: go mod download
- name: Verify
run: go mod verify
# Fixed in go 1.21: https://go.dev/issue/54372
# - name: Verify
# run: go mod verify
- name: Vet
run: go vet ./...

View File

@@ -31,16 +31,11 @@ This has benefits, but also comes with some drawbacks.
Because WASM does not support shared memory,
[WAL](https://www.sqlite.org/wal.html) support is [limited](https://www.sqlite.org/wal.html#noshm).
To work around this limitation, SQLite is compiled with
[`SQLITE_DEFAULT_LOCKING_MODE=1`](https://www.sqlite.org/compile.html#default_locking_mode),
making `EXCLUSIVE` the default locking mode.
For non-WAL databases, `NORMAL` locking mode can be activated with
[`PRAGMA locking_mode=NORMAL`](https://www.sqlite.org/pragma.html#pragma_locking_mode).
To work around this limitation, SQLite is [patched](sqlite3/locking_mode.patch)
to always use `EXCLUSIVE` locking mode for WAL databases.
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
the `database/sql` driver defaults to `NORMAL` locking mode.
To open WAL databases, or use `EXCLUSIVE` locking mode,
disable connection pooling by calling
to open WAL databases you should disable connection pooling by calling
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
#### POSIX Advisory Locks
@@ -68,6 +63,7 @@ Performance is tested by running
### Roadmap
- [ ] advanced SQLite features
- [x] custom functions
- [x] nested transactions
- [x] incremental BLOB I/O
- [x] online backup
@@ -77,7 +73,6 @@ Performance is tested by running
- [x] in-memory VFS
- [x] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
- [ ] custom SQL functions
### Alternatives

View File

@@ -77,7 +77,7 @@ func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string
if r == 0 {
defer c.closeDB(other)
r = c.call(c.api.errcode, uint64(dst))
return nil, c.module.error(r, dst)
return nil, c.sqlite.error(r, dst)
}
return &Backup{

24
conn.go
View File

@@ -19,7 +19,7 @@ import (
//
// https://www.sqlite.org/c3ref/sqlite3.html
type Conn struct {
*module
*sqlite
interrupt context.Context
waiter chan struct{}
@@ -39,7 +39,7 @@ func Open(filename string) (*Conn, error) {
// If none of the required flags is used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used.
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
//
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)")
//
// https://www.sqlite.org/c3ref/open.html
func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
@@ -50,19 +50,19 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
}
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
mod, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
return nil, err
}
defer func() {
if conn == nil {
mod.close()
sqlite.close()
} else {
runtime.SetFinalizer(conn, util.Finalizer[Conn](3))
}
}()
c := &Conn{module: mod}
c := &Conn{sqlite: sqlite}
c.arena = c.newArena(1024)
c.handle, err = c.openDB(filename, flags)
if err != nil {
@@ -80,7 +80,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
handle := util.ReadUint32(c.mod, connPtr)
if err := c.module.error(r, handle); err != nil {
if err := c.sqlite.error(r, handle); err != nil {
c.closeDB(handle)
return 0, err
}
@@ -99,7 +99,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
c.arena.reset()
pragmaPtr := c.arena.string(pragmas.String())
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.module.error(r, handle, pragmas.String()); err != nil {
if err := c.sqlite.error(r, handle, pragmas.String()); err != nil {
if errors.Is(err, ERROR) {
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
@@ -113,7 +113,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
func (c *Conn) closeDB(handle uint32) {
r := c.call(c.api.closeZombie, uint64(handle))
if err := c.module.error(r, handle); err != nil {
if err := c.sqlite.error(r, handle); err != nil {
panic(err)
}
}
@@ -143,7 +143,7 @@ func (c *Conn) Close() error {
c.handle = 0
runtime.SetFinalizer(c, nil)
return c.module.close()
return c.close()
}
// Exec is a convenience function that allows an application to run
@@ -278,7 +278,7 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
break
case <-ctx.Done(): // Done was closed.
const isInterruptedOffset = 280
const isInterruptedOffset = 288
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
// Wait for the next call to SetInterrupt.
@@ -295,7 +295,7 @@ func (c *Conn) checkInterrupt() bool {
if c.interrupt == nil || c.interrupt.Err() == nil {
return false
}
const isInterruptedOffset = 280
const isInterruptedOffset = 288
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
return true
@@ -319,7 +319,7 @@ func (c *Conn) Pragma(str string) ([]string, error) {
}
func (c *Conn) error(rc uint64, sql ...string) error {
return c.module.error(rc, c.handle, sql...)
return c.sqlite.error(rc, c.handle, sql...)
}
// DriverConn is implemented by the SQLite [database/sql] driver connection.

View File

@@ -167,6 +167,18 @@ const (
PREPARE_NO_VTAB PrepareFlag = 0x04
)
// FunctionFlag is a flag that can be passed to [Conn.PrepareFlags].
//
// https://www.sqlite.org/c3ref/c_deterministic.html
type FunctionFlag uint32
const (
DETERMINISTIC FunctionFlag = 0x000000800
DIRECTONLY FunctionFlag = 0x000080000
SUBTYPE FunctionFlag = 0x000100000
INNOCUOUS FunctionFlag = 0x000200000
)
// Datatype is a fundamental datatype of SQLite.
//
// https://www.sqlite.org/c3ref/c_blob.html
@@ -182,18 +194,18 @@ const (
// String implements the [fmt.Stringer] interface.
func (t Datatype) String() string {
const name = "INTEGERFLOATTEXTBLOBNULL"
const name = "INTEGERFLOATEXTBLOBNULL"
switch t {
case INTEGER:
return name[0:7]
case FLOAT:
return name[7:12]
case TEXT:
return name[12:16]
return name[11:15]
case BLOB:
return name[16:20]
return name[15:19]
case NULL:
return name[20:24]
return name[19:23]
}
return strconv.FormatUint(uint64(t), 10)
}

174
context.go Normal file
View File

@@ -0,0 +1,174 @@
package sqlite3
import (
"errors"
"math"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Context is the context in which an SQL function executes.
// An SQLite [Context] is in no way related to a Go [context.Context].
//
// https://www.sqlite.org/c3ref/context.html
type Context struct {
*sqlite
handle uint32
}
// SetAuxData saves metadata for argument n of the function.
//
// https://www.sqlite.org/c3ref/get_auxdata.html
func (c Context) SetAuxData(n int, data any) {
ptr := util.AddHandle(c.ctx, data)
c.call(c.api.setAuxData, uint64(c.handle), uint64(n), uint64(ptr))
}
// GetAuxData returns metadata for argument n of the function.
//
// https://www.sqlite.org/c3ref/get_auxdata.html
func (c Context) GetAuxData(n int) any {
ptr := uint32(c.call(c.api.getAuxData, uint64(c.handle), uint64(n)))
return util.GetHandle(c.ctx, ptr)
}
// ResultBool sets the result of the function to a bool.
// SQLite does not have a separate boolean storage class.
// Instead, boolean values are stored as integers 0 (false) and 1 (true).
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultBool(value bool) {
var i int64
if value {
i = 1
}
c.ResultInt64(i)
}
// ResultInt sets the result of the function to an int.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultInt(value int) {
c.ResultInt64(int64(value))
}
// ResultInt64 sets the result of the function to an int64.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultInt64(value int64) {
c.call(c.api.resultInteger,
uint64(c.handle), uint64(value))
}
// ResultFloat sets the result of the function to a float64.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultFloat(value float64) {
c.call(c.api.resultFloat,
uint64(c.handle), math.Float64bits(value))
}
// ResultText sets the result of the function to a string.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultText(value string) {
ptr := c.newString(value)
c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(value)),
uint64(c.api.destructor), _UTF8)
}
// ResultBlob sets the result of the function to a []byte.
// Returning a nil slice is the same as calling [Context.ResultNull].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultBlob(value []byte) {
ptr := c.newBytes(value)
c.call(c.api.resultBlob,
uint64(c.handle), uint64(ptr), uint64(len(value)),
uint64(c.api.destructor))
}
// BindZeroBlob sets the result of the function to a zero-filled, length n BLOB.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultZeroBlob(n int64) {
c.call(c.api.resultZeroBlob,
uint64(c.handle), uint64(n))
}
// ResultNull sets the result of the function to NULL.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultNull() {
c.call(c.api.resultNull,
uint64(c.handle))
}
// ResultTime sets the result of the function to a [time.Time].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultTime(value time.Time, format TimeFormat) {
if format == TimeFormatDefault {
c.resultRFC3339Nano(value)
return
}
switch v := format.Encode(value).(type) {
case string:
c.ResultText(v)
case int64:
c.ResultInt64(v)
case float64:
c.ResultFloat(v)
default:
panic(util.AssertErr())
}
}
func (c Context) resultRFC3339Nano(value time.Time) {
const maxlen = uint64(len(time.RFC3339Nano))
ptr := c.new(maxlen)
buf := util.View(c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(buf)),
uint64(c.api.destructor), _UTF8)
}
// ResultError sets the result of the function an error.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultError(err error) {
if errors.Is(err, NOMEM) {
c.call(c.api.resultErrorMem, uint64(c.handle))
return
}
if errors.Is(err, TOOBIG) {
c.call(c.api.resultErrorBig, uint64(c.handle))
return
}
str := err.Error()
ptr := c.newString(str)
c.call(c.api.resultError,
uint64(c.handle), uint64(ptr), uint64(len(str)))
c.free(ptr)
var code uint64
var ecode ErrorCode
var xcode xErrorCode
switch {
case errors.As(err, &xcode):
code = uint64(xcode)
case errors.As(err, &ecode):
code = uint64(ecode)
}
if code != 0 {
c.call(c.api.resultErrorCode,
uint64(c.handle), code)
}
}

View File

@@ -14,10 +14,9 @@
//
// [PRAGMA] statements can be specified using "_pragma":
//
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)")
//
// If no PRAGMAs are specified, a busy timeout of 1 minute
// and normal locking mode are used.
// If no PRAGMAs are specified, a busy timeout of 1 minute is set.
//
// Order matters:
// busy timeout and locking mode should be the first PRAGMAs set, in that order.
@@ -53,9 +52,14 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
if err != nil {
return nil, err
}
defer func() {
if err != nil {
c.Close()
}
}()
var pragmas bool
c.txBegin = "BEGIN"
var pragmas []string
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
query, _ := url.ParseQuery(after)
@@ -66,20 +70,15 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
case "deferred", "immediate", "exclusive":
c.txBegin = "BEGIN " + s
default:
c.Close()
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
}
pragmas = query["_pragma"]
pragmas = len(query["_pragma"]) > 0
}
}
if len(pragmas) == 0 {
err := c.Conn.Exec(`
PRAGMA busy_timeout=60000;
PRAGMA locking_mode=normal;
`)
if !pragmas {
err = c.Conn.Exec(`PRAGMA busy_timeout=60000`)
if err != nil {
c.Close()
return nil, err
}
c.reusable = true
@@ -90,7 +89,6 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
PRAGMA_query_only;
`)
if err != nil {
c.Close()
return nil, err
}
if s.Step() {
@@ -99,7 +97,6 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
}
err = s.Close()
if err != nil {
c.Close()
return nil, err
}
}

View File

@@ -9,6 +9,7 @@ The following optional features are compiled in:
- [JSON](https://www.sqlite.org/json1.html)
- [R*Tree](https://www.sqlite.org/rtree.html)
- [GeoPoly](https://www.sqlite.org/geopoly.html)
- [soundex](https://www.sqlite.org/lang_corefunc.html#soundex)
- [base64](https://github.com/sqlite/sqlite/blob/master/ext/misc/base64.c)
- [decimal](https://github.com/sqlite/sqlite/blob/master/ext/misc/decimal.c)
- [regexp](https://github.com/sqlite/sqlite/blob/master/ext/misc/regexp.c)

View File

@@ -4,24 +4,27 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../
BINARYEN="$ROOT/tools/binaryen-version_113/bin"
BINARYEN="$ROOT/tools/binaryen-version_114/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
-o sqlite3.wasm "$ROOT/sqlite3/main.c" \
-I"$ROOT/sqlite3" \
-mexec-model=reactor \
-mmutable-globals \
-msimd128 -mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-fno-stack-protector -fno-stack-clash-protection \
-Wl,--initial-memory=327680 \
-Wl,--stack-first \
-Wl,--import-undefined \
-D_HAVE_SQLITE_CONFIG_H \
$(awk '{print "-Wl,--export="$0}' exports.txt)
trap 'rm -f sqlite3.tmp' EXIT
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
"$BINARYEN/wasm-opt" -g -O2 sqlite3.tmp -o sqlite3.wasm \
--enable-multivalue --enable-mutable-globals \
"$BINARYEN/wasm-opt" -g --strip -c -O3 \
sqlite3.tmp -o sqlite3.wasm \
--enable-simd --enable-mutable-globals --enable-multivalue \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext

View File

@@ -33,10 +33,10 @@ sqlite3_column_blob
sqlite3_column_bytes
sqlite3_blob_open
sqlite3_blob_close
sqlite3_blob_reopen
sqlite3_blob_bytes
sqlite3_blob_read
sqlite3_blob_write
sqlite3_blob_reopen
sqlite3_backup_init
sqlite3_backup_step
sqlite3_backup_finish
@@ -46,4 +46,29 @@ sqlite3_uri_parameter
sqlite3_uri_key
sqlite3_changes64
sqlite3_last_insert_rowid
sqlite3_get_autocommit
sqlite3_get_autocommit
sqlite3_anycollseq_init
sqlite3_create_collation_go
sqlite3_create_function_go
sqlite3_create_aggregate_function_go
sqlite3_create_window_function_go
sqlite3_aggregate_context
sqlite3_user_data
sqlite3_set_auxdata_go
sqlite3_get_auxdata
sqlite3_value_type
sqlite3_value_int64
sqlite3_value_double
sqlite3_value_text
sqlite3_value_blob
sqlite3_value_bytes
sqlite3_result_null
sqlite3_result_int64
sqlite3_result_double
sqlite3_result_text64
sqlite3_result_blob64
sqlite3_result_zeroblob64
sqlite3_result_error
sqlite3_result_error_code
sqlite3_result_error_nomem
sqlite3_result_error_toobig

Binary file not shown.

View File

@@ -68,6 +68,19 @@ func (e *Error) Is(err error) bool {
return false
}
// As converts this error to an [ErrorCode] or [ExtendedErrorCode].
func (e *Error) As(err any) bool {
switch c := err.(type) {
case *ErrorCode:
*c = e.Code()
return true
case *ExtendedErrorCode:
*c = e.ExtendedCode()
return true
}
return false
}
// Temporary returns true for [BUSY] errors.
func (e *Error) Temporary() bool {
return e.Code() == BUSY
@@ -104,6 +117,15 @@ func (e ExtendedErrorCode) Is(err error) bool {
return ok && c == ErrorCode(e)
}
// As converts this error to an [ErrorCode].
func (e ExtendedErrorCode) As(err any) bool {
c, ok := err.(*ErrorCode)
if ok {
*c = ErrorCode(e)
}
return ok
}
// Temporary returns true for [BUSY] errors.
func (e ExtendedErrorCode) Temporary() bool {
return ErrorCode(e) == BUSY

View File

@@ -18,22 +18,36 @@ func Test_assertErr(t *testing.T) {
func TestError(t *testing.T) {
t.Parallel()
err := Error{code: 0x8080}
if rc := err.Code(); rc != 0x80 {
t.Errorf("got %#x, want 0x80", rc)
var ecode ErrorCode
var xcode xErrorCode
err := &Error{code: 0x8080}
if !errors.As(err, &err) {
t.Fatal("want true")
}
if !errors.Is(&err, ErrorCode(0x80)) {
if ecode := err.Code(); ecode != 0x80 {
t.Errorf("got %#x, want 0x80", uint8(ecode))
}
if ok := errors.As(err, &ecode); !ok || ecode != ErrorCode(0x80) {
t.Errorf("got %#x, want 0x80", uint8(ecode))
}
if !errors.Is(err, ErrorCode(0x80)) {
t.Errorf("want true")
}
if rc := err.ExtendedCode(); rc != 0x8080 {
t.Errorf("got %#x, want 0x8080", rc)
if xcode := err.ExtendedCode(); xcode != 0x8080 {
t.Errorf("got %#x, want 0x8080", uint16(xcode))
}
if !errors.Is(&err, ExtendedErrorCode(0x8080)) {
if ok := errors.As(err, &xcode); !ok || xcode != xErrorCode(0x8080) {
t.Errorf("got %#x, want 0x8080", uint16(xcode))
}
if !errors.Is(err, xErrorCode(0x8080)) {
t.Errorf("want true")
}
if s := err.Error(); s != "sqlite3: 32896" {
t.Errorf("got %q", s)
}
if ok := errors.As(err.ExtendedCode(), &ecode); !ok || ecode != ErrorCode(0x80) {
t.Errorf("got %#x, want 0x80", uint8(ecode))
}
if !errors.Is(err.ExtendedCode(), ErrorCode(0x80)) {
t.Errorf("want true")
}

176
ext/unicode/unicode.go Normal file
View File

@@ -0,0 +1,176 @@
// Package unicode provides an alternative to the SQLite ICU extension.
//
// Provides Unicode aware:
// - upper and lower functions,
// - LIKE and REGEXP operators,
// - collation sequences.
//
// This package is not 100% compatible with the ICU extension:
// - upper and lower use [strings.ToUpper], [strings.ToLower] and [cases];
// - the LIKE operator follows [strings.EqualFold] rules;
// - the REGEXP operator uses Go [regex/syntax];
// - collation sequences use [collate].
//
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
package unicode
import (
"bytes"
"regexp"
"strings"
"unicode/utf8"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"golang.org/x/text/cases"
"golang.org/x/text/collate"
"golang.org/x/text/language"
)
// Register registers Unicode aware functions for a database connection.
func Register(db *sqlite3.Conn) {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
db.CreateFunction("like", 2, flags, like)
db.CreateFunction("like", 3, flags, like)
db.CreateFunction("upper", 1, flags, upper)
db.CreateFunction("upper", 2, flags, upper)
db.CreateFunction("lower", 1, flags, lower)
db.CreateFunction("lower", 2, flags, lower)
db.CreateFunction("regexp", 2, flags, regex)
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
name := arg[1].Text()
if name == "" {
return
}
tag, err := language.Parse(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return
}
err = db.CreateCollation(name, collate.New(tag).Compare)
if err != nil {
ctx.ResultError(err)
return
}
})
}
func upper(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) == 1 {
ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob()))
return
}
cs, ok := ctx.GetAuxData(1).(cases.Caser)
if !ok {
t, err := language.Parse(arg[1].Text())
if err != nil {
ctx.ResultError(err)
return
}
c := cases.Upper(t)
ctx.SetAuxData(1, c)
cs = c
}
ctx.ResultBlob(cs.Bytes(arg[0].RawBlob()))
}
func lower(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) == 1 {
ctx.ResultBlob(bytes.ToLower(arg[0].RawBlob()))
return
}
cs, ok := ctx.GetAuxData(1).(cases.Caser)
if !ok {
t, err := language.Parse(arg[1].Text())
if err != nil {
ctx.ResultError(err)
return
}
c := cases.Lower(t)
ctx.SetAuxData(1, c)
cs = c
}
ctx.ResultBlob(cs.Bytes(arg[0].RawBlob()))
}
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return
}
re = r
ctx.SetAuxData(0, re)
}
ctx.ResultBool(re.Match(arg[1].RawBlob()))
}
func like(ctx sqlite3.Context, arg ...sqlite3.Value) {
escape := rune(-1)
if len(arg) == 3 {
var size int
b := arg[2].RawBlob()
escape, size = utf8.DecodeRune(b)
if size != len(b) {
ctx.ResultError(util.ErrorString("ESCAPE expression must be a single character"))
return
}
}
type likeData struct {
*regexp.Regexp
escape rune
}
re, ok := ctx.GetAuxData(0).(likeData)
if !ok || re.escape != escape {
re = likeData{
regexp.MustCompile(like2regex(arg[0].Text(), escape)),
escape,
}
ctx.SetAuxData(0, re)
}
ctx.ResultBool(re.Match(arg[1].RawBlob()))
}
func like2regex(pattern string, escape rune) string {
var re strings.Builder
start := 0
literal := false
re.Grow(len(pattern) + 10)
re.WriteString(`(?is)\A`) // case insensitive, . matches any character
for i, r := range pattern {
if start < 0 {
start = i
}
if literal {
literal = false
continue
}
var symbol string
switch r {
case '_':
symbol = `.`
case '%':
symbol = `.*`
case escape:
literal = true
default:
continue
}
re.WriteString(regexp.QuoteMeta(pattern[start:i]))
re.WriteString(symbol)
start = -1
}
if start >= 0 {
re.WriteString(regexp.QuoteMeta(pattern[start:]))
}
re.WriteString(`\z`)
return re.String()
}

215
ext/unicode/unicode_test.go Normal file
View File

@@ -0,0 +1,215 @@
package unicode
import (
"errors"
"reflect"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestRegister(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
exec := func(fn string) string {
stmt, _, err := db.Prepare(`SELECT ` + fn)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
return stmt.ColumnText(0)
}
t.Fatal(stmt.Err())
return ""
}
Register(db)
tests := []struct {
test string
want string
}{
{`upper('hello')`, "HELLO"},
{`lower('HELLO')`, "hello"},
{`upper('привет')`, "ПРИВЕТ"},
{`lower('ПРИВЕТ')`, "привет"},
{`upper('istanbul')`, "ISTANBUL"},
{`upper('istanbul', 'tr-TR')`, "İSTANBUL"},
{`lower('Dünyanın İlk Borsası', 'tr-TR')`, "dünyanın ilk borsası"},
{`upper('Dünyanın İlk Borsası', 'tr-TR')`, "DÜNYANIN İLK BORSASI"},
{`'Hello' REGEXP 'ell'`, "1"},
{`'Hello' REGEXP 'el.'`, "1"},
{`'Hello' LIKE 'hel_'`, "0"},
{`'Hello' LIKE 'hel%'`, "1"},
{`'Hello' LIKE 'h_llo'`, "1"},
{`'Hello' LIKE 'hello'`, "1"},
{`'Привет' LIKE 'ПРИВЕТ'`, "1"},
{`'100%' LIKE '100|%' ESCAPE '|'`, "1"},
}
for _, tt := range tests {
t.Run(tt.test, func(t *testing.T) {
if got := exec(tt.test); got != tt.want {
t.Errorf("exec(%q) = %q, want %q", tt.test, got, tt.want)
}
})
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}
func TestRegister_collation(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
Register(db)
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`SELECT icu_load_collation('fr_FR', 'french')`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT word FROM words ORDER BY word COLLATE french`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
got, want := []string{}, []string{"cote", "coté", "côte", "côté", "cotée", "coter"}
for stmt.Step() {
got = append(got, stmt.ColumnText(0))
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, want) {
t.Error("not equal")
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}
func TestRegister_error(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
Register(db)
err = db.Exec(`SELECT upper('hello', 'enUS')`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT lower('hello', 'enUS')`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT 'hello' REGEXP '\'`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT 'hello' LIKE 'HELLO' ESCAPE '\\'`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT icu_load_collation('enUS', 'error')`)
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.ERROR) {
t.Errorf("got %v, want sqlite3.ERROR", err)
}
err = db.Exec(`SELECT icu_load_collation('enUS', '')`)
if err != nil {
t.Error(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}
func Test_like2regex(t *testing.T) {
const prefix = `(?is)\A`
const sufix = `\z`
tests := []struct {
pattern string
escape rune
want string
}{
{`a`, -1, `a`},
{`a.`, -1, `a\.`},
{`a%`, -1, `a.*`},
{`a\`, -1, `a\\`},
{`a_b`, -1, `a.b`},
{`a|b`, '|', `ab`},
{`a|_`, '|', `a_`},
}
for _, tt := range tests {
t.Run(tt.pattern, func(t *testing.T) {
want := prefix + tt.want + sufix
if got := like2regex(tt.pattern, tt.escape); got != want {
t.Errorf("like2regex() = %q, want %q", got, want)
}
})
}
}

186
func.go Normal file
View File

@@ -0,0 +1,186 @@
package sqlite3
import (
"context"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
// AnyCollationNeeded registers a fake collating function
// for any unknown collating sequence.
// The fake collating function works like BINARY.
//
// This extension can be used to load schemas that contain
// one or more unknown collating sequences.
func (c *Conn) AnyCollationNeeded() {
c.call(c.api.anyCollation, uint64(c.handle), 0, 0)
}
// CreateCollation defines a new collating sequence.
//
// https://www.sqlite.org/c3ref/create_collation.html
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
r := c.call(c.api.createCollation,
uint64(c.handle), uint64(namePtr), uint64(funcPtr))
if err := c.error(r); err != nil {
util.DelHandle(c.ctx, funcPtr)
return err
}
return nil
}
// CreateFunction defines a new scalar SQL function.
//
// https://www.sqlite.org/c3ref/create_function.html
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error {
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
r := c.call(c.api.createFunction,
uint64(c.handle), uint64(namePtr), uint64(nArg),
uint64(flag), uint64(funcPtr))
return c.error(r)
}
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
// If fn returns a [WindowFunction], then an aggregate window function is created.
//
// https://www.sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
call := c.api.createAggregate
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
if _, ok := fn().(WindowFunction); ok {
call = c.api.createWindow
}
r := c.call(call,
uint64(c.handle), uint64(namePtr), uint64(nArg),
uint64(flag), uint64(funcPtr))
return c.error(r)
}
// AggregateFunction is the interface an aggregate function should implement.
//
// https://www.sqlite.org/appfunc.html
type AggregateFunction interface {
// Step is invoked to add a row to the current window.
// The function arguments, if any, corresponding to the row being added are passed to Step.
Step(ctx Context, arg ...Value)
// Value is invoked to return the current value of the aggregate.
Value(ctx Context)
}
// WindowFunction is the interface an aggregate window function should implement.
//
// https://www.sqlite.org/windowfunctions.html
type WindowFunction interface {
AggregateFunction
// Inverse is invoked to remove the oldest presently aggregated result of Step from the current window.
// The function arguments, if any, are those passed to Step for the row being removed.
Inverse(ctx Context, arg ...Value)
}
func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
util.ExportFuncVIII(env, "go_func", callbackFunc)
util.ExportFuncVIII(env, "go_step", callbackStep)
util.ExportFuncVI(env, "go_final", callbackFinal)
util.ExportFuncVI(env, "go_value", callbackValue)
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
return env
}
func callbackDestroy(ctx context.Context, mod api.Module, pApp uint32) {
util.DelHandle(ctx, pApp)
}
func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))
}
func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackHandle(sqlite, pCtx).(func(ctx Context, arg ...Value))
fn(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
}
func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
fn.Step(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
}
func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) {
var handle uint32
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, &handle).(AggregateFunction)
fn.Value(Context{sqlite, pCtx})
if err := util.DelHandle(ctx, handle); err != nil {
Context{sqlite, pCtx}.ResultError(err)
}
}
func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
fn.Value(Context{sqlite, pCtx})
}
func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, nil).(WindowFunction)
fn.Inverse(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
}
func callbackHandle(sqlite *sqlite, pCtx uint32) any {
pApp := uint32(sqlite.call(sqlite.api.userData, uint64(pCtx)))
return util.GetHandle(sqlite.ctx, pApp)
}
func callbackAggregate(sqlite *sqlite, pCtx uint32, close *uint32) any {
// On close, we're getting rid of the handle.
// Don't allocate space to store it.
var size uint64
if close == nil {
size = ptrlen
}
ptr := uint32(sqlite.call(sqlite.api.aggregateCtx, uint64(pCtx), size))
// Try loading the handle, if we already have one, or want a new one.
if ptr != 0 || size != 0 {
if handle := util.ReadUint32(sqlite.mod, ptr); handle != 0 {
fn := util.GetHandle(sqlite.ctx, handle)
if close != nil {
*close = handle
}
if fn != nil {
return fn
}
}
}
// Create a new aggregate and store the handle.
fn := callbackHandle(sqlite, pCtx).(func() AggregateFunction)()
if ptr != 0 {
util.WriteUint32(sqlite.mod, ptr, util.AddHandle(sqlite.ctx, fn))
}
return fn
}
func callbackArgs(sqlite *sqlite, nArg, pArg uint32) []Value {
args := make([]Value, nArg)
for i := range args {
args[i] = Value{
sqlite: sqlite,
handle: util.ReadUint32(sqlite.mod, pArg+ptrlen*uint32(i)),
}
}
return args
}

154
func_test.go Normal file
View File

@@ -0,0 +1,154 @@
package sqlite3_test
import (
"bytes"
"fmt"
"log"
"regexp"
"golang.org/x/text/collate"
"golang.org/x/text/language"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func ExampleConn_CreateCollation() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
log.Fatal(err)
}
err = db.CreateCollation("french", collate.New(language.French).Compare)
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT word FROM words ORDER BY word COLLATE french`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnText(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Output:
// cote
// coté
// côte
// côté
// cotée
// coter
}
func ExampleConn_CreateFunction() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
log.Fatal(err)
}
err = db.CreateFunction("upper", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob()))
})
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT upper(word) FROM words`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnText(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Unordered output:
// COTE
// COTÉ
// CÔTE
// CÔTÉ
// COTÉE
// COTER
}
func ExampleContext_SetAuxData() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
log.Fatal(err)
}
err = db.CreateFunction("regexp", 2, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return
}
ctx.SetAuxData(0, r)
re = r
}
ctx.ResultBool(re.Match(arg[1].RawBlob()))
})
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT word FROM words WHERE word REGEXP '^\p{L}+e$'`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnText(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Unordered output:
// cote
// côte
// cotée
}

87
func_win_test.go Normal file
View File

@@ -0,0 +1,87 @@
package sqlite3_test
import (
"fmt"
"log"
"unicode"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func ExampleConn_CreateWindowFunction() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
log.Fatal(err)
}
err = db.CreateWindowFunction("count_ascii", 1, sqlite3.INNOCUOUS, newASCIICounter)
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT count_ascii(word) OVER (ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM words`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnInt(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Output:
// 1
// 2
// 2
// 1
// 0
// 0
}
type countASCII struct{ result int }
func newASCIICounter() sqlite3.AggregateFunction {
return &countASCII{}
}
func (f *countASCII) Value(ctx sqlite3.Context) {
ctx.ResultInt(f.result)
}
func (f *countASCII) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
if f.isASCII(arg[0]) {
f.result++
}
}
func (f *countASCII) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
if f.isASCII(arg[0]) {
f.result--
}
}
func (f *countASCII) isASCII(arg sqlite3.Value) bool {
if arg.Type() != sqlite3.TEXT {
return false
}
for _, c := range arg.RawBlob() {
if c > unicode.MaxASCII {
return false
}
}
return true
}

7
go.mod
View File

@@ -5,9 +5,10 @@ go 1.19
require (
github.com/ncruces/julianday v0.1.5
github.com/psanford/httpreadat v0.1.0
github.com/tetratelabs/wazero v1.2.0
golang.org/x/sync v0.2.0
golang.org/x/sys v0.8.0
github.com/tetratelabs/wazero v1.3.1
golang.org/x/sync v0.3.0
golang.org/x/sys v0.10.0
golang.org/x/text v0.11.0
)
retract v0.4.0 // tagged from the wrong branch

14
go.sum
View File

@@ -2,9 +2,11 @@ github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FB
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
github.com/tetratelabs/wazero v1.2.0 h1:I/8LMf4YkCZ3r2XaL9whhA0VMyAvF6QE+O7rco0DCeQ=
github.com/tetratelabs/wazero v1.2.0/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
github.com/tetratelabs/wazero v1.3.1 h1:rnb9FgOEQRLLR8tgoD1mfjNjMhFeWRUk+a4b4j/GpUM=
github.com/tetratelabs/wazero v1.3.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=

6
go.work Normal file
View File

@@ -0,0 +1,6 @@
go 1.19
use (
.
./gormlite
)

4
go.work.sum Normal file
View File

@@ -0,0 +1,4 @@
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=

View File

@@ -1,5 +1,7 @@
# GORM SQLite Driver
[![Go Reference](https://pkg.go.dev/badge/image)](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
## Usage
```go
@@ -19,6 +21,6 @@ Checkout [https://gorm.io](https://gorm.io) for details.
Foreign-key constraint is disabled by default in SQLite. To activate it, use connection URL parameter:
```go
db, err := gorm.Open(gormlite.Open(
"file:gorm.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)&_pragma=foreign_keys(1)"),
"file:gorm.db?_pragma=busy_timeout(10000)&_pragma=foreign_keys(1)"),
&gorm.Config{})
```

View File

@@ -162,7 +162,7 @@ func parseDDL(strs ...string) (*ddl, error) {
for _, column := range getAllColumns(matches[1]) {
for idx, c := range result.columns {
if c.NameValue.String == column {
c.UniqueValue = sql.NullBool{Bool: true, Valid: true}
c.UniqueValue = sql.NullBool{Bool: strings.ToUpper(strings.Fields(str)[1]) == "UNIQUE", Valid: true}
result.columns[idx] = c
}
}

View File

@@ -79,6 +79,24 @@ func TestParseDDL(t *testing.T) {
},
},
},
{
"non-unique index",
[]string{
"CREATE TABLE `test-c` (`field` integer NOT NULL)",
"CREATE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
},
1,
[]migrator.ColumnType{
{
NameValue: sql.NullString{String: "field", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
UniqueValue: sql.NullBool{Bool: false, Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
},
},
},
}
for _, p := range params {

11
gormlite/download.sh Executable file
View File

@@ -0,0 +1,11 @@
#!/usr/bin/env bash
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/ddlmod.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/ddlmod_test.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/error_translator.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/migrator.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/sqlite.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/sqlite_test.go"

View File

@@ -8,8 +8,14 @@ import (
)
func (dialector Dialector) Translate(err error) error {
if errors.Is(err, sqlite3.CONSTRAINT_UNIQUE) {
switch {
case
errors.Is(err, sqlite3.CONSTRAINT_UNIQUE),
errors.Is(err, sqlite3.CONSTRAINT_PRIMARYKEY):
return gorm.ErrDuplicatedKey
case
errors.Is(err, sqlite3.CONSTRAINT_FOREIGNKEY):
return gorm.ErrForeignKeyViolated
}
return err
}

View File

@@ -1,7 +0,0 @@
package gormlite
import "errors"
var (
ErrConstraintsNotImplemented = errors.New("constraints not implemented on sqlite, consider using DisableForeignKeyConstraintWhenMigrating, more details https://github.com/go-gorm/gorm/wiki/GORM-V2-Release-Note-Draft#all-new-migrator")
)

View File

@@ -1,19 +1,16 @@
module github.com/ncruces/go-sqlite/gormlite
module github.com/ncruces/go-sqlite3/gormlite
go 1.19
require (
github.com/google/uuid v1.3.0
github.com/jinzhu/now v1.1.5
github.com/ncruces/go-sqlite3 v0.7.1
gorm.io/driver/mysql v1.5.1
gorm.io/gorm v1.25.1
github.com/ncruces/go-sqlite3 v0.8.3
gorm.io/gorm v1.25.2
)
require (
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/ncruces/julianday v0.1.5 // indirect
github.com/tetratelabs/wazero v1.2.0 // indirect
golang.org/x/sys v0.8.0 // indirect
github.com/tetratelabs/wazero v1.3.1 // indirect
golang.org/x/sys v0.10.0 // indirect
)

View File

@@ -1,20 +1,15 @@
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/ncruces/go-sqlite3 v0.7.1 h1:SDd3g18RobYi3NM8nZgozHM6jaqIbpMEmX42YGOVcTU=
github.com/ncruces/go-sqlite3 v0.7.1/go.mod h1:n+DEDYam8SK5jmsfUC/9GFhSF0gVHGXiYFXnAo8Jwsc=
github.com/ncruces/go-sqlite3 v0.8.3 h1:kYUAqDpZ0OT+snTH1yWyxq9QSJ22HoM3WKfFEL4N694=
github.com/ncruces/go-sqlite3 v0.8.3/go.mod h1:DUdzKfMlIFmSLAtNHdIgxbdax/5NsQx2RlIlVO7EWfU=
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.2.0 h1:I/8LMf4YkCZ3r2XaL9whhA0VMyAvF6QE+O7rco0DCeQ=
github.com/tetratelabs/wazero v1.2.0/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gorm.io/driver/mysql v1.5.1 h1:WUEH5VF9obL/lTtzjmML/5e6VfFR/788coz2uaVCAZw=
gorm.io/driver/mysql v1.5.1/go.mod h1:Jo3Xu7mMhCyj8dlrb3WoCaRd1FhsVh+yMXb1jUInf5o=
gorm.io/gorm v1.25.1 h1:nsSALe5Pr+cM3V1qwwQ7rOkw+6UeLrX5O4v3llhHa64=
gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
github.com/tetratelabs/wazero v1.3.1 h1:rnb9FgOEQRLLR8tgoD1mfjNjMhFeWRUk+a4b4j/GpUM=
github.com/tetratelabs/wazero v1.3.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=

View File

@@ -322,6 +322,9 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error
var sql string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
if sql != "" {
if err := m.DropIndex(value, oldName); err != nil {
return err
}
return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
}
return fmt.Errorf("failed to find index with name %v", oldName)

View File

@@ -3,12 +3,16 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
rm -rf gorm/ tests/
git clone --filter=blob:none --branch=v1.25.1 https://github.com/go-gorm/gorm.git
rm -rf gorm/ tests/ $TMPDIR/gorm.db
git clone --filter=blob:none https://github.com/go-gorm/gorm.git
mv gorm/tests tests
rm -rf gorm/
patch -p1 -N < tests.patch
cd tests
go mod tidy && go test
go mod tidy && go work use . && go test
cd ..
rm -rf tests/ $TMPDIR/gorm.db
go work use -r .

View File

@@ -1,51 +1,31 @@
diff --git a/tests/.gitignore b/tests/.gitignore
index 08cb523..72e8ffc 100644
--- a/tests/.gitignore
+++ b/tests/.gitignore
@@ -1 +1 @@
-go.sum
+*
diff --git a/tests/go.mod b/tests/go.mod
index f47d175..dba4a24 100644
--- a/tests/go.mod
+++ b/tests/go.mod
@@ -7,13 +7,10 @@ require (
github.com/jackc/pgx/v5 v5.3.1 // indirect
@@ -6,9 +6,10 @@ require (
github.com/google/uuid v1.3.0
github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.8
- github.com/mattn/go-sqlite3 v1.14.16 // indirect
+ github.com/ncruces/go-sqlite3 v0.7.1
golang.org/x/crypto v0.8.0 // indirect
gorm.io/driver/mysql v1.5.0
gorm.io/driver/postgres v1.5.0
- gorm.io/driver/sqlite v1.5.0
gorm.io/driver/sqlserver v1.4.3
- gorm.io/gorm v1.25.0
-)
-
github.com/lib/pq v1.10.9
+ github.com/ncruces/go-sqlite3 v0.8.3
+ github.com/ncruces/go-sqlite3/gormlite v0.0.0
gorm.io/driver/mysql v1.5.2-0.20230612053416-48b6526a21f0
gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196
- gorm.io/driver/sqlite v1.5.2
gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a
gorm.io/gorm v1.25.2
)
@@ -27,4 +28,4 @@ require (
golang.org/x/text v0.11.0 // indirect
)
-replace gorm.io/gorm => ../
+ gorm.io/gorm v1.25.1
+)
\ No newline at end of file
diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go
index 1412169..472434b 100644
--- a/tests/scanner_valuer_test.go
+++ b/tests/scanner_valuer_test.go
@@ -170,10 +170,10 @@ func (data *EncryptedData) Scan(value interface{}) error {
return errors.New("Too short")
}
- *data = b[3:]
+ *data = append((*data)[0:], b[3:]...)
return nil
} else if s, ok := value.(string); ok {
- *data = []byte(s)[3:]
+ *data = []byte(s[3:])
return nil
}
+replace github.com/ncruces/go-sqlite3/gormlite => ../
diff --git a/tests/tests_test.go b/tests/tests_test.go
index 90eb847..cd9af43 100644
--- a/tests/tests_test.go
+++ b/tests/tests_test.go
@@ -7,9 +7,11 @@ import (
@@ -61,3 +41,12 @@ index 90eb847..cd9af43 100644
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@@ -89,7 +91,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
db, err = gorm.Open(mysql.Open(dbDSN), cfg)
default:
log.Println("testing sqlite3...")
- db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg)
+ db, err = gorm.Open(sqlite.Open("file:"+filepath.Join(os.TempDir(), "gorm.db")+"?_pragma=busy_timeout(1000)&_pragma=foreign_keys(1)"), cfg)
}
if err != nil {

View File

@@ -10,6 +10,32 @@ import (
type i32 interface{ ~int32 | ~uint32 }
type i64 interface{ ~int64 | ~uint64 }
type funcVI[T0 i32] func(context.Context, api.Module, T0)
func (fn funcVI[T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {
fn(ctx, mod, T0(stack[0]))
}
func ExportFuncVI[T0 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0)) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcVI[T0](fn),
[]api.ValueType{api.ValueTypeI32}, nil).
Export(name)
}
type funcVIII[T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2)
func (fn funcVIII[T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]))
}
func ExportFuncVIII[T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2)) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcVIII[T0, T1, T2](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, nil).
Export(name)
}
type funcII[TR, T0 i32] func(context.Context, api.Module, T0) TR
func (fn funcII[TR, T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {

75
internal/util/handle.go Normal file
View File

@@ -0,0 +1,75 @@
package util
import (
"context"
"io"
"github.com/tetratelabs/wazero/experimental"
)
type handleKey struct{}
type handleState struct {
handles []any
empty int
}
func NewContext(ctx context.Context) context.Context {
state := new(handleState)
ctx = experimental.WithCloseNotifier(ctx, state)
ctx = context.WithValue(ctx, handleKey{}, state)
return ctx
}
func (s *handleState) CloseNotify(ctx context.Context, exitCode uint32) {
for _, h := range s.handles {
if c, ok := h.(io.Closer); ok {
c.Close()
}
}
s.handles = nil
s.empty = 0
}
func GetHandle(ctx context.Context, id uint32) any {
if id == 0 {
return nil
}
s := ctx.Value(handleKey{}).(*handleState)
return s.handles[^id]
}
func DelHandle(ctx context.Context, id uint32) error {
if id == 0 {
return nil
}
s := ctx.Value(handleKey{}).(*handleState)
a := s.handles[^id]
s.handles[^id] = nil
s.empty++
if c, ok := a.(io.Closer); ok {
return c.Close()
}
return nil
}
func AddHandle(ctx context.Context, a any) (id uint32) {
if a == nil {
panic(NilErr)
}
s := ctx.Value(handleKey{}).(*handleState)
// Find an empty slot.
if s.empty > cap(s.handles)-len(s.handles) {
for id, h := range s.handles {
if h == nil {
s.empty--
s.handles[id] = a
return ^uint32(id)
}
}
}
// Add a new slot.
s.handles = append(s.handles, a)
return -uint32(len(s.handles))
}

View File

@@ -3,7 +3,6 @@ package sqlite3
import (
"context"
"io"
"math"
"os"
"sync"
@@ -25,70 +24,67 @@ var (
Path string // Path to load the binary from.
)
var sqlite3 struct {
var instance struct {
runtime wazero.Runtime
compiled wazero.CompiledModule
err error
once sync.Once
}
func instantiateModule() (*module, error) {
func compileSQLite() {
ctx := context.Background()
instance.runtime = wazero.NewRuntime(ctx)
sqlite3.once.Do(compileModule)
if sqlite3.err != nil {
return nil, sqlite3.err
}
cfg := wazero.NewModuleConfig()
mod, err := sqlite3.runtime.InstantiateModule(ctx, sqlite3.compiled, cfg)
if err != nil {
return nil, err
}
return newModule(mod)
}
func compileModule() {
ctx := context.Background()
sqlite3.runtime = wazero.NewRuntime(ctx)
env := vfs.ExportHostFunctions(sqlite3.runtime.NewHostModuleBuilder("env"))
_, sqlite3.err = env.Instantiate(ctx)
if sqlite3.err != nil {
env := instance.runtime.NewHostModuleBuilder("env")
env = vfs.ExportHostFunctions(env)
env = exportHostFunctions(env)
_, instance.err = env.Instantiate(ctx)
if instance.err != nil {
return
}
bin := Binary
if bin == nil && Path != "" {
bin, sqlite3.err = os.ReadFile(Path)
if sqlite3.err != nil {
bin, instance.err = os.ReadFile(Path)
if instance.err != nil {
return
}
}
if bin == nil {
sqlite3.err = util.BinaryErr
instance.err = util.BinaryErr
return
}
sqlite3.compiled, sqlite3.err = sqlite3.runtime.CompileModule(ctx, bin)
instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin)
}
type module struct {
ctx context.Context
mod api.Module
vfs io.Closer
api sqliteAPI
arg [8]uint64
type sqlite struct {
ctx context.Context
mod api.Module
api sqliteAPI
stack [8]uint64
}
func newModule(mod api.Module) (m *module, err error) {
m = new(module)
m.mod = mod
m.ctx, m.vfs = vfs.NewContext(context.Background())
type sqliteKey struct{}
func instantiateSQLite() (sqlt *sqlite, err error) {
instance.once.Do(compileSQLite)
if instance.err != nil {
return nil, instance.err
}
sqlt = new(sqlite)
sqlt.ctx = util.NewContext(context.Background())
sqlt.ctx = context.WithValue(sqlt.ctx, sqliteKey{}, sqlt)
sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
instance.compiled, wazero.NewModuleConfig())
if err != nil {
return nil, err
}
getFun := func(name string) api.Function {
f := mod.ExportedFunction(name)
f := sqlt.mod.ExportedFunction(name)
if f == nil {
err = util.NoFuncErr + util.ErrorString(name)
return nil
@@ -97,15 +93,15 @@ func newModule(mod api.Module) (m *module, err error) {
}
getVal := func(name string) uint32 {
g := mod.ExportedGlobal(name)
g := sqlt.mod.ExportedGlobal(name)
if g == nil {
err = util.NoGlobalErr + util.ErrorString(name)
return 0
}
return util.ReadUint32(mod, uint32(g.Get()))
return util.ReadUint32(sqlt.mod, uint32(g.Get()))
}
m.api = sqliteAPI{
sqlt.api = sqliteAPI{
free: getFun("free"),
malloc: getFun("malloc"),
destructor: getVal("malloc_destructor"),
@@ -153,20 +149,43 @@ func newModule(mod api.Module) (m *module, err error) {
changes: getFun("sqlite3_changes64"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
autocommit: getFun("sqlite3_get_autocommit"),
anyCollation: getFun("sqlite3_anycollseq_init"),
createCollation: getFun("sqlite3_create_collation_go"),
createFunction: getFun("sqlite3_create_function_go"),
createAggregate: getFun("sqlite3_create_aggregate_function_go"),
createWindow: getFun("sqlite3_create_window_function_go"),
aggregateCtx: getFun("sqlite3_aggregate_context"),
userData: getFun("sqlite3_user_data"),
setAuxData: getFun("sqlite3_set_auxdata_go"),
getAuxData: getFun("sqlite3_get_auxdata"),
valueType: getFun("sqlite3_value_type"),
valueInteger: getFun("sqlite3_value_int64"),
valueFloat: getFun("sqlite3_value_double"),
valueText: getFun("sqlite3_value_text"),
valueBlob: getFun("sqlite3_value_blob"),
valueBytes: getFun("sqlite3_value_bytes"),
resultNull: getFun("sqlite3_result_null"),
resultInteger: getFun("sqlite3_result_int64"),
resultFloat: getFun("sqlite3_result_double"),
resultText: getFun("sqlite3_result_text64"),
resultBlob: getFun("sqlite3_result_blob64"),
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
resultError: getFun("sqlite3_result_error"),
resultErrorCode: getFun("sqlite3_result_error_code"),
resultErrorMem: getFun("sqlite3_result_error_nomem"),
resultErrorBig: getFun("sqlite3_result_error_toobig"),
}
if err != nil {
return nil, err
}
return m, nil
return sqlt, nil
}
func (m *module) close() error {
err := m.mod.Close(m.ctx)
m.vfs.Close()
return err
func (sqlt *sqlite) close() error {
return sqlt.mod.Close(sqlt.ctx)
}
func (m *module) error(rc uint64, handle uint32, sql ...string) error {
func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
if rc == _OK {
return nil
}
@@ -177,16 +196,16 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
panic(util.OOMErr)
}
if r := m.call(m.api.errstr, rc); r != 0 {
err.str = util.ReadString(m.mod, uint32(r), _MAX_STRING)
if r := sqlt.call(sqlt.api.errstr, rc); r != 0 {
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if r := m.call(m.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(m.mod, uint32(r), _MAX_STRING)
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if sql != nil {
if r := m.call(m.api.erroff, uint64(handle)); r != math.MaxUint32 {
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:]
}
}
@@ -198,60 +217,58 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
return &err
}
func (m *module) call(fn api.Function, params ...uint64) uint64 {
copy(m.arg[:], params)
err := fn.CallWithStack(m.ctx, m.arg[:])
func (sqlt *sqlite) call(fn api.Function, params ...uint64) uint64 {
copy(sqlt.stack[:], params)
err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
if err != nil {
// The module closed or panicked; release resources.
m.vfs.Close()
panic(err)
}
return m.arg[0]
return sqlt.stack[0]
}
func (m *module) free(ptr uint32) {
func (sqlt *sqlite) free(ptr uint32) {
if ptr == 0 {
return
}
m.call(m.api.free, uint64(ptr))
sqlt.call(sqlt.api.free, uint64(ptr))
}
func (m *module) new(size uint64) uint32 {
func (sqlt *sqlite) new(size uint64) uint32 {
if size > _MAX_ALLOCATION_SIZE {
panic(util.OOMErr)
}
ptr := uint32(m.call(m.api.malloc, size))
ptr := uint32(sqlt.call(sqlt.api.malloc, size))
if ptr == 0 && size != 0 {
panic(util.OOMErr)
}
return ptr
}
func (m *module) newBytes(b []byte) uint32 {
func (sqlt *sqlite) newBytes(b []byte) uint32 {
if b == nil {
return 0
}
ptr := m.new(uint64(len(b)))
util.WriteBytes(m.mod, ptr, b)
ptr := sqlt.new(uint64(len(b)))
util.WriteBytes(sqlt.mod, ptr, b)
return ptr
}
func (m *module) newString(s string) uint32 {
ptr := m.new(uint64(len(s) + 1))
util.WriteString(m.mod, ptr, s)
func (sqlt *sqlite) newString(s string) uint32 {
ptr := sqlt.new(uint64(len(s) + 1))
util.WriteString(sqlt.mod, ptr, s)
return ptr
}
func (m *module) newArena(size uint64) arena {
func (sqlt *sqlite) newArena(size uint64) arena {
return arena{
m: m,
base: m.new(size),
sqlt: sqlt,
size: uint32(size),
base: sqlt.new(size),
}
}
type arena struct {
m *module
sqlt *sqlite
ptrs []uint32
base uint32
next uint32
@@ -259,17 +276,17 @@ type arena struct {
}
func (a *arena) free() {
if a.m == nil {
if a.sqlt == nil {
return
}
a.reset()
a.m.free(a.base)
a.m = nil
a.sqlt.free(a.base)
a.sqlt = nil
}
func (a *arena) reset() {
for _, ptr := range a.ptrs {
a.m.free(ptr)
a.sqlt.free(ptr)
}
a.ptrs = nil
a.next = 0
@@ -281,7 +298,7 @@ func (a *arena) new(size uint64) uint32 {
a.next += uint32(size)
return ptr
}
ptr := a.m.new(size)
ptr := a.sqlt.new(size)
a.ptrs = append(a.ptrs, ptr)
return ptr
}
@@ -291,13 +308,13 @@ func (a *arena) bytes(b []byte) uint32 {
return 0
}
ptr := a.new(uint64(len(b)))
util.WriteBytes(a.m.mod, ptr, b)
util.WriteBytes(a.sqlt.mod, ptr, b)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint64(len(s) + 1))
util.WriteString(a.m.mod, ptr, s)
util.WriteString(a.sqlt.mod, ptr, s)
return ptr
}
@@ -317,10 +334,10 @@ type sqliteAPI struct {
step api.Function
exec api.Function
clearBindings api.Function
bindNull api.Function
bindCount api.Function
bindIndex api.Function
bindName api.Function
bindNull api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
@@ -348,5 +365,30 @@ type sqliteAPI struct {
changes api.Function
lastRowid api.Function
autocommit api.Function
anyCollation api.Function
createCollation api.Function
createFunction api.Function
createAggregate api.Function
createWindow api.Function
aggregateCtx api.Function
userData api.Function
setAuxData api.Function
getAuxData api.Function
valueType api.Function
valueInteger api.Function
valueFloat api.Function
valueText api.Function
valueBlob api.Function
valueBytes api.Function
resultNull api.Function
resultInteger api.Function
resultFloat api.Function
resultText api.Function
resultBlob api.Function
resultZeroBlob api.Function
resultError api.Function
resultErrorCode api.Function
resultErrorMem api.Function
resultErrorBig api.Function
destructor uint32
}

1
sqlite3/.gitignore vendored
View File

@@ -1,3 +1,4 @@
ext/
sqlite3.c
sqlite3.h
sqlite3ext.h

View File

@@ -8,9 +8,9 @@ unzip -d . sqlite-amalgamation-*.zip
mv sqlite-amalgamation-*/sqlite3* .
rm -rf sqlite-amalgamation-*
patch < vfs_find.patch
patch < deserialize.patch
cat *.patch | patch
mkdir -p ext/
cd ext/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uint.c"
@@ -18,6 +18,7 @@ curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uuid.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/series.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/anycollseq.c"
cd ~-
cd ../vfs/tests/mptest/testdata/

View File

@@ -1 +0,0 @@
*.c

40
sqlite3/func.c Normal file
View File

@@ -0,0 +1,40 @@
#include <string.h>
#include "sqlite3.h"
int go_compare(void *, int, const void *, int, const void *);
void go_func(sqlite3_context *, int, sqlite3_value **);
void go_step(sqlite3_context *, int, sqlite3_value **);
void go_final(sqlite3_context *);
void go_value(sqlite3_context *);
void go_inverse(sqlite3_context *, int, sqlite3_value **);
void go_destroy(void *);
int sqlite3_create_collation_go(sqlite3 *db, const char *zName, void *pApp) {
return sqlite3_create_collation_v2(db, zName, SQLITE_UTF8, pApp, go_compare,
go_destroy);
}
int sqlite3_create_function_go(sqlite3 *db, const char *zName, int nArg,
int flags, void *pApp) {
return sqlite3_create_function_v2(db, zName, nArg, SQLITE_UTF8 | flags, pApp,
go_func, NULL, NULL, go_destroy);
}
int sqlite3_create_aggregate_function_go(sqlite3 *db, const char *zName,
int nArg, int flags, void *pApp) {
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
pApp, go_step, go_final, NULL, NULL,
go_destroy);
}
int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg,
int flags, void *pApp) {
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
pApp, go_step, go_final, go_value,
go_inverse, go_destroy);
}
void sqlite3_set_auxdata_go(sqlite3_context *ctx, int iArg, void *pAux) {
sqlite3_set_auxdata(ctx, iArg, pAux, go_destroy);
}

View File

@@ -0,0 +1,14 @@
# Use exclusive locking mode for WAL databases with v1 VFSes.
--- sqlite3.c.orig
+++ sqlite3.c
@@ -63210,7 +63210,9 @@
SQLITE_PRIVATE int sqlite3PagerWalSupported(Pager *pPager){
const sqlite3_io_methods *pMethods = pPager->fd->pMethods;
if( pPager->noLock ) return 0;
- return pPager->exclusiveMode || (pMethods->iVersion>=2 && pMethods->xShmMap);
+ if( pMethods->iVersion>=2 && pMethods->xShmMap ) return 1;
+ pPager->exclusiveMode = 1;
+ return 1;
}
/*

View File

@@ -1,19 +1,19 @@
#include <stdbool.h>
#include <stddef.h>
// Configuration
#include "sqlite_cfg.h"
// Amalgamation
#include "sqlite3.c"
// VFS
#include "vfs.c"
// Extensions
#include "ext/anycollseq.c"
#include "ext/base64.c"
#include "ext/decimal.c"
#include "ext/regexp.c"
#include "ext/series.c"
#include "ext/uint.c"
#include "ext/uuid.c"
#include "func.c"
#include "time.c"
__attribute__((constructor)) void init() {

View File

@@ -1,20 +1,26 @@
# Allow the VFS to force memory journal mode
# regardless of SQLITE_OMIT_DESERIALIZE.
--- sqlite3.c.orig
+++ sqlite3.c
@@ -60425,7 +60425,7 @@
@@ -60425,11 +60425,7 @@
int rc = SQLITE_OK; /* Return code */
int tempFile = 0; /* True for temp files (incl. in-memory files) */
int memDb = 0; /* True if this is an in-memory file */
-#ifndef SQLITE_OMIT_DESERIALIZE
+#if 1
int memJM = 0; /* Memory journal mode */
#else
# define memJM 0
@@ -60628,7 +60628,7 @@
-#else
-# define memJM 0
-#endif
int readOnly = 0; /* True if this is a read-only file */
int journalFileSize; /* Bytes to allocate for each journal fd */
char *zPathname = 0; /* Full path to database file */
@@ -60628,9 +60624,7 @@
int fout = 0; /* VFS flags returned by xOpen() */
rc = sqlite3OsOpen(pVfs, pPager->zFilename, pPager->fd, vfsFlags, &fout);
assert( !memDb );
-#ifndef SQLITE_OMIT_DESERIALIZE
+#if 1
pPager->memVfs = memJM = (fout&SQLITE_OPEN_MEMORY)!=0;
#endif
-#endif
readOnly = (fout&SQLITE_OPEN_READONLY)!=0;
/* If the file was successfully opened for read/write access,

View File

@@ -29,6 +29,7 @@
#define SQLITE_USE_ALLOCA
// Other Options
#define SQLITE_ALLOW_URI_AUTHORITY
#define SQLITE_ENABLE_BATCH_ATOMIC_WRITE
#define SQLITE_ENABLE_ATOMIC_WRITE
@@ -36,12 +37,9 @@
// Because WASM does not support shared memory,
// SQLite disables WAL for WASM builds.
// We set the default locking mode to EXCLUSIVE instead.
// We patch SQLite to use exclusive locking mode instead.
// https://www.sqlite.org/wal.html#noshm
#undef SQLITE_OMIT_WAL
#ifndef SQLITE_DEFAULT_LOCKING_MODE
#define SQLITE_DEFAULT_LOCKING_MODE 1
#endif
// Amalgamated Extensions
@@ -58,5 +56,7 @@
// #define SQLITE_ENABLE_SESSION
// #define SQLITE_ENABLE_PREUPDATE_HOOK
#define SQLITE_SOUNDEX
// Implemented in vfs.c.
int localtime_s(struct tm *const pTm, time_t const *const pTime);

View File

@@ -134,4 +134,4 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");
static_assert(offsetof(sqlite3_vfs, zName) == 16, "Unexpected offset");
static_assert(offsetof(sqlite3, u1.isInterrupted) == 280, "Unexpected offset");
static_assert(offsetof(sqlite3, u1.isInterrupted) == 288, "Unexpected offset");

View File

@@ -1,3 +1,4 @@
# Wrap sqlite3_vfs_find.
--- sqlite3.c.orig
+++ sqlite3.c
@@ -25394,7 +25394,7 @@

View File

@@ -12,67 +12,67 @@ func init() {
Path = "./embed/sqlite3.wasm"
}
func TestConn_error_OOM(t *testing.T) {
func Test_sqlite_error_OOM(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
defer func() { _ = recover() }()
m.error(uint64(NOMEM), 0)
sqlite.error(uint64(NOMEM), 0)
t.Error("want panic")
}
func TestConn_call_closed(t *testing.T) {
func Test_sqlite_call_closed(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
m.close()
sqlite.close()
defer func() { _ = recover() }()
m.call(m.api.free)
sqlite.call(sqlite.api.free)
t.Error("want panic")
}
func TestConn_new(t *testing.T) {
func Test_sqlite_new(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
t.Run("MaxUint32", func(t *testing.T) {
defer func() { _ = recover() }()
m.new(math.MaxUint32)
sqlite.new(math.MaxUint32)
t.Error("want panic")
})
t.Run("_MAX_ALLOCATION_SIZE", func(t *testing.T) {
defer func() { _ = recover() }()
m.new(_MAX_ALLOCATION_SIZE)
m.new(_MAX_ALLOCATION_SIZE)
sqlite.new(_MAX_ALLOCATION_SIZE)
sqlite.new(_MAX_ALLOCATION_SIZE)
t.Error("want panic")
})
}
func TestConn_newArena(t *testing.T) {
func Test_sqlite_newArena(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
arena := m.newArena(16)
arena := sqlite.newArena(16)
defer arena.free()
const title = "Lorem ipsum"
@@ -80,7 +80,7 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != title {
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != title {
t.Errorf("got %q, want %q", got, title)
}
@@ -89,7 +89,7 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != body {
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != body {
t.Errorf("got %q, want %q", got, body)
}
@@ -101,121 +101,121 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := util.View(m.mod, ptr, uint64(len(title))); string(got) != title {
if got := util.View(sqlite.mod, ptr, uint64(len(title))); string(got) != title {
t.Errorf("got %q, want %q", got, title)
}
arena.free()
}
func TestConn_newBytes(t *testing.T) {
func Test_sqlite_newBytes(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
ptr := m.newBytes(nil)
ptr := sqlite.newBytes(nil)
if ptr != 0 {
t.Errorf("got %#x, want nullptr", ptr)
}
buf := []byte("sqlite3")
ptr = m.newBytes(buf)
ptr = sqlite.newBytes(buf)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := buf
if got := util.View(m.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
func TestConn_newString(t *testing.T) {
func Test_sqlite_newString(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
ptr := m.newString("")
ptr := sqlite.newString("")
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
str := "sqlite3\000sqlite3"
ptr = m.newString(str)
ptr = sqlite.newString(str)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := str + "\000"
if got := util.View(m.mod, ptr, uint64(len(want))); string(got) != want {
if got := util.View(sqlite.mod, ptr, uint64(len(want))); string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestConn_getString(t *testing.T) {
func Test_sqlite_getString(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
ptr := m.newString("")
ptr := sqlite.newString("")
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
str := "sqlite3" + "\000 drop this"
ptr = m.newString(str)
ptr = sqlite.newString(str)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := "sqlite3"
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != want {
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != want {
t.Errorf("got %q, want %q", got, want)
}
if got := util.ReadString(m.mod, ptr, 0); got != "" {
if got := util.ReadString(sqlite.mod, ptr, 0); got != "" {
t.Errorf("got %q, want empty", got)
}
func() {
defer func() { _ = recover() }()
util.ReadString(m.mod, ptr, uint32(len(want)/2))
util.ReadString(sqlite.mod, ptr, uint32(len(want)/2))
t.Error("want panic")
}()
func() {
defer func() { _ = recover() }()
util.ReadString(m.mod, 0, math.MaxUint32)
util.ReadString(sqlite.mod, 0, math.MaxUint32)
t.Error("want panic")
}()
}
func TestConn_free(t *testing.T) {
func Test_sqlite_free(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
sqlite, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
defer m.close()
defer sqlite.close()
m.free(0)
sqlite.free(0)
ptr := m.new(1)
ptr := sqlite.new(1)
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
m.free(ptr)
sqlite.free(ptr)
}

27
stmt.go
View File

@@ -131,10 +131,11 @@ func (s *Stmt) BindName(param int) string {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindBool(param int, value bool) error {
var i int64
if value {
return s.BindInt64(param, 1)
i = 1
}
return s.BindInt64(param, 0)
return s.BindInt64(param, i)
}
// BindInt binds an int to the prepared statement.
@@ -374,18 +375,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
func (s *Stmt) ColumnRawText(col int) []byte {
r := s.c.call(s.c.api.columnText,
uint64(s.handle), uint64(col))
ptr := uint32(r)
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r)
return nil
}
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
return util.View(s.c.mod, ptr, r)
return s.columnRawBytes(col, uint32(r))
}
// ColumnRawBlob returns the value of the result column as a []byte.
@@ -397,17 +387,18 @@ func (s *Stmt) ColumnRawText(col int) []byte {
func (s *Stmt) ColumnRawBlob(col int) []byte {
r := s.c.call(s.c.api.columnBlob,
uint64(s.handle), uint64(col))
return s.columnRawBytes(col, uint32(r))
}
ptr := uint32(r)
func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte {
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
r := s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r)
return nil
}
r = s.c.call(s.c.api.columnBytes,
r := s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
return util.View(s.c.mod, ptr, r)
}

View File

@@ -43,7 +43,7 @@ func (t params) mustExec(sql string, args ...interface{}) sql.Result {
func (sqliteDB) RunTest(t *testing.T, fn func(params)) {
db, err := sql.Open("sqlite3", "file:"+
filepath.Join(t.TempDir(), "foo.db")+
"?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)&_pragma=synchronous(off)")
"?_pragma=busy_timeout(10000)&_pragma=synchronous(off)")
if err != nil {
t.Fatalf("foo.db open fail: %v", err)
}

View File

@@ -1,14 +1,20 @@
package tests
import (
"os"
"path/filepath"
"testing"
_ "embed"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
//go:embed testdata/wal.db
var waldb []byte
func TestDB_memory(t *testing.T) {
t.Parallel()
testDB(t, ":memory:")
@@ -19,6 +25,16 @@ func TestDB_file(t *testing.T) {
testDB(t, filepath.Join(t.TempDir(), "test.db"))
}
func TestDB_wal(t *testing.T) {
t.Parallel()
wal := filepath.Join(t.TempDir(), "test.db")
err := os.WriteFile(wal, waldb, 0666)
if err != nil {
t.Fatal(err)
}
testDB(t, wal)
}
func TestDB_vfs(t *testing.T) {
testDB(t, "file:test.db?vfs=memdb")
}

188
tests/func_test.go Normal file
View File

@@ -0,0 +1,188 @@
package tests
import (
"errors"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestCreateFunction(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.CreateFunction("test", 1, sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
switch arg := arg[0]; arg.Int() {
case 0:
ctx.ResultInt(arg.Int())
case 1:
ctx.ResultInt64(arg.Int64())
case 2:
ctx.ResultBool(arg.Bool())
case 3:
ctx.ResultFloat(arg.Float())
case 4:
ctx.ResultText(arg.Text())
case 5:
ctx.ResultBlob(arg.Blob(nil))
case 6:
ctx.ResultZeroBlob(arg.Int64())
case 7:
ctx.ResultTime(arg.Time(sqlite3.TimeFormatUnix), sqlite3.TimeFormatDefault)
case 8:
ctx.ResultNull()
case 9:
ctx.ResultError(sqlite3.FULL)
}
})
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0, 9)`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want 1", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnInt64(0); got != 1 {
t.Errorf("got %v, want 2", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != true {
t.Errorf("got %v, want true", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.FLOAT {
t.Errorf("got %v, want FLOAT", got)
}
if got := stmt.ColumnInt64(0); got != 3 {
t.Errorf("got %v, want 3", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnText(0); got != "4" {
t.Errorf("got %s, want 4", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnRawBlob(0); string(got) != "5" {
t.Errorf("got %s, want 5", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnRawBlob(0); len(got) != 6 {
t.Errorf("got %v, want 6", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); got.Unix() != 7 {
t.Errorf("got %v, want 7", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
}
if stmt.Step() {
t.Error("want error")
}
if err := stmt.Err(); !errors.Is(err, sqlite3.FULL) {
t.Errorf("got %v, want sqlite3.FULL", err)
}
}
func TestAnyCollationNeeded(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}
db.AnyCollationNeeded()
stmt, _, err := db.Prepare(`SELECT id, name FROM users ORDER BY name COLLATE silly`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
row := 0
ids := []int{0, 2, 1}
names := []string{"go", "whatever", "zig"}
for ; stmt.Step(); row++ {
id := stmt.ColumnInt(0)
name := stmt.ColumnText(1)
if id != ids[row] {
t.Errorf("got %d, want %d", id, ids[row])
}
if name != names[row] {
t.Errorf("got %q, want %q", name, names[row])
}
}
if row != 3 {
t.Errorf("got %d, want %d", row, len(ids))
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
}

View File

@@ -25,7 +25,6 @@ func TestParallel(t *testing.T) {
name := "file:" +
filepath.Join(t.TempDir(), "test.db") +
"?_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
testParallel(t, name, iter)
@@ -42,7 +41,6 @@ func TestMemory(t *testing.T) {
name := "file:/test.db?vfs=memdb" +
"&_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(memory)" +
"&_pragma=synchronous(off)"
testParallel(t, name, iter)
@@ -59,7 +57,6 @@ func TestMultiProcess(t *testing.T) {
name := "file:" + file +
"?_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
@@ -93,7 +90,6 @@ func TestChildProcess(t *testing.T) {
name := "file:" + file +
"?_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
@@ -128,10 +124,7 @@ func testParallel(t *testing.T, name string, n int) {
}
defer db.Close()
err = db.Exec(`
PRAGMA busy_timeout=10000;
PRAGMA locking_mode=normal;
`)
err = db.Exec(`PRAGMA busy_timeout=10000`)
if err != nil {
return err
}

BIN
tests/testdata/wal.db vendored Normal file

Binary file not shown.

125
value.go Normal file
View File

@@ -0,0 +1,125 @@
package sqlite3
import (
"math"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Value is any value that can be stored in a database table.
//
// https://www.sqlite.org/c3ref/value.html
type Value struct {
*sqlite
handle uint32
}
// Type returns the initial [Datatype] of the value.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Type() Datatype {
r := v.call(v.api.valueType, uint64(v.handle))
return Datatype(r)
}
// Bool returns the value as a bool.
// SQLite does not have a separate boolean storage class.
// Instead, boolean values are retrieved as integers,
// with 0 converted to false and any other value to true.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Bool() bool {
if i := v.Int64(); i != 0 {
return true
}
return false
}
// Int returns the value as an int.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Int() int {
return int(v.Int64())
}
// Int64 returns the value as an int64.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Int64() int64 {
r := v.call(v.api.valueInteger, uint64(v.handle))
return int64(r)
}
// Float returns the value as a float64.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Float() float64 {
r := v.call(v.api.valueFloat, uint64(v.handle))
return math.Float64frombits(r)
}
// Time returns the value as a [time.Time].
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Time(format TimeFormat) time.Time {
var a any
switch v.Type() {
case INTEGER:
a = v.Int64()
case FLOAT:
a = v.Float()
case TEXT, BLOB:
a = v.Text()
case NULL:
return time.Time{}
default:
panic(util.AssertErr())
}
t, _ := format.Decode(a)
return t
}
// Text returns the value as a string.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Text() string {
return string(v.RawText())
}
// Blob appends to buf and returns
// the value as a []byte.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) Blob(buf []byte) []byte {
return append(buf, v.RawBlob()...)
}
// RawText returns the value as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Value] methods.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) RawText() []byte {
r := v.call(v.api.valueText, uint64(v.handle))
return v.rawBytes(uint32(r))
}
// RawBlob returns the value as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Value] methods.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v Value) RawBlob() []byte {
r := v.call(v.api.valueBlob, uint64(v.handle))
return v.rawBytes(uint32(r))
}
func (v Value) rawBytes(ptr uint32) []byte {
if ptr == 0 {
return nil
}
r := v.call(v.api.valueBytes, uint64(v.handle))
return util.View(v.mod, ptr, r)
}

View File

@@ -15,7 +15,7 @@ type VFS interface {
FullPathname(name string) (string, error)
}
// VFSParams extends VFS to with the ability to handle URI parameters
// VFSParams extends VFS with the ability to handle URI parameters
// through the OpenParams method.
//
// https://www.sqlite.org/c3ref/uri_boolean.html

View File

@@ -41,8 +41,7 @@ func Test_vfsLock(t *testing.T) {
pOutput = 32
)
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
ctx, vfs := NewContext(context.TODO())
defer vfs.Close()
ctx := util.NewContext(context.TODO())
vfsFileRegister(ctx, mod, pFile1, &vfsFile{File: file1})
vfsFileRegister(ctx, mod, pFile2, &vfsFile{File: file2})

View File

@@ -187,17 +187,11 @@ func (m *memFile) Lock(lock vfs.LockLevel) error {
m.lockMtx.Lock()
defer m.lockMtx.Unlock()
deadline := time.Now().Add(time.Millisecond)
switch lock {
case vfs.LOCK_SHARED:
for m.pending != nil {
if time.Now().After(deadline) {
return sqlite3.BUSY
}
m.lockMtx.Unlock()
runtime.Gosched()
m.lockMtx.Lock()
if m.pending != nil {
return sqlite3.BUSY
}
m.shared++
@@ -216,8 +210,8 @@ func (m *memFile) Lock(lock vfs.LockLevel) error {
m.pending = m
}
for m.shared > 1 {
if time.Now().After(deadline) {
for start := time.Now(); m.shared > 1; {
if time.Since(start) > time.Millisecond {
return sqlite3.BUSY
}
m.lockMtx.Unlock()

View File

@@ -28,7 +28,7 @@ var (
)
// Create creates an immutable database from reader.
// The caller should insure that data from reader does not mutate,
// The caller should ensure that data from reader does not mutate,
// otherwise SQLite might return incorrect query results and/or [sqlite3.CORRUPT] errors.
func Create(name string, reader SizeReaderAt) {
readerMtx.Lock()

View File

@@ -37,7 +37,7 @@ func (readerVFS) FullPathname(name string) (string, error) {
type readerFile struct{ SizeReaderAt }
func (r readerFile) Close() error {
func (readerFile) Close() error {
return nil
}

View File

@@ -16,6 +16,7 @@ import (
"sync/atomic"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/vfs"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
"github.com/tetratelabs/wazero"
@@ -36,8 +37,7 @@ var (
)
func TestMain(m *testing.M) {
ctx := context.TODO()
ctx := context.Background()
rt = wazero.NewRuntime(ctx)
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
@@ -83,16 +83,15 @@ func system(ctx context.Context, mod api.Module, ptr uint32) uint32 {
cfg := config(ctx).WithArgs(args...)
go func() {
ctx, vfs := vfs.NewContext(ctx)
ctx := util.NewContext(ctx)
mod, _ := rt.InstantiateModule(ctx, module, cfg)
mod.Close(ctx)
vfs.Close()
}()
return 0
}
func Test_config01(t *testing.T) {
ctx, vfs := vfs.NewContext(newContext(t))
ctx := util.NewContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "config01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
@@ -100,7 +99,6 @@ func Test_config01(t *testing.T) {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func Test_config02(t *testing.T) {
@@ -111,7 +109,7 @@ func Test_config02(t *testing.T) {
t.Skip("skipping in CI")
}
ctx, vfs := vfs.NewContext(newContext(t))
ctx := util.NewContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "config02.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
@@ -119,7 +117,6 @@ func Test_config02(t *testing.T) {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func Test_crash01(t *testing.T) {
@@ -127,7 +124,7 @@ func Test_crash01(t *testing.T) {
t.Skip("skipping in short mode")
}
ctx, vfs := vfs.NewContext(newContext(t))
ctx := util.NewContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "crash01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
@@ -135,7 +132,6 @@ func Test_crash01(t *testing.T) {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func Test_multiwrite01(t *testing.T) {
@@ -143,7 +139,7 @@ func Test_multiwrite01(t *testing.T) {
t.Skip("skipping in short mode")
}
ctx, vfs := vfs.NewContext(newContext(t))
ctx := util.NewContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "multiwrite01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
@@ -151,11 +147,10 @@ func Test_multiwrite01(t *testing.T) {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func Test_config01_memory(t *testing.T) {
ctx, vfs := vfs.NewContext(newContext(t))
ctx := util.NewContext(newContext(t))
cfg := config(ctx).WithArgs("mptest", "test.db",
"config01.test",
"--vfs", "memdb",
@@ -165,7 +160,6 @@ func Test_config01_memory(t *testing.T) {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func Test_multiwrite01_memory(t *testing.T) {
@@ -173,7 +167,7 @@ func Test_multiwrite01_memory(t *testing.T) {
t.Skip("skipping in short mode")
}
ctx, vfs := vfs.NewContext(newContext(t))
ctx := util.NewContext(newContext(t))
cfg := config(ctx).WithArgs("mptest", "/test.db",
"multiwrite01.test",
"--vfs", "memdb",
@@ -183,7 +177,6 @@ func Test_multiwrite01_memory(t *testing.T) {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func newContext(t *testing.T) context.Context {

View File

@@ -4,25 +4,28 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../../../../
BINARYEN="$ROOT/tools/binaryen-version_113/bin"
BINARYEN="$ROOT/tools/binaryen-version_114/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
-o mptest.wasm main.c \
-I"$ROOT/sqlite3" \
-mmutable-globals \
-msimd128 -mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-fno-stack-protector -fno-stack-clash-protection \
-Wl,--stack-first \
-Wl,--import-undefined \
-D_HAVE_SQLITE_CONFIG_H \
-DSQLITE_DEFAULT_SYNCHRONOUS=0 \
-DSQLITE_DEFAULT_LOCKING_MODE=0 \
-DHAVE_USLEEP -DSQLITE_NO_SYNC \
-DSQLITE_THREADSAFE=0 -DSQLITE_OMIT_LOAD_EXTENSION \
-D_WASI_EMULATED_GETPID -lwasi-emulated-getpid
"$BINARYEN/wasm-opt" -g -O2 mptest.wasm -o mptest.tmp \
--enable-multivalue --enable-mutable-globals \
"$BINARYEN/wasm-opt" -g --strip -c -O3 \
mptest.wasm -o mptest.tmp \
--enable-simd --enable-mutable-globals --enable-multivalue \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext
mv mptest.tmp mptest.wasm

View File

@@ -1,8 +1,6 @@
#include <stdbool.h>
#include <stddef.h>
// Configuration
#include "sqlite_cfg.h"
// Amalgamation
#include "sqlite3.c"
// VFS

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:53bb204be0ddd312c23976e7b627ad62fb384d220557f4e3f723d89273d78a01
size 1477009
oid sha256:d23b37d507077cdcbb616852185370a227278b599187dc134200ed274a7a3a02
size 1441194

View File

@@ -18,6 +18,7 @@ import (
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/vfs"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
@@ -74,7 +75,7 @@ func initFlags() {
func Benchmark_speedtest1(b *testing.B) {
output.Reset()
ctx, vfs := vfs.NewContext(context.Background())
ctx := util.NewContext(context.Background())
name := filepath.Join(b.TempDir(), "test.db")
args := append(options, "--size", strconv.Itoa(b.N), name)
cfg := wazero.NewModuleConfig().
@@ -88,5 +89,4 @@ func Benchmark_speedtest1(b *testing.B) {
b.Error(err)
}
mod.Close(ctx)
vfs.Close()
}

View File

@@ -4,20 +4,23 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../../../../
BINARYEN="$ROOT/tools/binaryen-version_113/bin"
BINARYEN="$ROOT/tools/binaryen-version_114/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
-o speedtest1.wasm main.c \
-I"$ROOT/sqlite3" \
-mmutable-globals \
-msimd128 -mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-fno-stack-protector -fno-stack-clash-protection \
-Wl,--stack-first \
-Wl,--import-undefined
-Wl,--import-undefined \
-D_HAVE_SQLITE_CONFIG_H
"$BINARYEN/wasm-opt" -g -O2 speedtest1.wasm -o speedtest1.tmp \
--enable-multivalue --enable-mutable-globals \
"$BINARYEN/wasm-opt" -g --strip -c -O3 \
speedtest1.wasm -o speedtest1.tmp \
--enable-simd --enable-mutable-globals --enable-multivalue \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext
mv speedtest1.tmp speedtest1.wasm

View File

@@ -1,8 +1,6 @@
#include <stdbool.h>
#include <stddef.h>
// Configuration
#include "sqlite_cfg.h"
// Amalgamation
#include "sqlite3.c"
// VFS

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:23ac5877bc34cf5ff06ed2cf1c5ac148e5d364d6d01c85bd0f99d14f0a681830
size 1513516
oid sha256:83d67feda51cc974634e245ac2b072f9587c607c7ad97321f2de9dde2188e63a
size 1481348

View File

@@ -44,33 +44,6 @@ func ExportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder
return env
}
type vfsKey struct{}
type vfsState struct {
files []File
}
// NewContext is an internal API users need not call directly.
//
// NewContext creates a new context to hold [api.Module] specific VFS data.
// The context should be passed to any [api.Function] calls that might
// generate VFS host callbacks.
// The returned [io.Closer] should be closed after the [api.Module] is closed,
// to release any associated resources.
func NewContext(ctx context.Context) (context.Context, io.Closer) {
vfs := new(vfsState)
return context.WithValue(ctx, vfsKey{}, vfs), vfs
}
func (vfs *vfsState) Close() error {
for _, f := range vfs.files {
if f != nil {
f.Close()
}
}
vfs.files = nil
return nil
}
func vfsFind(ctx context.Context, mod api.Module, zVfsName uint32) uint32 {
name := util.ReadString(mod, zVfsName, _MAX_STRING)
if vfs := Find(name); vfs != nil && vfs != (vfsOS{}) {
@@ -183,6 +156,10 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla
file, flags, err = vfs.Open(path, flags)
}
if err != nil {
return vfsErrorCode(err, _CANTOPEN)
}
if file, ok := file.(FilePowersafeOverwrite); ok {
if !parsed {
params = vfsURIParameters(ctx, mod, zPath, flags)
@@ -192,14 +169,10 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla
}
}
if err != nil {
return vfsErrorCode(err, _CANTOPEN)
}
vfsFileRegister(ctx, mod, pFile, file)
if pOutFlags != 0 {
util.WriteUint32(mod, pOutFlags, uint32(flags))
}
vfsFileRegister(ctx, mod, pFile, file)
return _OK
}
@@ -431,40 +404,22 @@ func vfsGet(mod api.Module, pVfs uint32) VFS {
panic(util.NoVFSErr + util.ErrorString(name))
}
func vfsFileNew(vfs *vfsState, file File) uint32 {
// Find an empty slot.
for id, f := range vfs.files {
if f == nil {
vfs.files[id] = file
return uint32(id)
}
}
// Add a new slot.
vfs.files = append(vfs.files, file)
return uint32(len(vfs.files) - 1)
}
func vfsFileRegister(ctx context.Context, mod api.Module, pFile uint32, file File) {
const fileHandleOffset = 4
id := vfsFileNew(ctx.Value(vfsKey{}).(*vfsState), file)
id := util.AddHandle(ctx, file)
util.WriteUint32(mod, pFile+fileHandleOffset, id)
}
func vfsFileGet(ctx context.Context, mod api.Module, pFile uint32) File {
const fileHandleOffset = 4
vfs := ctx.Value(vfsKey{}).(*vfsState)
id := util.ReadUint32(mod, pFile+fileHandleOffset)
return vfs.files[id]
return util.GetHandle(ctx, id).(File)
}
func vfsFileClose(ctx context.Context, mod api.Module, pFile uint32) error {
const fileHandleOffset = 4
vfs := ctx.Value(vfsKey{}).(*vfsState)
id := util.ReadUint32(mod, pFile+fileHandleOffset)
file := vfs.files[id]
vfs.files[id] = nil
return file.Close()
return util.DelHandle(ctx, id)
}
func vfsErrorCode(err error, def _ErrorCode) _ErrorCode {

View File

@@ -220,8 +220,7 @@ func Test_vfsAccess(t *testing.T) {
func Test_vfsFile(t *testing.T) {
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
ctx, vfs := NewContext(context.TODO())
defer vfs.Close()
ctx := util.NewContext(context.TODO())
// Open a temporary file.
rc := vfsOpen(ctx, mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)
@@ -293,8 +292,7 @@ func Test_vfsFile(t *testing.T) {
func Test_vfsFile_psow(t *testing.T) {
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
ctx, vfs := NewContext(context.TODO())
defer vfs.Close()
ctx := util.NewContext(context.TODO())
// Open a temporary file.
rc := vfsOpen(ctx, mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)