Compare commits

...

28 Commits

Author SHA1 Message Date
Nuno Cruces
56e8281bdb Time collation tests. 2023-03-10 16:42:20 +00:00
Nuno Cruces
f61d430e65 Documentation. 2023-03-10 16:26:19 +00:00
Nuno Cruces
dbaed53b9a Sync and delete improvements. 2023-03-10 14:17:02 +00:00
Nuno Cruces
8b1bfd04e3 Simplify windows hacks. 2023-03-10 10:43:02 +00:00
Nuno Cruces
11c1687146 Time collation. 2023-03-09 14:42:29 +00:00
Nuno Cruces
94c43a8685 Use access syscall. 2023-03-09 01:59:46 +00:00
Nuno Cruces
a25159a070 Fix sharing violation. 2023-03-09 01:23:52 +00:00
Nuno Cruces
e007e9b060 Windows fixes. 2023-03-08 20:10:46 +00:00
Nuno Cruces
66a730893f Fix readonly transaction rollback. 2023-03-08 18:07:21 +00:00
Nuno Cruces
926adeb3f5 Remove MustPrepare. 2023-03-08 17:39:41 +00:00
Nuno Cruces
677f51bec1 Savepoint API. 2023-03-08 17:39:23 +00:00
Nuno Cruces
5d6f92b733 Documentation, tests, tweaks. 2023-03-08 13:29:33 +00:00
Nuno Cruces
f5747f19fb Tests. 2023-03-07 14:19:22 +00:00
Nuno Cruces
dfcdbf9c4c Online backup. 2023-03-07 12:15:29 +00:00
Nuno Cruces
ad1e8f4b0e Refactor. 2023-03-07 10:47:55 +00:00
Nuno Cruces
8f29882671 Pass mptest crash. 2023-03-07 04:37:55 +00:00
Nuno Cruces
6c96a019e6 Towards shared modules: refactor. 2023-03-07 04:37:55 +00:00
Nuno Cruces
d291738b81 Towards shared modules: refactor. 2023-03-07 04:37:55 +00:00
Nuno Cruces
c1263d4f33 Towards shared modules: refactor. 2023-03-07 04:37:55 +00:00
Nuno Cruces
1ebdc1aa93 Towards shared modules: refactor. 2023-03-07 04:37:55 +00:00
Nuno Cruces
4dd10f071a Towards shared modules: backup. 2023-03-07 04:37:55 +00:00
Nuno Cruces
7dbddfa5c0 Towards shared modules. 2023-03-07 04:37:55 +00:00
dependabot[bot]
ce5e035801 Bump golang.org/x/sys from 0.5.0 to 0.6.0
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.5.0 to 0.6.0.
- [Release notes](https://github.com/golang/sys/releases)
- [Commits](https://github.com/golang/sys/compare/v0.5.0...v0.6.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-03-07 04:11:45 +00:00
Nuno Cruces
8bb8367a36 Refactor mptest. 2023-03-06 14:27:49 +00:00
Nuno Cruces
9f59b3d0ec Pass mptest multiwrite on Windows. 2023-03-05 14:33:51 +00:00
Nuno Cruces
5f893b5459 Add SQLite mptest. 2023-03-05 12:20:02 +00:00
Nuno Cruces
35b1a97b88 Use finalizers to detect unclosed connections. 2023-03-03 14:50:55 +00:00
Nuno Cruces
416c3863a0 Documentation, tests, dependencies. 2023-03-01 23:47:24 +00:00
59 changed files with 2635 additions and 1023 deletions

View File

@@ -15,6 +15,8 @@ jobs:
steps:
- uses: actions/checkout@v3
with:
lfs: 'true'
- name: Set up Go
uses: actions/setup-go@v3

5
.gitignore vendored
View File

@@ -13,7 +13,4 @@
# Dependency directories (remove the comment below to include it)
# vendor/
tools
# Project
sqlite3/sqlite3*
tools

View File

@@ -2,7 +2,7 @@
[![Go Reference](https://pkg.go.dev/badge/image)](https://pkg.go.dev/github.com/ncruces/go-sqlite3)
[![Go Report](https://goreportcard.com/badge/github.com/ncruces/go-sqlite3)](https://goreportcard.com/report/github.com/ncruces/go-sqlite3)
[![Go Coverage](https://github.com/ncruces/go-sqlite3/wiki/coverage.svg)](https://raw.githack.com/wiki/ncruces/go-sqlite3/coverage.html)
[![Go Coverage](https://github.com/ncruces/go-sqlite3/wiki/coverage.svg)](https://github.com/ncruces/go-sqlite3/wiki/Test-coverage-report)
Go module `github.com/ncruces/go-sqlite3` wraps a [WASM](https://webassembly.org/) build of [SQLite](https://sqlite.org/),
and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings.
@@ -18,6 +18,10 @@ embeds a build of SQLite into your application.
### Caveats
This module replaces the SQLite [OS Interface](https://www.sqlite.org/vfs.html) (aka VFS)
with a pure Go implementation.
This has numerous benefits, but also comes with some caveats.
#### Write-Ahead Logging
Because WASM does not support shared memory,
@@ -30,8 +34,10 @@ For non-WAL databases, `NORMAL` locking mode can be activated with
[`PRAGMA locking_mode=NORMAL`](https://www.sqlite.org/pragma.html#pragma_locking_mode).
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
the `database/sql` driver defaults to `NORMAL` locking mode,
and WAL databases are not supported.
the `database/sql` driver defaults to `NORMAL` locking mode.
To open WAL databases, or use `EXCLUSIVE` locking mode,
disable connection pooling by calling
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
#### Open File Description Locks
@@ -43,23 +49,22 @@ OFD locks are fully compatible with process-associated POSIX advisory locks,
and are supported on Linux, macOS and illumos.
As a work around for other Unixes, you can use [`nolock=1`](https://www.sqlite.org/uri.html).
#### Testing
The pure Go VFS is stress tested by running an unmodified build of SQLite's
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
on Linux, macOS and Windows.
### Roadmap
- [x] build SQLite using `zig cc --target=wasm32-wasi`
- [x] `:memory:` databases
- [x] port [`test_demovfs.c`](https://www.sqlite.org/src/doc/trunk/src/test_demovfs.c) to Go
- branch [`wasi`](https://github.com/ncruces/go-sqlite3/tree/wasi) uses `test_demovfs.c` directly
- [x] design a nice API, enough for simple use cases
- [x] provide a simple `database/sql` driver
- [x] file locking, compatible with SQLite on macOS/Linux/Windows
- [ ] advanced SQLite features
- [x] nested transactions
- [x] incremental BLOB I/O
- [ ] online backup
- [x] online backup
- [ ] snapshots
- [ ] session extension
- [ ] resumable bulk update
- [ ] shared cache mode
- [ ] shared-cache mode
- [ ] unlock-notify
- [ ] custom SQL functions
- [ ] custom VFSes

126
api.go
View File

@@ -1,126 +0,0 @@
// Package sqlite3 wraps the C SQLite API.
package sqlite3
import (
"context"
"github.com/tetratelabs/wazero/api"
)
func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) {
getFun := func(name string) api.Function {
f := module.ExportedFunction(name)
if f == nil {
err = noFuncErr + errorString(name)
return nil
}
return f
}
getVal := func(name string) uint32 {
global := module.ExportedGlobal(name)
if global == nil {
err = noGlobalErr + errorString(name)
return 0
}
return memory{module}.readUint32(uint32(global.Get()))
}
c := Conn{
ctx: ctx,
mem: memory{module},
api: sqliteAPI{
free: getFun("free"),
malloc: getFun("malloc"),
destructor: uint64(getVal("malloc_destructor")),
errcode: getFun("sqlite3_errcode"),
errstr: getFun("sqlite3_errstr"),
errmsg: getFun("sqlite3_errmsg"),
erroff: getFun("sqlite3_error_offset"),
open: getFun("sqlite3_open_v2"),
close: getFun("sqlite3_close"),
prepare: getFun("sqlite3_prepare_v3"),
finalize: getFun("sqlite3_finalize"),
reset: getFun("sqlite3_reset"),
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
bindName: getFun("sqlite3_bind_parameter_name"),
bindNull: getFun("sqlite3_bind_null"),
bindInteger: getFun("sqlite3_bind_int64"),
bindFloat: getFun("sqlite3_bind_double"),
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
columnInteger: getFun("sqlite3_column_int64"),
columnFloat: getFun("sqlite3_column_double"),
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
autocommit: getFun("sqlite3_get_autocommit"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
changes: getFun("sqlite3_changes64"),
blobOpen: getFun("sqlite3_blob_open"),
blobClose: getFun("sqlite3_blob_close"),
blobReopen: getFun("sqlite3_blob_reopen"),
blobBytes: getFun("sqlite3_blob_bytes"),
blobRead: getFun("sqlite3_blob_read"),
blobWrite: getFun("sqlite3_blob_write"),
interrupt: getVal("sqlite3_interrupt_offset"),
},
}
if err != nil {
return nil, err
}
return &c, nil
}
type sqliteAPI struct {
free api.Function
malloc api.Function
destructor uint64
errcode api.Function
errstr api.Function
errmsg api.Function
erroff api.Function
open api.Function
close api.Function
prepare api.Function
finalize api.Function
reset api.Function
step api.Function
exec api.Function
clearBindings api.Function
bindNull api.Function
bindCount api.Function
bindIndex api.Function
bindName api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
columnInteger api.Function
columnFloat api.Function
columnText api.Function
columnBlob api.Function
columnBytes api.Function
autocommit api.Function
lastRowid api.Function
changes api.Function
blobOpen api.Function
blobClose api.Function
blobReopen api.Function
blobBytes api.Function
blobRead api.Function
blobWrite api.Function
interrupt uint32
}

134
backup.go Normal file
View File

@@ -0,0 +1,134 @@
package sqlite3
// Backup is a handle to an open BLOB.
//
// https://www.sqlite.org/c3ref/backup.html
type Backup struct {
c *Conn
handle uint32
otherc uint32
}
// Backup backs up srcDB on the src connection to the "main" database in dstURI.
//
// Backup calls [Open] to open the SQLite database file dstURI,
// and blocks until the entire backup is complete.
// Use [Conn.BackupInit] for incremental backup.
//
// https://www.sqlite.org/backup.html
func (src *Conn) Backup(srcDB, dstURI string) error {
b, err := src.BackupInit(srcDB, dstURI)
if err != nil {
return err
}
defer b.Close()
_, err = b.Step(-1)
return err
}
// Restore restores dstDB on the dst connection from the "main" database in srcURI.
//
// Restore calls [Open] to open the SQLite database file srcURI,
// and blocks until the entire restore is complete.
//
// https://www.sqlite.org/backup.html
func (dst *Conn) Restore(dstDB, srcURI string) error {
src, err := dst.openDB(srcURI, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
if err != nil {
return err
}
b, err := dst.backupInit(dst.handle, dstDB, src, "main")
if err != nil {
return err
}
defer b.Close()
_, err = b.Step(-1)
return err
}
// BackupInit initializes a backup operation to copy the content of one database into another.
//
// BackupInit calls [Open] to open the SQLite database file dstURI,
// then initializes a backup that copies the contents of srcDB on the src connection
// to the "main" database in dstURI.
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupinit
func (src *Conn) BackupInit(srcDB, dstURI string) (*Backup, error) {
dst, err := src.openDB(dstURI, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
if err != nil {
return nil, err
}
return src.backupInit(dst, "main", src.handle, srcDB)
}
func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string) (*Backup, error) {
defer c.arena.reset()
dstPtr := c.arena.string(dstName)
srcPtr := c.arena.string(srcName)
other := dst
if c.handle == dst {
other = src
}
r := c.call(c.api.backupInit,
uint64(dst), uint64(dstPtr),
uint64(src), uint64(srcPtr))
if r[0] == 0 {
defer c.closeDB(other)
r = c.call(c.api.errcode, uint64(dst))
return nil, c.module.error(r[0], dst)
}
return &Backup{
c: c,
otherc: other,
handle: uint32(r[0]),
}, nil
}
// Close finishes a backup operation.
//
// It is safe to close a nil, zero or closed Backup.
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupfinish
func (b *Backup) Close() error {
if b == nil || b.handle == 0 {
return nil
}
r := b.c.call(b.c.api.backupFinish, uint64(b.handle))
b.c.closeDB(b.otherc)
b.handle = 0
return b.c.error(r[0])
}
// Step copies up to nPage pages between the source and destination databases.
// If nPage is negative, all remaining source pages are copied.
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupstep
func (b *Backup) Step(nPage int) (done bool, err error) {
r := b.c.call(b.c.api.backupStep, uint64(b.handle), uint64(nPage))
if r[0] == _DONE {
return true, nil
}
return false, b.c.error(r[0])
}
// Remaining returns the number of pages still to be backed up
// at the conclusion of the most recent [Backup.Step].
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupremaining
func (b *Backup) Remaining() int {
r := b.c.call(b.c.api.backupRemaining, uint64(b.handle))
return int(r[0])
}
// PageCount returns the total number of pages in the source database
// at the conclusion of the most recent [Backup.Step].
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backuppagecount
func (b *Backup) PageCount() int {
r := b.c.call(b.c.api.backupFinish, uint64(b.handle))
return int(r[0])
}

View File

@@ -25,6 +25,7 @@ var _ io.ReadWriteSeeker = &Blob{}
//
// https://www.sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
c.checkInterrupt()
defer c.arena.reset()
blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db)

View File

@@ -1,68 +0,0 @@
package sqlite3
import (
"context"
"crypto/rand"
"os"
"runtime"
"strconv"
"sync"
"sync/atomic"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
// Configure SQLite WASM.
//
// Importing package embed initializes these
// with an appropriate build of SQLite:
//
// import _ "github.com/ncruces/go-sqlite3/embed"
var (
Binary []byte // WASM binary to load.
Path string // Path to load the binary from.
)
var sqlite3 sqlite3Runtime
type sqlite3Runtime struct {
once sync.Once
runtime wazero.Runtime
compiled wazero.CompiledModule
instances atomic.Uint64
err error
}
func (s *sqlite3Runtime) instantiateModule(ctx context.Context) (api.Module, error) {
s.once.Do(func() { s.compileModule(ctx) })
if s.err != nil {
return nil, s.err
}
cfg := wazero.NewModuleConfig().
WithName("sqlite3-" + strconv.FormatUint(s.instances.Add(1), 10)).
WithSysWalltime().WithSysNanotime().WithSysNanosleep().
WithOsyield(runtime.Gosched).
WithRandSource(rand.Reader)
return s.runtime.InstantiateModule(ctx, s.compiled, cfg)
}
func (s *sqlite3Runtime) compileModule(ctx context.Context) {
s.runtime = wazero.NewRuntime(ctx)
vfsInstantiate(ctx, s.runtime)
bin := Binary
if bin == nil && Path != "" {
bin, s.err = os.ReadFile(Path)
if s.err != nil {
return
}
}
if bin == nil {
s.err = binaryErr
return
}
s.compiled, s.err = s.runtime.CompileModule(ctx, bin)
}

240
conn.go
View File

@@ -3,14 +3,13 @@ package sqlite3
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"math"
"net/url"
"runtime"
"strings"
"sync/atomic"
"unsafe"
"github.com/tetratelabs/wazero/api"
)
// Conn is a database connection handle.
@@ -18,11 +17,9 @@ import (
//
// https://www.sqlite.org/c3ref/sqlite3.html
type Conn struct {
ctx context.Context
api sqliteAPI
mem memory
handle uint32
*module
handle uint32
arena arena
interrupt context.Context
waiter chan struct{}
@@ -30,8 +27,8 @@ type Conn struct {
}
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI].
func Open(filename string) (conn *Conn, err error) {
return OpenFlags(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
func Open(filename string) (*Conn, error) {
return newConn(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
}
// OpenFlags opens an SQLite database file as specified by the filename argument.
@@ -41,33 +38,43 @@ func Open(filename string) (conn *Conn, err error) {
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
//
// https://www.sqlite.org/c3ref/open.html
func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
ctx := context.Background()
module, err := sqlite3.instantiateModule(ctx)
func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
return newConn(filename, flags)
}
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
mod, err := instantiateModule()
if err != nil {
return nil, err
}
defer func() {
if conn == nil {
module.Close(ctx)
mod.close()
} else {
runtime.SetFinalizer(conn, finalizer[Conn](3))
}
}()
c, err := newConn(ctx, module)
c := &Conn{module: mod}
c.arena = c.newArena(1024)
c.handle, err = c.openDB(filename, flags)
if err != nil {
return nil, err
}
c.arena = c.newArena(1024)
return c, nil
}
func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
defer c.arena.reset()
connPtr := c.arena.new(ptrlen)
namePtr := c.arena.string(filename)
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
c.handle = c.mem.readUint32(connPtr)
if err := c.error(r[0]); err != nil {
return nil, err
handle := c.mem.readUint32(connPtr)
if err := c.module.error(r[0], handle); err != nil {
c.closeDB(handle)
return 0, err
}
if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
@@ -80,11 +87,28 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
pragmas.WriteByte(';')
}
}
if err := c.Exec(pragmas.String()); err != nil {
return nil, fmt.Errorf("sqlite3: invalid _pragma: %w", err)
c.arena.reset()
pragmaPtr := c.arena.string(pragmas.String())
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.module.error(r[0], handle, pragmas.String()); err != nil {
if errors.Is(err, ERROR) {
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
c.closeDB(handle)
return 0, err
}
}
return c, nil
c.call(c.api.timeCollation, uint64(handle))
return handle, nil
}
func (c *Conn) closeDB(handle uint32) {
r := c.call(c.api.closeZombie, uint64(handle))
if err := c.module.error(r[0], handle); err != nil {
panic(err)
}
}
// Close closes the database connection.
@@ -102,6 +126,8 @@ func (c *Conn) Close() error {
}
c.SetInterrupt(context.Background())
c.pending.Close()
c.pending = nil
r := c.call(c.api.close, uint64(c.handle))
if err := c.error(r[0]); err != nil {
@@ -109,7 +135,8 @@ func (c *Conn) Close() error {
}
c.handle = 0
return c.mem.mod.Close(c.ctx)
runtime.SetFinalizer(c, nil)
return c.module.close()
}
// Exec is a convenience function that allows an application to run
@@ -125,22 +152,6 @@ func (c *Conn) Exec(sql string) error {
return c.error(r[0])
}
// MustPrepare calls [Conn.Prepare] and panics on error,
// a nil Stmt, or a non-empty tail.
func (c *Conn) MustPrepare(sql string) *Stmt {
s, tail, err := c.PrepareFlags(sql, 0)
if err != nil {
panic(err)
}
if s == nil {
panic(emptyErr)
}
if !emptyStatement(tail) {
panic(tailErr)
}
return s
}
// Prepare calls [Conn.PrepareFlags] with no flags.
func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
return c.PrepareFlags(sql, 0)
@@ -228,26 +239,23 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
<-c.waiter // Wait for it to finish.
c.waiter = nil
}
// Reset the pending statement.
if c.pending != nil {
c.pending.Reset()
}
old = c.interrupt
c.interrupt = ctx
if ctx == nil || ctx.Done() == nil {
// Finalize the uncompleted SQL statement.
if c.pending != nil {
c.pending.Close()
c.pending = nil
}
return old
}
// Creating an uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
if c.pending == nil {
c.pending = c.MustPrepare(`SELECT 1 UNION ALL SELECT 2`)
c.pending.Step()
} else {
c.pending.Reset()
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
}
c.pending.Step()
// Don't create the goroutine if we're already interrupted.
// This happens frequently while restoring to a previously interrupted state.
@@ -287,144 +295,22 @@ func (c *Conn) checkInterrupt() bool {
// Pragma executes a PRAGMA statement and returns any results.
//
// https://www.sqlite.org/pragma.html
func (c *Conn) Pragma(str string) []string {
stmt := c.MustPrepare(`PRAGMA ` + str)
func (c *Conn) Pragma(str string) ([]string, error) {
stmt, _, err := c.Prepare(`PRAGMA ` + str)
if err != nil {
return nil, err
}
defer stmt.Close()
var pragmas []string
for stmt.Step() {
pragmas = append(pragmas, stmt.ColumnText(0))
}
return pragmas
return pragmas, stmt.Close()
}
func (c *Conn) error(rc uint64, sql ...string) error {
if rc == _OK {
return nil
}
err := Error{code: rc}
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
panic(oomErr)
}
var r []uint64
r, _ = c.api.errstr.Call(c.ctx, rc)
if r != nil {
err.str = c.mem.readString(uint32(r[0]), _MAX_STRING)
}
r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle))
if r != nil {
err.msg = c.mem.readString(uint32(r[0]), _MAX_STRING)
}
if sql != nil {
r, _ = c.api.erroff.Call(c.ctx, uint64(c.handle))
if r != nil && r[0] != math.MaxUint32 {
err.sql = sql[0][r[0]:]
}
}
switch err.msg {
case err.str, "not an error":
err.msg = ""
}
return &err
}
func (c *Conn) call(fn api.Function, params ...uint64) []uint64 {
r, err := fn.Call(c.ctx, params...)
if err != nil {
panic(err)
}
return r
}
func (c *Conn) free(ptr uint32) {
if ptr == 0 {
return
}
c.call(c.api.free, uint64(ptr))
}
func (c *Conn) new(size uint64) uint32 {
if size > _MAX_ALLOCATION_SIZE {
panic(oomErr)
}
r := c.call(c.api.malloc, size)
ptr := uint32(r[0])
if ptr == 0 && size != 0 {
panic(oomErr)
}
return ptr
}
func (c *Conn) newBytes(b []byte) uint32 {
if b == nil {
return 0
}
ptr := c.new(uint64(len(b)))
c.mem.writeBytes(ptr, b)
return ptr
}
func (c *Conn) newString(s string) uint32 {
ptr := c.new(uint64(len(s) + 1))
c.mem.writeString(ptr, s)
return ptr
}
func (c *Conn) newArena(size uint64) arena {
return arena{
c: c,
base: c.new(size),
size: uint32(size),
}
}
type arena struct {
c *Conn
base uint32
next uint32
size uint32
ptrs []uint32
}
func (a *arena) free() {
if a.c == nil {
return
}
a.reset()
a.c.free(a.base)
a.c = nil
}
func (a *arena) reset() {
for _, ptr := range a.ptrs {
a.c.free(ptr)
}
a.ptrs = nil
a.next = 0
}
func (a *arena) new(size uint64) uint32 {
if size <= uint64(a.size-a.next) {
ptr := a.base + a.next
a.next += uint32(size)
return ptr
}
ptr := a.c.new(size)
a.ptrs = append(a.ptrs, ptr)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint64(len(s) + 1))
a.c.mem.writeString(ptr, s)
return ptr
return c.module.error(rc, c.handle, sql...)
}
// DriverConn is implemented by the SQLite [database/sql] driver connection.
@@ -439,6 +325,6 @@ type DriverConn interface {
driver.ExecerContext
driver.ConnPrepareContext
Savepoint() (release func(*error))
Savepoint() Savepoint
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
}

View File

@@ -133,6 +133,7 @@ const (
CONSTRAINT_DATATYPE ExtendedErrorCode = xErrorCode(CONSTRAINT) | (12 << 8)
NOTICE_RECOVER_WAL ExtendedErrorCode = xErrorCode(NOTICE) | (1 << 8)
NOTICE_RECOVER_ROLLBACK ExtendedErrorCode = xErrorCode(NOTICE) | (2 << 8)
NOTICE_RBU ExtendedErrorCode = xErrorCode(NOTICE) | (3 << 8)
WARNING_AUTOINDEX ExtendedErrorCode = xErrorCode(WARNING) | (1 << 8)
AUTH_USER ExtendedErrorCode = xErrorCode(AUTH) | (1 << 8)
)
@@ -167,14 +168,6 @@ const (
OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */
)
type _AccessFlag uint32
const (
_ACCESS_EXISTS _AccessFlag = 0
_ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
_ACCESS_READ _AccessFlag = 2 /* Unused */
)
// PrepareFlag is a flag that can be passed to [Conn.PrepareFlags].
//
// https://www.sqlite.org/c3ref/c_prepare_normalize.html
@@ -216,3 +209,19 @@ func (t Datatype) String() string {
}
return strconv.FormatUint(uint64(t), 10)
}
type _AccessFlag uint32
const (
_ACCESS_EXISTS _AccessFlag = 0
_ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
_ACCESS_READ _AccessFlag = 2 /* Unused */
)
type _SyncFlag uint32
const (
_SYNC_NORMAL _SyncFlag = 0x00002
_SYNC_FULL _SyncFlag = 0x00003
_SYNC_DATAONLY _SyncFlag = 0x00010
)

View File

@@ -86,15 +86,17 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
}
type conn struct {
conn *sqlite3.Conn
txBegin string
txCommit string
conn *sqlite3.Conn
txBegin string
txCommit string
txRollback string
}
var (
// Ensure these interfaces are implemented:
_ driver.ExecerContext = conn{}
_ driver.ConnBeginTx = conn{}
_ driver.Validator = conn{}
_ sqlite3.DriverConn = conn{}
)
@@ -102,27 +104,49 @@ func (c conn) Close() error {
return c.conn.Close()
}
func (c conn) IsValid() (valid bool) {
r, err := c.conn.Pragma("locking_mode")
return err == nil && len(r) == 1 && r[0] == "normal"
}
func (c conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
switch opts.Isolation {
default:
return nil, isolationErr
case driver.IsolationLevel(sql.LevelDefault):
case driver.IsolationLevel(sql.LevelSerializable):
}
func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, error) {
txBegin := c.txBegin
c.txCommit = `COMMIT`
c.txRollback = `ROLLBACK`
if opts.ReadOnly {
c.txCommit = `
ROLLBACK;
PRAGMA query_only=` + c.conn.Pragma("query_only")[0]
query_only, err := c.conn.Pragma("query_only")
if err != nil {
return nil, err
}
txBegin = `
BEGIN deferred;
PRAGMA query_only=on`
c.txCommit = `
ROLLBACK;
PRAGMA query_only=` + query_only[0]
c.txRollback = c.txCommit
}
switch opts.Isolation {
default:
return nil, isolationErr
case
driver.IsolationLevel(sql.LevelDefault),
driver.IsolationLevel(sql.LevelSerializable):
break
case driver.IsolationLevel(sql.LevelReadUncommitted):
read_uncommitted, err := c.conn.Pragma("read_uncommitted")
if err != nil {
return nil, err
}
txBegin += `; PRAGMA read_uncommitted=on`
c.txCommit += `; PRAGMA read_uncommitted=` + read_uncommitted[0]
c.txRollback += `; PRAGMA read_uncommitted=` + read_uncommitted[0]
}
err := c.conn.Exec(txBegin)
@@ -134,14 +158,14 @@ func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, er
func (c conn) Commit() error {
err := c.conn.Exec(c.txCommit)
if err != nil {
if err != nil && !c.conn.GetAutocommit() {
c.Rollback()
}
return err
}
func (c conn) Rollback() error {
return c.conn.Exec(`ROLLBACK`)
return c.conn.Exec(c.txRollback)
}
func (c conn) Prepare(query string) (driver.Stmt, error) {
@@ -189,7 +213,7 @@ func (c conn) ExecContext(ctx context.Context, query string, args []driver.Named
}, nil
}
func (c conn) Savepoint() (release func(*error)) {
func (c conn) Savepoint() sqlite3.Savepoint {
return c.conn.Savepoint()
}

View File

@@ -1,4 +1,3 @@
// Package driver provides a database/sql driver for SQLite.
package driver
import (
@@ -134,7 +133,9 @@ func Test_BeginTx(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.db"))
db, err := sql.Open("sqlite3", "file:"+
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
"?_txlock=exclusive&_pragma=busy_timeout(0)")
if err != nil {
t.Fatal(err)
}
@@ -145,6 +146,16 @@ func Test_BeginTx(t *testing.T) {
t.Error("want isolationErr")
}
tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadUncommitted})
if err != nil {
t.Fatal(err)
}
err = tx.Rollback()
if err != nil {
t.Fatal(err)
}
tx1, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
t.Fatal(err)

View File

@@ -28,8 +28,8 @@ func Example() {
if err != nil {
log.Fatal(err)
}
defer db.Close()
defer os.Remove("./recordings.db")
defer db.Close()
// Create a table with some data in it.
err = albumsSetup()

View File

@@ -20,8 +20,8 @@ func ExampleDriverConn() {
if err != nil {
log.Fatal(err)
}
defer db.Close()
defer os.Remove("demo.db")
defer db.Close()
ctx := context.Background()
@@ -48,7 +48,8 @@ func ExampleDriverConn() {
err = conn.Raw(func(driverConn any) error {
conn := driverConn.(sqlite3.DriverConn)
defer conn.Savepoint()(&err)
savept := conn.Savepoint()
defer savept.Release(&err)
blob, err := conn.OpenBlob("main", "test", "col", id, true)
if err != nil {

15
embed/README.md Normal file
View File

@@ -0,0 +1,15 @@
# Embeddable WASM build of SQLite
This folder includes an embeddable WASM build of SQLite 3.41.0 for use with
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
The following optional features are compiled in:
- math functions
- FTS3/4/5
- JSON
- R*Tree
- GeoPoly
See the [configuration options](../sqlite3/sqlite_cfg.h).
Built using [`zig`](https://ziglang.org/) version 0.10.1.

View File

@@ -13,51 +13,4 @@ zig cc --target=wasm32-wasi -flto -g0 -Os \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-D_HAVE_SQLITE_CONFIG_H \
-Wl,--export=free \
-Wl,--export=malloc \
-Wl,--export=malloc_destructor \
-Wl,--export=sqlite3_errcode \
-Wl,--export=sqlite3_errstr \
-Wl,--export=sqlite3_errmsg \
-Wl,--export=sqlite3_error_offset \
-Wl,--export=sqlite3_open_v2 \
-Wl,--export=sqlite3_close \
-Wl,--export=sqlite3_prepare_v3 \
-Wl,--export=sqlite3_finalize \
-Wl,--export=sqlite3_reset \
-Wl,--export=sqlite3_step \
-Wl,--export=sqlite3_exec \
-Wl,--export=sqlite3_clear_bindings \
-Wl,--export=sqlite3_bind_parameter_count \
-Wl,--export=sqlite3_bind_parameter_index \
-Wl,--export=sqlite3_bind_parameter_name \
-Wl,--export=sqlite3_bind_null \
-Wl,--export=sqlite3_bind_int64 \
-Wl,--export=sqlite3_bind_double \
-Wl,--export=sqlite3_bind_text64 \
-Wl,--export=sqlite3_bind_blob64 \
-Wl,--export=sqlite3_bind_zeroblob64 \
-Wl,--export=sqlite3_column_count \
-Wl,--export=sqlite3_column_name \
-Wl,--export=sqlite3_column_type \
-Wl,--export=sqlite3_column_int64 \
-Wl,--export=sqlite3_column_double \
-Wl,--export=sqlite3_column_text \
-Wl,--export=sqlite3_column_blob \
-Wl,--export=sqlite3_column_bytes \
-Wl,--export=sqlite3_blob_open \
-Wl,--export=sqlite3_blob_close \
-Wl,--export=sqlite3_blob_bytes \
-Wl,--export=sqlite3_blob_read \
-Wl,--export=sqlite3_blob_write \
-Wl,--export=sqlite3_blob_reopen \
-Wl,--export=sqlite3_get_autocommit \
-Wl,--export=sqlite3_last_insert_rowid \
-Wl,--export=sqlite3_changes64 \
-Wl,--export=sqlite3_unlock_notify \
-Wl,--export=sqlite3_backup_init \
-Wl,--export=sqlite3_backup_step \
-Wl,--export=sqlite3_backup_finish \
-Wl,--export=sqlite3_backup_remaining \
-Wl,--export=sqlite3_backup_pagecount \
-Wl,--export=sqlite3_interrupt_offset \
$(awk '{print "-Wl,--export="$0}' exports.txt)

50
embed/exports.txt Normal file
View File

@@ -0,0 +1,50 @@
free
malloc
malloc_destructor
sqlite3_errcode
sqlite3_errstr
sqlite3_errmsg
sqlite3_error_offset
sqlite3_open_v2
sqlite3_close
sqlite3_close_v2
sqlite3_prepare_v3
sqlite3_finalize
sqlite3_reset
sqlite3_step
sqlite3_exec
sqlite3_clear_bindings
sqlite3_bind_parameter_count
sqlite3_bind_parameter_index
sqlite3_bind_parameter_name
sqlite3_bind_null
sqlite3_bind_int64
sqlite3_bind_double
sqlite3_bind_text64
sqlite3_bind_blob64
sqlite3_bind_zeroblob64
sqlite3_column_count
sqlite3_column_name
sqlite3_column_type
sqlite3_column_int64
sqlite3_column_double
sqlite3_column_text
sqlite3_column_blob
sqlite3_column_bytes
sqlite3_blob_open
sqlite3_blob_close
sqlite3_blob_bytes
sqlite3_blob_read
sqlite3_blob_write
sqlite3_blob_reopen
sqlite3_get_autocommit
sqlite3_last_insert_rowid
sqlite3_changes64
sqlite3_unlock_notify
sqlite3_backup_init
sqlite3_backup_step
sqlite3_backup_finish
sqlite3_backup_remaining
sqlite3_backup_pagecount
sqlite3_time_collation
sqlite3_interrupt_offset

View File

@@ -4,9 +4,6 @@
// with an appropriate build of SQLite:
//
// import _ "github.com/ncruces/go-sqlite3/embed"
//
// You can obtain this build of SQLite from:
// https://github.com/ncruces/go-sqlite3/tree/main/embed
package embed
import (

Binary file not shown.

View File

@@ -1,6 +1,7 @@
package sqlite3
import (
"fmt"
"runtime"
"strconv"
"strings"
@@ -201,8 +202,6 @@ const (
noFuncErr = errorString("sqlite3: could not find function: ")
binaryErr = errorString("sqlite3: no SQLite binary embed/set/loaded")
timeErr = errorString("sqlite3: invalid time value")
emptyErr = errorString("sqlite3: empty statement")
tailErr = errorString("sqlite3: non-empty tail")
notImplErr = errorString("sqlite3: not implemented")
whenceErr = errorString("sqlite3: invalid whence")
offsetErr = errorString("sqlite3: invalid offset")
@@ -215,3 +214,11 @@ func assertErr() errorString {
}
return errorString(msg)
}
func finalizer[T any](skip int) func(*T) {
msg := fmt.Sprintf("sqlite3: %T not closed", new(T))
if _, file, line, ok := runtime.Caller(skip + 1); ok && skip >= 0 {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return func(*T) { panic(errorString(msg)) }
}

View File

@@ -1,7 +1,6 @@
package sqlite3
import (
"context"
"errors"
"strings"
"testing"
@@ -9,7 +8,7 @@ import (
func Test_assertErr(t *testing.T) {
err := assertErr()
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:11)") {
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:10)") {
t.Errorf("got %q", s)
}
}
@@ -120,10 +119,8 @@ func Test_ErrorCode_Error(t *testing.T) {
// Test all error codes.
for i := 0; i == int(ErrorCode(i)); i++ {
want := "sqlite3: "
r, _ := db.api.errstr.Call(context.TODO(), uint64(i))
if r != nil {
want += db.mem.readString(uint32(r[0]), _MAX_STRING)
}
r := db.call(db.api.errstr, uint64(i))
want += db.mem.readString(uint32(r[0]), _MAX_STRING)
got := ErrorCode(i).Error()
if got != want {
@@ -144,10 +141,8 @@ func Test_ExtendedErrorCode_Error(t *testing.T) {
// Test all extended error codes.
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
want := "sqlite3: "
r, _ := db.api.errstr.Call(context.TODO(), uint64(i))
if r != nil {
want += db.mem.readString(uint32(r[0]), _MAX_STRING)
}
r := db.call(db.api.errstr, uint64(i))
want += db.mem.readString(uint32(r[0]), _MAX_STRING)
got := ExtendedErrorCode(i).Error()
if got != want {

View File

@@ -26,7 +26,11 @@ func Example() {
log.Fatal(err)
}
stmt := db.MustPrepare(`SELECT id, name FROM users`)
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnInt(0), stmt.ColumnText(1))

6
go.mod
View File

@@ -4,7 +4,9 @@ go 1.19
require (
github.com/ncruces/julianday v0.1.5
github.com/tetratelabs/wazero v1.0.0-pre.9
github.com/tetratelabs/wazero v1.0.0-rc.1
golang.org/x/sync v0.1.0
golang.org/x/sys v0.5.0
golang.org/x/sys v0.6.0
)
retract v0.4.0 // tagged from the wrong branch

8
go.sum
View File

@@ -1,8 +1,8 @@
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.0.0-pre.9 h1:2uVdi2bvTi/JQxG2cp3LRm2aRadd3nURn5jcfbvqZcw=
github.com/tetratelabs/wazero v1.0.0-pre.9/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
github.com/tetratelabs/wazero v1.0.0-rc.1 h1:ytecMV5Ue0BwezjKh/cM5yv1Mo49ep2R2snSsQUyToc=
github.com/tetratelabs/wazero v1.0.0-rc.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

354
module.go Normal file
View File

@@ -0,0 +1,354 @@
// Package sqlite3 wraps the C SQLite API.
package sqlite3
import (
"context"
"crypto/rand"
"io"
"math"
"os"
"runtime"
"strconv"
"sync"
"sync/atomic"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
// Configure SQLite WASM.
//
// Importing package embed initializes these
// with an appropriate build of SQLite:
//
// import _ "github.com/ncruces/go-sqlite3/embed"
var (
Binary []byte // WASM binary to load.
Path string // Path to load the binary from.
)
var sqlite3 struct {
once sync.Once
runtime wazero.Runtime
compiled wazero.CompiledModule
instances atomic.Uint64
err error
}
func instantiateModule() (*module, error) {
ctx := context.Background()
sqlite3.once.Do(compileModule)
if sqlite3.err != nil {
return nil, sqlite3.err
}
name := "sqlite3-" + strconv.FormatUint(sqlite3.instances.Add(1), 10)
cfg := wazero.NewModuleConfig().WithName(name).
WithSysWalltime().WithSysNanotime().WithSysNanosleep().
WithOsyield(runtime.Gosched).
WithRandSource(rand.Reader)
mod, err := sqlite3.runtime.InstantiateModule(ctx, sqlite3.compiled, cfg)
if err != nil {
return nil, err
}
return newModule(mod)
}
func compileModule() {
ctx := context.Background()
sqlite3.runtime = wazero.NewRuntime(ctx)
vfsInstantiate(ctx, sqlite3.runtime)
bin := Binary
if bin == nil && Path != "" {
bin, sqlite3.err = os.ReadFile(Path)
if sqlite3.err != nil {
return
}
}
if bin == nil {
sqlite3.err = binaryErr
return
}
sqlite3.compiled, sqlite3.err = sqlite3.runtime.CompileModule(ctx, bin)
}
type module struct {
ctx context.Context
mem memory
api sqliteAPI
vfs io.Closer
}
func newModule(mod api.Module) (m *module, err error) {
m = &module{}
m.mem = memory{mod}
m.ctx, m.vfs = vfsContext(context.Background())
getFun := func(name string) api.Function {
f := mod.ExportedFunction(name)
if f == nil {
err = noFuncErr + errorString(name)
return nil
}
return f
}
getVal := func(name string) uint32 {
global := mod.ExportedGlobal(name)
if global == nil {
err = noGlobalErr + errorString(name)
return 0
}
return m.mem.readUint32(uint32(global.Get()))
}
m.api = sqliteAPI{
free: getFun("free"),
malloc: getFun("malloc"),
destructor: uint64(getVal("malloc_destructor")),
errcode: getFun("sqlite3_errcode"),
errstr: getFun("sqlite3_errstr"),
errmsg: getFun("sqlite3_errmsg"),
erroff: getFun("sqlite3_error_offset"),
open: getFun("sqlite3_open_v2"),
close: getFun("sqlite3_close"),
closeZombie: getFun("sqlite3_close_v2"),
prepare: getFun("sqlite3_prepare_v3"),
finalize: getFun("sqlite3_finalize"),
reset: getFun("sqlite3_reset"),
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
bindName: getFun("sqlite3_bind_parameter_name"),
bindNull: getFun("sqlite3_bind_null"),
bindInteger: getFun("sqlite3_bind_int64"),
bindFloat: getFun("sqlite3_bind_double"),
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
columnInteger: getFun("sqlite3_column_int64"),
columnFloat: getFun("sqlite3_column_double"),
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
autocommit: getFun("sqlite3_get_autocommit"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
changes: getFun("sqlite3_changes64"),
blobOpen: getFun("sqlite3_blob_open"),
blobClose: getFun("sqlite3_blob_close"),
blobReopen: getFun("sqlite3_blob_reopen"),
blobBytes: getFun("sqlite3_blob_bytes"),
blobRead: getFun("sqlite3_blob_read"),
blobWrite: getFun("sqlite3_blob_write"),
backupInit: getFun("sqlite3_backup_init"),
backupStep: getFun("sqlite3_backup_step"),
backupFinish: getFun("sqlite3_backup_finish"),
backupRemaining: getFun("sqlite3_backup_remaining"),
backupPageCount: getFun("sqlite3_backup_pagecount"),
timeCollation: getFun("sqlite3_time_collation"),
interrupt: getVal("sqlite3_interrupt_offset"),
}
if err != nil {
return nil, err
}
return m, nil
}
func (m *module) close() error {
err := m.mem.mod.Close(m.ctx)
m.vfs.Close()
return err
}
func (m *module) error(rc uint64, handle uint32, sql ...string) error {
if rc == _OK {
return nil
}
err := Error{code: rc}
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
panic(oomErr)
}
var r []uint64
r = m.call(m.api.errstr, rc)
if r != nil {
err.str = m.mem.readString(uint32(r[0]), _MAX_STRING)
}
r = m.call(m.api.errmsg, uint64(handle))
if r != nil {
err.msg = m.mem.readString(uint32(r[0]), _MAX_STRING)
}
if sql != nil {
r = m.call(m.api.erroff, uint64(handle))
if r != nil && r[0] != math.MaxUint32 {
err.sql = sql[0][r[0]:]
}
}
switch err.msg {
case err.str, "not an error":
err.msg = ""
}
return &err
}
func (m *module) call(fn api.Function, params ...uint64) []uint64 {
r, err := fn.Call(m.ctx, params...)
if err != nil {
// The module closed or panicked; release resources.
m.vfs.Close()
panic(err)
}
return r
}
func (m *module) free(ptr uint32) {
if ptr == 0 {
return
}
m.call(m.api.free, uint64(ptr))
}
func (m *module) new(size uint64) uint32 {
if size > _MAX_ALLOCATION_SIZE {
panic(oomErr)
}
r := m.call(m.api.malloc, size)
ptr := uint32(r[0])
if ptr == 0 && size != 0 {
panic(oomErr)
}
return ptr
}
func (m *module) newBytes(b []byte) uint32 {
if b == nil {
return 0
}
ptr := m.new(uint64(len(b)))
m.mem.writeBytes(ptr, b)
return ptr
}
func (m *module) newString(s string) uint32 {
ptr := m.new(uint64(len(s) + 1))
m.mem.writeString(ptr, s)
return ptr
}
func (m *module) newArena(size uint64) arena {
return arena{
m: m,
base: m.new(size),
size: uint32(size),
}
}
type arena struct {
m *module
base uint32
next uint32
size uint32
ptrs []uint32
}
func (a *arena) free() {
if a.m == nil {
return
}
a.reset()
a.m.free(a.base)
a.m = nil
}
func (a *arena) reset() {
for _, ptr := range a.ptrs {
a.m.free(ptr)
}
a.ptrs = nil
a.next = 0
}
func (a *arena) new(size uint64) uint32 {
if size <= uint64(a.size-a.next) {
ptr := a.base + a.next
a.next += uint32(size)
return ptr
}
ptr := a.m.new(size)
a.ptrs = append(a.ptrs, ptr)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint64(len(s) + 1))
a.m.mem.writeString(ptr, s)
return ptr
}
type sqliteAPI struct {
free api.Function
malloc api.Function
destructor uint64
errcode api.Function
errstr api.Function
errmsg api.Function
erroff api.Function
open api.Function
close api.Function
closeZombie api.Function
prepare api.Function
finalize api.Function
reset api.Function
step api.Function
exec api.Function
clearBindings api.Function
bindNull api.Function
bindCount api.Function
bindIndex api.Function
bindName api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
columnInteger api.Function
columnFloat api.Function
columnText api.Function
columnBlob api.Function
columnBytes api.Function
autocommit api.Function
lastRowid api.Function
changes api.Function
blobOpen api.Function
blobClose api.Function
blobReopen api.Function
blobBytes api.Function
blobRead api.Function
blobWrite api.Function
backupInit api.Function
backupStep api.Function
backupFinish api.Function
backupRemaining api.Function
backupPageCount api.Function
timeCollation api.Function
interrupt uint32
}

View File

@@ -9,43 +9,43 @@ import (
func TestConn_error_OOM(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
defer func() { _ = recover() }()
db.error(uint64(NOMEM))
m.error(uint64(NOMEM), 0)
t.Error("want panic")
}
func TestConn_call_nil(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
defer func() { _ = recover() }()
db.call(db.api.free)
m.call(m.api.free)
t.Error("want panic")
}
func TestConn_new(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
testOOM := func(size uint64) {
defer func() { _ = recover() }()
db.new(size)
m.new(size)
t.Error("want panic")
}
@@ -56,13 +56,13 @@ func TestConn_new(t *testing.T) {
func TestConn_newArena(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
arena := db.newArena(16)
arena := m.newArena(16)
defer arena.free()
const title = "Lorem ipsum"
@@ -71,7 +71,7 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := db.mem.readString(ptr, math.MaxUint32); got != title {
if got := m.mem.readString(ptr, math.MaxUint32); got != title {
t.Errorf("got %q, want %q", got, title)
}
@@ -80,7 +80,7 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := db.mem.readString(ptr, math.MaxUint32); got != body {
if got := m.mem.readString(ptr, math.MaxUint32); got != body {
t.Errorf("got %q, want %q", got, body)
}
arena.free()
@@ -89,25 +89,25 @@ func TestConn_newArena(t *testing.T) {
func TestConn_newBytes(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
ptr := db.newBytes(nil)
ptr := m.newBytes(nil)
if ptr != 0 {
t.Errorf("got %#x, want nullptr", ptr)
}
buf := []byte("sqlite3")
ptr = db.newBytes(buf)
ptr = m.newBytes(buf)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := buf
if got := db.mem.view(ptr, uint64(len(want))); !bytes.Equal(got, want) {
if got := m.mem.view(ptr, uint64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -115,25 +115,25 @@ func TestConn_newBytes(t *testing.T) {
func TestConn_newString(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
ptr := db.newString("")
ptr := m.newString("")
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
str := "sqlite3\000sqlite3"
ptr = db.newString(str)
ptr = m.newString(str)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := str + "\000"
if got := db.mem.view(ptr, uint64(len(want))); string(got) != want {
if got := m.mem.view(ptr, uint64(len(want))); string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -141,40 +141,40 @@ func TestConn_newString(t *testing.T) {
func TestConn_getString(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
ptr := db.newString("")
ptr := m.newString("")
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
str := "sqlite3" + "\000 drop this"
ptr = db.newString(str)
ptr = m.newString(str)
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
want := "sqlite3"
if got := db.mem.readString(ptr, math.MaxUint32); got != want {
if got := m.mem.readString(ptr, math.MaxUint32); got != want {
t.Errorf("got %q, want %q", got, want)
}
if got := db.mem.readString(ptr, 0); got != "" {
if got := m.mem.readString(ptr, 0); got != "" {
t.Errorf("got %q, want empty", got)
}
func() {
defer func() { _ = recover() }()
db.mem.readString(ptr, uint32(len(want)/2))
m.mem.readString(ptr, uint32(len(want)/2))
t.Error("want panic")
}()
func() {
defer func() { _ = recover() }()
db.mem.readString(0, math.MaxUint32)
m.mem.readString(0, math.MaxUint32)
t.Error("want panic")
}()
}
@@ -182,18 +182,18 @@ func TestConn_getString(t *testing.T) {
func TestConn_free(t *testing.T) {
t.Parallel()
db, err := Open(":memory:")
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer m.close()
db.free(0)
m.free(0)
ptr := db.new(1)
ptr := m.new(1)
if ptr == 0 {
t.Error("got nullptr, want a pointer")
}
db.free(ptr)
m.free(ptr)
}

3
sqlite3/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
sqlite3.c
sqlite3.h
sqlite3ext.h

View File

@@ -3,6 +3,8 @@
#include "main.c"
#include "os.c"
#include "qsort.c"
#include "time.c"
#include "sqlite3.c"
sqlite3_destructor_type malloc_destructor = &free;

View File

@@ -33,7 +33,9 @@
// We set the default locking mode to EXCLUSIVE instead.
// https://www.sqlite.org/wal.html#noshm
#undef SQLITE_OMIT_WAL
#ifndef SQLITE_DEFAULT_LOCKING_MODE
#define SQLITE_DEFAULT_LOCKING_MODE 1
#endif
// Recommended Extensions

28
sqlite3/time.c Normal file
View File

@@ -0,0 +1,28 @@
#include <string.h>
#include "sqlite3.h"
static int time_collation(void *pArg, int nKey1, const void *pKey1, int nKey2,
const void *pKey2) {
// If keys are of different length, and both terminated by a Z,
// ignore the Z for collation purposes.
if (nKey1 && nKey2 && nKey1 != nKey2) {
const char *pK1 = (const char *)pKey1;
const char *pK2 = (const char *)pKey2;
if (pK1[nKey1 - 1] == 'Z' && pK2[nKey2 - 1] == 'Z') {
nKey1--;
nKey2--;
}
}
int n = nKey1 < nKey2 ? nKey1 : nKey2;
int rc = memcmp(pKey1, pKey2, n);
if (rc == 0) {
rc = nKey1 - nKey2;
}
return rc;
}
int sqlite3_time_collation(sqlite3 *db) {
return sqlite3_create_collation(db, "TIME", SQLITE_UTF8, 0, time_collation);
}

View File

@@ -339,7 +339,7 @@ func (s *Stmt) ColumnText(col int) string {
ptr := uint32(r[0])
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.handle))
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r[0])
return ""
}
@@ -362,7 +362,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
ptr := uint32(r[0])
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.handle))
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r[0])
return buf[0:0]
}

127
tests/backup_test.go Normal file
View File

@@ -0,0 +1,127 @@
package tests
import (
"path/filepath"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestBackup(t *testing.T) {
t.Parallel()
backupName := filepath.Join(t.TempDir(), "backup.db")
func() { // Create backup.
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}
err = db.Backup("main", backupName)
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}()
func() { // Restore backup.
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Restore("main", backupName)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
row := 0
ids := []int{0, 1, 2}
names := []string{"go", "zig", "whatever"}
for ; stmt.Step(); row++ {
id := stmt.ColumnInt(0)
name := stmt.ColumnText(1)
if id != ids[row] {
t.Errorf("got %d, want %d", id, ids[row])
}
if name != names[row] {
t.Errorf("got %q, want %q", name, names[row])
}
}
if row != 3 {
t.Errorf("got %d, want %d", row, len(ids))
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}()
func() { // Errors.
db, err := sqlite3.Open(backupName)
if err != nil {
t.Fatal(err)
}
defer db.Close()
tx, err := db.BeginExclusive()
if err != nil {
t.Fatal(err)
}
err = db.Restore("main", backupName)
if err == nil {
t.Fatal("want error")
}
err = tx.Rollback()
if err != nil {
t.Fatal(err)
}
err = db.Restore("main", backupName)
if err == nil {
t.Fatal("want error")
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}()
}

View File

@@ -202,59 +202,3 @@ func TestConn_Prepare_invalid(t *testing.T) {
t.Error("got message:", got)
}
}
func TestConn_MustPrepare_empty(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(``)
t.Error("want panic")
}
func TestConn_MustPrepare_tail(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(`SELECT 1; -- HERE`)
t.Error("want panic")
}
func TestConn_MustPrepare_invalid(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(`SELECT`)
t.Error("want panic")
}
func TestConn_Pragma(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.Pragma("encoding=''")
t.Error("want panic")
}

182
tests/mptest/mptest_test.go Normal file
View File

@@ -0,0 +1,182 @@
package mptest
import (
"bytes"
"context"
"crypto/rand"
"embed"
"io"
"io/fs"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
_ "unsafe"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
_ "github.com/ncruces/go-sqlite3"
)
//go:embed testdata/mptest.wasm
var binary []byte
//go:embed testdata/*.*test
var scripts embed.FS
//go:linkname vfsNewEnvModuleBuilder github.com/ncruces/go-sqlite3.vfsNewEnvModuleBuilder
func vfsNewEnvModuleBuilder(r wazero.Runtime) wazero.HostModuleBuilder
//go:linkname vfsContext github.com/ncruces/go-sqlite3.vfsContext
func vfsContext(ctx context.Context) (context.Context, io.Closer)
var (
rt wazero.Runtime
module wazero.CompiledModule
instances atomic.Uint64
)
func init() {
ctx := context.TODO()
rt = wazero.NewRuntime(ctx)
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
env := vfsNewEnvModuleBuilder(rt)
env.NewFunctionBuilder().WithFunc(system).Export("system")
_, err := env.Instantiate(ctx)
if err != nil {
panic(err)
}
module, err = rt.CompileModule(ctx, binary)
if err != nil {
panic(err)
}
}
func config(ctx context.Context) wazero.ModuleConfig {
name := strconv.FormatUint(instances.Add(1), 10)
log := ctx.Value(logger{}).(io.Writer)
fs, err := fs.Sub(scripts, "testdata")
if err != nil {
panic(err)
}
return wazero.NewModuleConfig().
WithName(name).WithStdout(log).WithStderr(log).WithFS(fs).
WithSysWalltime().WithSysNanotime().WithSysNanosleep().
WithOsyield(runtime.Gosched).
WithRandSource(rand.Reader)
}
func system(ctx context.Context, mod api.Module, ptr uint32) uint32 {
buf, _ := mod.Memory().Read(ptr, mod.Memory().Size()-ptr)
buf = buf[:bytes.IndexByte(buf, 0)]
args := strings.Split(string(buf), " ")
for i := range args {
args[i] = strings.Trim(args[i], `"`)
}
args = args[:len(args)-1]
cfg := config(ctx).WithArgs(args...)
go func() {
ctx, vfs := vfsContext(ctx)
rt.InstantiateModule(ctx, module, cfg)
vfs.Close()
}()
return 0
}
func Test_config01(t *testing.T) {
ctx, vfs := vfsContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "config01.test")
_, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
vfs.Close()
}
func Test_config02(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
if os.Getenv("CI") != "" {
t.Skip("skipping in CI")
}
ctx, vfs := vfsContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "config02.test")
_, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
vfs.Close()
}
func Test_crash01(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
ctx, vfs := vfsContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "crash01.test")
_, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
vfs.Close()
}
func Test_multiwrite01(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
ctx, vfs := vfsContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "multiwrite01.test")
_, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
vfs.Close()
}
func newContext(t *testing.T) context.Context {
return context.WithValue(context.Background(), logger{}, &testWriter{T: t})
}
type logger struct{}
type testWriter struct {
*testing.T
buf []byte
mtx sync.Mutex
}
func (l *testWriter) Write(p []byte) (n int, err error) {
l.mtx.Lock()
defer l.mtx.Unlock()
l.buf = append(l.buf, p...)
for {
before, after, found := bytes.Cut(l.buf, []byte("\n"))
if !found {
return len(p), nil
}
l.Logf("%s", before)
l.buf = after
}
}

2
tests/mptest/testdata/.gitattributes vendored Normal file
View File

@@ -0,0 +1,2 @@
mptest.wasm filter=lfs diff=lfs merge=lfs -text
*.*test -crlf

1
tests/mptest/testdata/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
mptest.c

26
tests/mptest/testdata/build.sh vendored Executable file
View File

@@ -0,0 +1,26 @@
#!/usr/bin/env bash
set -eo pipefail
cd -P -- "$(dirname -- "$0")"
if [ ! -f "mptest.c" ]; then
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/mptest.c"
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/config01.test"
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/config02.test"
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/crash01.test"
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/crash02.subtest"
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/multiwrite01.test"
fi
zig cc --target=wasm32-wasi -flto -g0 -Os \
-o mptest.wasm main.c test.c \
-I../../../sqlite3 \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-D_HAVE_SQLITE_CONFIG_H \
-DSQLITE_DEFAULT_SYNCHRONOUS=0 \
-DSQLITE_DEFAULT_LOCKING_MODE=0 \
-DHAVE_USLEEP -DSQLITE_NO_SYNC \
-DSQLITE_THREADSAFE=0 -DSQLITE_OMIT_LOAD_EXTENSION \
-D_WASI_EMULATED_GETPID -lwasi-emulated-getpid

46
tests/mptest/testdata/config01.test vendored Normal file
View File

@@ -0,0 +1,46 @@
/*
** Configure five tasks in different ways, then run tests.
*/
--if vfsname() GLOB 'unix'
PRAGMA page_size=8192;
--task 1
PRAGMA journal_mode=PERSIST;
PRAGMA mmap_size=0;
--end
--task 2
PRAGMA journal_mode=TRUNCATE;
PRAGMA mmap_size=28672;
--end
--task 3
PRAGMA journal_mode=MEMORY;
--end
--task 4
PRAGMA journal_mode=OFF;
--end
--task 4
PRAGMA mmap_size(268435456);
--end
--source multiwrite01.test
--wait all
PRAGMA page_size=16384;
VACUUM;
CREATE TABLE pgsz(taskid, sz INTEGER);
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--wait all
SELECT sz FROM pgsz;
--match 16384 16384 16384 16384 16384

123
tests/mptest/testdata/config02.test vendored Normal file
View File

@@ -0,0 +1,123 @@
/*
** Configure five tasks in different ways, then run tests.
*/
PRAGMA page_size=512;
--task 1
PRAGMA mmap_size=0;
--end
--task 2
PRAGMA mmap_size=28672;
--end
--task 3
PRAGMA mmap_size=8192;
--end
--task 4
PRAGMA mmap_size=65536;
--end
--task 5
PRAGMA mmap_size=268435456;
--end
--source multiwrite01.test
--source crash02.subtest
PRAGMA page_size=1024;
VACUUM;
CREATE TABLE pgsz(taskid, sz INTEGER);
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--source crash02.subtest
--wait all
SELECT sz FROM pgsz;
--match 1024 1024 1024 1024 1024
PRAGMA page_size=2048;
VACUUM;
DELETE FROM pgsz;
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--source crash02.subtest
--wait all
SELECT sz FROM pgsz;
--match 2048 2048 2048 2048 2048
PRAGMA page_size=8192;
VACUUM;
DELETE FROM pgsz;
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--source crash02.subtest
--wait all
SELECT sz FROM pgsz;
--match 8192 8192 8192 8192 8192
PRAGMA page_size=16384;
VACUUM;
DELETE FROM pgsz;
--task 1
INSERT INTO pgsz VALUES(1, eval('PRAGMA page_size'));
--end
--task 2
INSERT INTO pgsz VALUES(2, eval('PRAGMA page_size'));
--end
--task 3
INSERT INTO pgsz VALUES(3, eval('PRAGMA page_size'));
--end
--task 4
INSERT INTO pgsz VALUES(4, eval('PRAGMA page_size'));
--end
--task 5
INSERT INTO pgsz VALUES(5, eval('PRAGMA page_size'));
--end
--source multiwrite01.test
--source crash02.subtest
--wait all
SELECT sz FROM pgsz;
--match 16384 16384 16384 16384 16384
PRAGMA auto_vacuum=FULL;
VACUUM;
--source multiwrite01.test
--source crash02.subtest
--wait all
PRAGMA auto_vacuum=FULL;
PRAGMA page_size=512;
VACUUM;
--source multiwrite01.test
--source crash02.subtest

106
tests/mptest/testdata/crash01.test vendored Normal file
View File

@@ -0,0 +1,106 @@
/* Test cases involving incomplete transactions that must be rolled back.
*/
--task 1
DROP TABLE IF EXISTS t1;
CREATE TABLE t1(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t1 VALUES(1, randomblob(2000));
INSERT INTO t1 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t1 SELECT a+2, randomblob(1500) FROM t1;
INSERT INTO t1 SELECT a+4, randomblob(1500) FROM t1;
INSERT INTO t1 SELECT a+8, randomblob(1500) FROM t1;
--sleep 1
INSERT INTO t1 SELECT a+16, randomblob(1500) FROM t1;
--sleep 1
INSERT INTO t1 SELECT a+32, randomblob(1500) FROM t1;
SELECT count(*) FROM t1;
--match 64
SELECT avg(length(b)) FROM t1;
--match 1500.0
--sleep 2
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t1;
--match 247
SELECT a FROM t1 WHERE b='x17y';
--match 17
CREATE INDEX t1b ON t1(b);
SELECT a FROM t1 WHERE b='x17y';
--match 17
SELECT a FROM t1 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--wait 1
--task 2
DROP TABLE IF EXISTS t2;
CREATE TABLE t2(a INTEGER PRIMARY KEY, b);
INSERT INTO t2 SELECT a, b FROM t1;
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t2;
--match 247
SELECT a FROM t2 WHERE b='x17y';
--match 17
CREATE INDEX t2b ON t2(b);
SELECT a FROM t2 WHERE b='x17y';
--match 17
SELECT a FROM t2 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 3
DROP TABLE IF EXISTS t3;
CREATE TABLE t3(a INTEGER PRIMARY KEY, b);
INSERT INTO t3 SELECT a, b FROM t1;
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t3;
--match 247
SELECT a FROM t3 WHERE b='x17y';
--match 17
CREATE INDEX t3b ON t3(b);
SELECT a FROM t3 WHERE b='x17y';
--match 17
SELECT a FROM t3 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 4
DROP TABLE IF EXISTS t4;
CREATE TABLE t4(a INTEGER PRIMARY KEY, b);
INSERT INTO t4 SELECT a, b FROM t1;
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t4;
--match 247
SELECT a FROM t4 WHERE b='x17y';
--match 17
CREATE INDEX t4b ON t4(b);
SELECT a FROM t4 WHERE b='x17y';
--match 17
SELECT a FROM t4 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 5
DROP TABLE IF EXISTS t5;
CREATE TABLE t5(a INTEGER PRIMARY KEY, b);
INSERT INTO t5 SELECT a, b FROM t1;
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t5;
--match 247
SELECT a FROM t5 WHERE b='x17y';
--match 17
CREATE INDEX t5b ON t5(b);
SELECT a FROM t5 WHERE b='x17y';
--match 17
SELECT a FROM t5 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--wait all
/* After the database file has been set up, run the crash2 subscript
** multiple times. */
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest
--source crash02.subtest

53
tests/mptest/testdata/crash02.subtest vendored Normal file
View File

@@ -0,0 +1,53 @@
/*
** This script is called from crash01.test and config02.test and perhaps other
** script. After the database file has been set up, make a big rollback
** journal in client 1, then crash client 1.
** Then in the other clients, do an integrity check.
*/
--task 1 leave-hot-journal
--sleep 5
--finish
PRAGMA cache_size=10;
BEGIN;
UPDATE t1 SET b=randomblob(20000);
UPDATE t2 SET b=randomblob(20000);
UPDATE t3 SET b=randomblob(20000);
UPDATE t4 SET b=randomblob(20000);
UPDATE t5 SET b=randomblob(20000);
UPDATE t1 SET b=NULL;
UPDATE t2 SET b=NULL;
UPDATE t3 SET b=NULL;
UPDATE t4 SET b=NULL;
UPDATE t5 SET b=NULL;
--print Task one crashing an incomplete transaction
--exit 1
--end
--task 2 integrity_check-2
SELECT count(*) FROM t1;
--match 64
--sleep 100
PRAGMA integrity_check(10);
--match ok
--end
--task 3 integrity_check-3
SELECT count(*) FROM t1;
--match 64
--sleep 100
PRAGMA integrity_check(10);
--match ok
--end
--task 4 integrity_check-4
SELECT count(*) FROM t1;
--match 64
--sleep 100
PRAGMA integrity_check(10);
--match ok
--end
--task 5 integrity_check-5
SELECT count(*) FROM t1;
--match 64
--sleep 100
PRAGMA integrity_check(10);
--match ok
--end
--wait all

22
tests/mptest/testdata/main.c vendored Normal file
View File

@@ -0,0 +1,22 @@
#include <stdbool.h>
#include <stddef.h>
#include "os.c"
#include "qsort.c"
#include "sqlite3.c"
sqlite3_destructor_type malloc_destructor = &free;
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);
void __attribute__((constructor)) premain() { sqlite3_initialize(); }
int sqlite3_enable_load_extension(sqlite3 *db, int onoff) { return SQLITE_OK; }
void *sqlite3_trace(sqlite3 *db, void (*xTrace)(void *, const char *),
void *pArg) {
return NULL;
}
int sqlite3_os_init() {
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
}

3
tests/mptest/testdata/mptest.wasm vendored Executable file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3960b873a7dab969a66f7859d491cec0dd4e6c0c9f83eab449fb15ec5ebdfd8f
size 1077281

415
tests/mptest/testdata/multiwrite01.test vendored Normal file
View File

@@ -0,0 +1,415 @@
/*
** This script sets up five different tasks all writing and updating
** the database at the same time, but each in its own table.
*/
--task 1 build-t1
DROP TABLE IF EXISTS t1;
CREATE TABLE t1(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t1 VALUES(1, randomblob(2000));
INSERT INTO t1 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t1 SELECT a+2, randomblob(1500) FROM t1;
INSERT INTO t1 SELECT a+4, randomblob(1500) FROM t1;
INSERT INTO t1 SELECT a+8, randomblob(1500) FROM t1;
--sleep 1
INSERT INTO t1 SELECT a+16, randomblob(1500) FROM t1;
--sleep 1
INSERT INTO t1 SELECT a+32, randomblob(1500) FROM t1;
SELECT count(*) FROM t1;
--match 64
SELECT avg(length(b)) FROM t1;
--match 1500.0
--sleep 2
UPDATE t1 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t1;
--match 247
SELECT a FROM t1 WHERE b='x17y';
--match 17
CREATE INDEX t1b ON t1(b);
SELECT a FROM t1 WHERE b='x17y';
--match 17
SELECT a FROM t1 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 2 build-t2
DROP TABLE IF EXISTS t2;
CREATE TABLE t2(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t2 VALUES(1, randomblob(2000));
INSERT INTO t2 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t2 SELECT a+2, randomblob(1500) FROM t2;
INSERT INTO t2 SELECT a+4, randomblob(1500) FROM t2;
INSERT INTO t2 SELECT a+8, randomblob(1500) FROM t2;
--sleep 1
INSERT INTO t2 SELECT a+16, randomblob(1500) FROM t2;
--sleep 1
INSERT INTO t2 SELECT a+32, randomblob(1500) FROM t2;
SELECT count(*) FROM t2;
--match 64
SELECT avg(length(b)) FROM t2;
--match 1500.0
--sleep 2
UPDATE t2 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t2;
--match 247
SELECT a FROM t2 WHERE b='x17y';
--match 17
CREATE INDEX t2b ON t2(b);
SELECT a FROM t2 WHERE b='x17y';
--match 17
SELECT a FROM t2 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 3 build-t3
DROP TABLE IF EXISTS t3;
CREATE TABLE t3(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t3 VALUES(1, randomblob(2000));
INSERT INTO t3 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t3 SELECT a+2, randomblob(1500) FROM t3;
INSERT INTO t3 SELECT a+4, randomblob(1500) FROM t3;
INSERT INTO t3 SELECT a+8, randomblob(1500) FROM t3;
--sleep 1
INSERT INTO t3 SELECT a+16, randomblob(1500) FROM t3;
--sleep 1
INSERT INTO t3 SELECT a+32, randomblob(1500) FROM t3;
SELECT count(*) FROM t3;
--match 64
SELECT avg(length(b)) FROM t3;
--match 1500.0
--sleep 2
UPDATE t3 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t3;
--match 247
SELECT a FROM t3 WHERE b='x17y';
--match 17
CREATE INDEX t3b ON t3(b);
SELECT a FROM t3 WHERE b='x17y';
--match 17
SELECT a FROM t3 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 4 build-t4
DROP TABLE IF EXISTS t4;
CREATE TABLE t4(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t4 VALUES(1, randomblob(2000));
INSERT INTO t4 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t4 SELECT a+2, randomblob(1500) FROM t4;
INSERT INTO t4 SELECT a+4, randomblob(1500) FROM t4;
INSERT INTO t4 SELECT a+8, randomblob(1500) FROM t4;
--sleep 1
INSERT INTO t4 SELECT a+16, randomblob(1500) FROM t4;
--sleep 1
INSERT INTO t4 SELECT a+32, randomblob(1500) FROM t4;
SELECT count(*) FROM t4;
--match 64
SELECT avg(length(b)) FROM t4;
--match 1500.0
--sleep 2
UPDATE t4 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t4;
--match 247
SELECT a FROM t4 WHERE b='x17y';
--match 17
CREATE INDEX t4b ON t4(b);
SELECT a FROM t4 WHERE b='x17y';
--match 17
SELECT a FROM t4 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--task 5 build-t5
DROP TABLE IF EXISTS t5;
CREATE TABLE t5(a INTEGER PRIMARY KEY, b);
--sleep 1
INSERT INTO t5 VALUES(1, randomblob(2000));
INSERT INTO t5 VALUES(2, randomblob(1000));
--sleep 1
INSERT INTO t5 SELECT a+2, randomblob(1500) FROM t5;
INSERT INTO t5 SELECT a+4, randomblob(1500) FROM t5;
INSERT INTO t5 SELECT a+8, randomblob(1500) FROM t5;
--sleep 1
INSERT INTO t5 SELECT a+16, randomblob(1500) FROM t5;
--sleep 1
INSERT INTO t5 SELECT a+32, randomblob(1500) FROM t5;
SELECT count(*) FROM t5;
--match 64
SELECT avg(length(b)) FROM t5;
--match 1500.0
--sleep 2
UPDATE t5 SET b='x'||a||'y';
SELECT sum(length(b)) FROM t5;
--match 247
SELECT a FROM t5 WHERE b='x17y';
--match 17
CREATE INDEX t5b ON t5(b);
SELECT a FROM t5 WHERE b='x17y';
--match 17
SELECT a FROM t5 WHERE b GLOB 'x2?y' ORDER BY b DESC LIMIT 5;
--match 29 28 27 26 25
--end
--wait all
SELECT count(*), sum(length(b)) FROM t1;
--match 64 247
SELECT count(*), sum(length(b)) FROM t2;
--match 64 247
SELECT count(*), sum(length(b)) FROM t3;
--match 64 247
SELECT count(*), sum(length(b)) FROM t4;
--match 64 247
SELECT count(*), sum(length(b)) FROM t5;
--match 64 247
--task 1
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 5
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 3
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 2
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 4
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--wait all
--task 5
DROP INDEX t5b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t5b ON t5(b DESC);
--end
--task 3
DROP INDEX t3b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t3b ON t3(b DESC);
--end
--task 1
DROP INDEX t1b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t1b ON t1(b DESC);
--end
--task 2
DROP INDEX t2b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t2b ON t2(b DESC);
--end
--task 4
DROP INDEX t4b;
--sleep 5
PRAGMA integrity_check(10);
--match ok
CREATE INDEX t4b ON t4(b DESC);
--end
--wait all
--task 1
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 5
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 3
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 2
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--task 4
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
--end
--wait all
VACUUM;
PRAGMA integrity_check(10);
--match ok
--task 1
UPDATE t1 SET b=randomblob(20000);
--sleep 5
UPDATE t1 SET b='x'||a||'y';
SELECT a FROM t1 WHERE b='x63y';
--match 63
--end
--task 2
UPDATE t2 SET b=randomblob(20000);
--sleep 5
UPDATE t2 SET b='x'||a||'y';
SELECT a FROM t2 WHERE b='x63y';
--match 63
--end
--task 3
UPDATE t3 SET b=randomblob(20000);
--sleep 5
UPDATE t3 SET b='x'||a||'y';
SELECT a FROM t3 WHERE b='x63y';
--match 63
--end
--task 4
UPDATE t4 SET b=randomblob(20000);
--sleep 5
UPDATE t4 SET b='x'||a||'y';
SELECT a FROM t4 WHERE b='x63y';
--match 63
--end
--task 5
UPDATE t5 SET b=randomblob(20000);
--sleep 5
UPDATE t5 SET b='x'||a||'y';
SELECT a FROM t5 WHERE b='x63y';
--match 63
--end
--wait all
--task 1
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--task 5
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--task 3
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--task 2
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--task 4
SELECT t1.a FROM t1, t2
WHERE t2.b GLOB 'x3?y' AND t1.b=('x'||(t2.a+3)||'y')
ORDER BY t1.a LIMIT 4
--match 33 34 35 36
SELECT t3.a FROM t3, t4
WHERE t4.b GLOB 'x4?y' AND t3.b=('x'||(t4.a+5)||'y')
ORDER BY t3.a LIMIT 7
--match 45 46 47 48 49 50 51
PRAGMA integrity_check;
--match ok
--end
--wait all

5
tests/mptest/testdata/test.c vendored Normal file
View File

@@ -0,0 +1,5 @@
#define unlink dont_unlink
#include "mptest.c"
int dont_unlink(const char *pathname) { return 0; }

View File

@@ -14,8 +14,15 @@ import (
)
func TestParallel(t *testing.T) {
var iter int
if testing.Short() {
iter = 1000
} else {
iter = 5000
}
name := filepath.Join(t.TempDir(), "test.db")
testParallel(t, name, 1000)
testParallel(t, name, iter)
testIntegrity(t, name)
}
@@ -135,7 +142,7 @@ func testParallel(t *testing.T, name string, n int) {
}
var group errgroup.Group
group.SetLimit(4)
group.SetLimit(6)
for i := 0; i < n; i++ {
if i&7 != 7 {
group.Go(reader)

169
tests/time_test.go Normal file
View File

@@ -0,0 +1,169 @@
package tests
import (
"reflect"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
)
func TestTimeFormat_Encode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt sqlite3.TimeFormat
time time.Time
want any
}{
{sqlite3.TimeFormatDefault, reference, "2013-10-07T04:23:19.12-04:00"},
{sqlite3.TimeFormatJulianDay, reference, 2456572.849526851851852},
{sqlite3.TimeFormatUnix, reference, int64(1381134199)},
{sqlite3.TimeFormatUnixFrac, reference, 1381134199.120},
{sqlite3.TimeFormatUnixMilli, reference, int64(1381134199_120)},
{sqlite3.TimeFormatUnixMicro, reference, int64(1381134199_120000)},
{sqlite3.TimeFormatUnixNano, reference, int64(1381134199_120000000)},
{sqlite3.TimeFormat7, reference, "2013-10-07T08:23:19.120"},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
if got := tt.fmt.Encode(tt.time); !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q.Encode(%v) = %v, want %v", tt.fmt, tt.time, got, tt.want)
}
})
}
}
func TestTimeFormat_Decode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt sqlite3.TimeFormat
val any
want time.Time
wantDelta time.Duration
wantErr bool
}{
{sqlite3.TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
{sqlite3.TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
{sqlite3.TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
{sqlite3.TimeFormatJulianDay, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
{sqlite3.TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
{sqlite3.TimeFormatUnix, int64(1381134199), reference, time.Second, false},
{sqlite3.TimeFormatUnix, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnix, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMilli, "1381134199120", reference, 0, false},
{sqlite3.TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
{sqlite3.TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
{sqlite3.TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMilli, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
{sqlite3.TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
{sqlite3.TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
{sqlite3.TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixMicro, false, time.Time{}, 0, true},
{sqlite3.TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
{sqlite3.TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
{sqlite3.TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
{sqlite3.TimeFormatUnixNano, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatUnixNano, false, time.Time{}, 0, true},
{sqlite3.TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
{sqlite3.TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
{sqlite3.TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
{sqlite3.TimeFormatAuto, "1381134199", reference, time.Second, false},
{sqlite3.TimeFormatAuto, "1381134199120", reference, 0, false},
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, false},
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, false},
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, true},
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, false},
{sqlite3.TimeFormat3, false, time.Time{}, 0, true},
{sqlite3.TimeFormat9, false, time.Time{}, 0, true},
{sqlite3.TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
{sqlite3.TimeFormatDefault, false, time.Time{}, 0, true},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
got, err := tt.fmt.Decode(tt.val)
if (err != nil) != tt.wantErr {
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
return
}
if tt.want.Sub(got).Abs() > tt.wantDelta {
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
}
})
}
}
func TestDB_timeCollation(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS times (tstamp COLLATE TIME)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`INSERT INTO times VALUES (?), (?), (?)`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
stmt.BindTime(1, time.Unix(0, 0).UTC(), sqlite3.TimeFormatDefault)
stmt.BindTime(2, time.Unix(0, -1).UTC(), sqlite3.TimeFormatDefault)
stmt.BindTime(3, time.Unix(0, +1).UTC(), sqlite3.TimeFormatDefault)
stmt.Step()
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
stmt, _, err = db.Prepare(`SELECT tstamp FROM times ORDER BY tstamp`)
if err != nil {
t.Fatal(err)
}
var t0 time.Time
for stmt.Step() {
t1 := stmt.ColumnTime(0, sqlite3.TimeFormatAuto)
if t0.After(t1) {
t.Errorf("got %v after %v", t0, t1)
}
t0 = t1
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -29,6 +29,7 @@ func TestConn_Transaction_exec(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
return stmt.ColumnInt(0)
}
@@ -117,6 +118,7 @@ func TestConn_Transaction_panic(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
@@ -183,10 +185,10 @@ func TestConn_Transaction_interrupt(t *testing.T) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
var nilErr error
tx.End(&nilErr)
if !errors.Is(nilErr, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", nilErr)
err = nil
tx.End(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
db.SetInterrupt(context.Background())
@@ -208,6 +210,33 @@ func TestConn_Transaction_interrupt(t *testing.T) {
}
}
func TestConn_Transaction_interrupted(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
cancel()
tx := db.Begin()
err = tx.Commit()
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = nil
tx.End(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
}
func TestConn_Transaction_rollback(t *testing.T) {
t.Parallel()
@@ -275,6 +304,7 @@ func TestConn_Savepoint_exec(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
return stmt.ColumnInt(0)
}
@@ -283,7 +313,7 @@ func TestConn_Savepoint_exec(t *testing.T) {
}
insert := func(succeed bool) (err error) {
defer db.Savepoint()(&err)
defer db.Savepoint().Release(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
@@ -341,7 +371,7 @@ func TestConn_Savepoint_panic(t *testing.T) {
}
panics := func() (err error) {
defer db.Savepoint()(&err)
defer db.Savepoint().Release(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
@@ -361,6 +391,7 @@ func TestConn_Savepoint_panic(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
got := stmt.ColumnInt(0)
if got != 1 {
@@ -391,12 +422,12 @@ func TestConn_Savepoint_interrupt(t *testing.T) {
t.Fatal(err)
}
release := db.Savepoint()
savept := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
}
release(&err)
savept.Release(&err)
if err != nil {
t.Fatal(err)
}
@@ -404,19 +435,19 @@ func TestConn_Savepoint_interrupt(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
release1 := db.Savepoint()
savept1 := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (2)`)
if err != nil {
t.Fatal(err)
}
release2 := db.Savepoint()
savept2 := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (3)`)
if err != nil {
t.Fatal(err)
}
cancel()
db.Savepoint()(&err)
db.Savepoint().Release(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
@@ -427,15 +458,15 @@ func TestConn_Savepoint_interrupt(t *testing.T) {
}
err = context.Canceled
release2(&err)
savept2.Release(&err)
if err != context.Canceled {
t.Fatal(err)
}
var nilErr error
release1(&nilErr)
if !errors.Is(nilErr, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", nilErr)
err = nil
savept1.Release(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
db.SetInterrupt(context.Background())
@@ -471,7 +502,7 @@ func TestConn_Savepoint_rollback(t *testing.T) {
t.Fatal(err)
}
release := db.Savepoint()
savept := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
@@ -480,7 +511,7 @@ func TestConn_Savepoint_rollback(t *testing.T) {
if err != nil {
t.Fatal(err)
}
release(&err)
savept.Release(&err)
if err != nil {
t.Fatal(err)
}

View File

@@ -1,19 +1,23 @@
package sqlite3
package tests
import "testing"
import (
"testing"
"github.com/ncruces/go-sqlite3"
)
func TestDatatype_String(t *testing.T) {
t.Parallel()
tests := []struct {
data Datatype
data sqlite3.Datatype
want string
}{
{INTEGER, "INTEGER"},
{FLOAT, "FLOAT"},
{TEXT, "TEXT"},
{BLOB, "BLOB"},
{NULL, "NULL"},
{sqlite3.INTEGER, "INTEGER"},
{sqlite3.FLOAT, "FLOAT"},
{sqlite3.TEXT, "TEXT"},
{sqlite3.BLOB, "BLOB"},
{sqlite3.NULL, "NULL"},
{10, "10"},
}
for _, tt := range tests {

19
time.go
View File

@@ -62,13 +62,18 @@ const (
// [TimeFormatDefault] and [TimeFormatAuto] encode using [time.RFC3339Nano],
// with nanosecond accuracy, and preserving any timezone offset.
//
// This is the format used by the database/sql driver:
// [database/sql.Row.Scan] is able to decode as [time.Time]
// This is the format used by the [database/sql] driver:
// [database/sql.Row.Scan] will decode as [time.Time]
// values encoded with [time.RFC3339Nano].
//
// Time values encoded with [time.RFC3339Nano] cannot be sorted as strings
// to produce a time-ordered sequence.
// Use [TimeFormat7] for time-ordered encoding.
//
// Assuming that the time zones of the time values are the same (e.g., all in UTC),
// and expressed using the same string (e.g., all "Z" or all "+00:00"),
// use the TIME [collating sequence] to produce a time-ordered sequence.
//
// Otherwise, use [TimeFormat7] for time-ordered encoding.
//
// Formats [TimeFormat1] through [TimeFormat10]
// convert time values to UTC before encoding.
@@ -78,6 +83,8 @@ const (
// or an int64 for the other numeric formats.
//
// https://www.sqlite.org/lang_datefunc.html
//
// [collating sequence]: https://www.sqlite.org/datatype3.html#collating_sequences
func (f TimeFormat) Encode(t time.Time) any {
switch f {
// Numeric formats
@@ -123,9 +130,9 @@ func (f TimeFormat) Encode(t time.Time) any {
// [TimeFormatAuto] implements (and extends) the SQLite auto modifier.
// Julian day numbers are safe to use for historical dates,
// from 4712BC through 9999AD.
// Unix timestamps (expressed in seconds, milliseconds, microseconds, or nanoseconds),
// are safe to use for current events, from 1980 through at least 2260.
// Unix timestamps before 1980 may be misinterpreted as julian day numbers,
// Unix timestamps (expressed in seconds, milliseconds, microseconds, or nanoseconds)
// are safe to use for current events, from at least 1980 through at least 2260.
// Unix timestamps before 1980 and after 9999 may be misinterpreted as julian day numbers,
// or have the wrong time unit.
//
// https://www.sqlite.org/lang_datefunc.html

View File

@@ -1,118 +0,0 @@
package sqlite3
import (
"reflect"
"testing"
"time"
)
func TestTimeFormat_Encode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt TimeFormat
time time.Time
want any
}{
{TimeFormatDefault, reference, "2013-10-07T04:23:19.12-04:00"},
{TimeFormatJulianDay, reference, 2456572.849526851851852},
{TimeFormatUnix, reference, int64(1381134199)},
{TimeFormatUnixFrac, reference, 1381134199.120},
{TimeFormatUnixMilli, reference, int64(1381134199_120)},
{TimeFormatUnixMicro, reference, int64(1381134199_120000)},
{TimeFormatUnixNano, reference, int64(1381134199_120000000)},
{TimeFormat7, reference, "2013-10-07T08:23:19.120"},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
if got := tt.fmt.Encode(tt.time); !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q.Encode(%v) = %v, want %v", tt.fmt, tt.time, got, tt.want)
}
})
}
}
func TestTimeFormat_Decode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
reftime := time.Date(2000, 1, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt TimeFormat
val any
want time.Time
wantDelta time.Duration
wantErr bool
}{
{TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false},
{TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false},
{TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false},
{TimeFormatJulianDay, false, time.Time{}, 0, true},
{TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false},
{TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false},
{TimeFormatUnix, int64(1381134199), reference, time.Second, false},
{TimeFormatUnix, "abc", time.Time{}, 0, true},
{TimeFormatUnix, false, time.Time{}, 0, true},
{TimeFormatUnixMilli, "1381134199120", reference, 0, false},
{TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false},
{TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false},
{TimeFormatUnixMilli, "abc", time.Time{}, 0, true},
{TimeFormatUnixMilli, false, time.Time{}, 0, true},
{TimeFormatUnixMicro, "1381134199120000", reference, 0, false},
{TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false},
{TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false},
{TimeFormatUnixMicro, "abc", time.Time{}, 0, true},
{TimeFormatUnixMicro, false, time.Time{}, 0, true},
{TimeFormatUnixNano, "1381134199120000000", reference, 0, false},
{TimeFormatUnixNano, 1381134199.120e9, reference, 0, false},
{TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false},
{TimeFormatUnixNano, "abc", time.Time{}, 0, true},
{TimeFormatUnixNano, false, time.Time{}, 0, true},
{TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false},
{TimeFormatAuto, "2456572", reference, 24 * time.Hour, false},
{TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false},
{TimeFormatAuto, "1381134199", reference, time.Second, false},
{TimeFormatAuto, "1381134199120", reference, 0, false},
{TimeFormatAuto, "1381134199120000", reference, 0, false},
{TimeFormatAuto, "1381134199120000000", reference, 0, false},
{TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{TimeFormatAuto, "04:23:19.12-04:00", reftime, 0, false},
{TimeFormatAuto, "abc", time.Time{}, 0, true},
{TimeFormatAuto, false, time.Time{}, 0, true},
{TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
{TimeFormat9, "04:23:19.12-04:00", reftime, 0, false},
{TimeFormat9, "08:23:19.12", reftime, 0, false},
{TimeFormat3, false, time.Time{}, 0, true},
{TimeFormat9, false, time.Time{}, 0, true},
{TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false},
{TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false},
{TimeFormatDefault, false, time.Time{}, 0, true},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
got, err := tt.fmt.Decode(tt.val)
if (err != nil) != tt.wantErr {
t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr)
return
}
if tt.want.Sub(got).Abs() > tt.wantDelta {
t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want)
}
})
}
}

157
tx.go
View File

@@ -4,9 +4,14 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"runtime"
"strconv"
)
// Tx is an in-progress database transaction.
//
// https://www.sqlite.org/lang_transaction.html
type Tx struct {
c *Conn
}
@@ -15,8 +20,9 @@ type Tx struct {
//
// https://www.sqlite.org/lang_transaction.html
func (c *Conn) Begin() Tx {
err := c.Exec(`BEGIN DEFERRED`)
if err != nil && !errors.Is(err, INTERRUPT) {
// BEGIN even if interrupted.
err := c.txExecInterrupted(`BEGIN DEFERRED`)
if err != nil {
panic(err)
}
return Tx{c}
@@ -63,21 +69,22 @@ func (tx Tx) End(errp *error) {
defer panic(recovered)
}
if tx.c.GetAutocommit() {
// There is nothing to commit/rollback.
return
}
if *errp == nil && recovered == nil {
if (errp == nil || *errp == nil) && recovered == nil {
// Success path.
if tx.c.GetAutocommit() { // There is nothing to commit.
return
}
*errp = tx.Commit()
if *errp == nil {
return
}
// Possible interrupt, fall through to the error path.
// Fall through to the error path.
}
// Error path.
if tx.c.GetAutocommit() { // There is nothing to rollback.
return
}
err := tx.Rollback()
if err != nil {
panic(err)
@@ -91,33 +98,28 @@ func (tx Tx) Commit() error {
return tx.c.Exec(`COMMIT`)
}
// Rollback rollsback the transaction.
// Rollback rolls back the transaction,
// even if the connection has been interrupted.
//
// https://www.sqlite.org/lang_transaction.html
func (tx Tx) Rollback() error {
// ROLLBACK even if the connection has been interrupted.
old := tx.c.SetInterrupt(context.Background())
defer tx.c.SetInterrupt(old)
return tx.c.Exec(`ROLLBACK`)
return tx.c.txExecInterrupted(`ROLLBACK`)
}
// Savepoint creates a named SQLite transaction using SAVEPOINT.
//
// On success Savepoint returns a release func that will call either
// RELEASE or ROLLBACK depending on whether the parameter *error
// points to a nil or non-nil error.
//
// This is meant to be deferred:
//
// func doWork(conn *sqlite3.Conn) (err error) {
// defer conn.Savepoint()(&err)
//
// // ... do work in the transaction
// }
// Savepoint is a marker within a transaction
// that allows for partial rollback.
//
// https://www.sqlite.org/lang_savepoint.html
func (c *Conn) Savepoint() (release func(*error)) {
name := "sqlite3.Savepoint" // names can be reused
type Savepoint struct {
c *Conn
name string
}
// Savepoint establishes a new transaction savepoint.
//
// https://www.sqlite.org/lang_savepoint.html
func (c *Conn) Savepoint() Savepoint {
name := "sqlite3.Savepoint"
var pc [1]uintptr
if n := runtime.Callers(2, pc[:]); n > 0 {
frames := runtime.CallersFrames(pc[:n])
@@ -126,52 +128,75 @@ func (c *Conn) Savepoint() (release func(*error)) {
name = frame.Function
}
}
// Names can be reused; this makes catching bugs more likely.
name += "#" + strconv.Itoa(int(rand.Int31()))
err := c.Exec(fmt.Sprintf("SAVEPOINT %q;", name))
err := c.txExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name))
if err != nil {
if errors.Is(err, INTERRUPT) {
return func(errp *error) {
if *errp == nil {
*errp = err
}
}
}
panic(err)
}
return Savepoint{c: c, name: name}
}
return func(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
// Release releases the savepoint rolling back any changes
// if *error points to a non-nil error.
//
// This is meant to be deferred:
//
// func doWork(conn *sqlite3.Conn) (err error) {
// savept := conn.Savepoint()
// defer savept.Release(&err)
//
// // ... do work in the transaction
// }
func (s Savepoint) Release(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
if c.GetAutocommit() {
// There is nothing to commit/rollback.
if (errp == nil || *errp == nil) && recovered == nil {
// Success path.
if s.c.GetAutocommit() { // There is nothing to commit.
return
}
if *errp == nil && recovered == nil {
// Success path.
// RELEASE the savepoint successfully.
*errp = c.Exec(fmt.Sprintf("RELEASE %q;", name))
if *errp == nil {
return
}
// Possible interrupt, fall through to the error path.
*errp = s.c.Exec(fmt.Sprintf("RELEASE %q;", s.name))
if *errp == nil {
return
}
// Fall through to the error path.
}
// Error path.
// Always ROLLBACK even if the connection has been interrupted.
old := c.SetInterrupt(context.Background())
defer c.SetInterrupt(old)
err := c.Exec(fmt.Sprintf("ROLLBACK TO %q;", name))
if err != nil {
panic(err)
}
err = c.Exec(fmt.Sprintf("RELEASE %q;", name))
if err != nil {
panic(err)
}
// Error path.
if s.c.GetAutocommit() { // There is nothing to rollback.
return
}
// ROLLBACK and RELEASE even if interrupted.
err := s.c.txExecInterrupted(fmt.Sprintf(`
ROLLBACK TO %[1]q;
RELEASE %[1]q;
`, s.name))
if err != nil {
panic(err)
}
}
// Rollback rolls the transaction back to the savepoint,
// even if the connection has been interrupted.
// Rollback does not release the savepoint.
//
// https://www.sqlite.org/lang_transaction.html
func (s Savepoint) Rollback() error {
// ROLLBACK even if interrupted.
return s.c.txExecInterrupted(fmt.Sprintf("ROLLBACK TO %q;", s.name))
}
func (c *Conn) txExecInterrupted(sql string) error {
err := c.Exec(sql)
if errors.Is(err, INTERRUPT) {
old := c.SetInterrupt(context.Background())
defer c.SetInterrupt(old)
err = c.Exec(sql)
}
return err
}

128
vfs.go
View File

@@ -9,7 +9,6 @@ import (
"os"
"path/filepath"
"runtime"
"syscall"
"time"
"github.com/ncruces/julianday"
@@ -26,6 +25,14 @@ func vfsInstantiate(ctx context.Context, r wazero.Runtime) {
panic(err)
}
env := vfsNewEnvModuleBuilder(r)
_, err = env.Instantiate(ctx)
if err != nil {
panic(err)
}
}
func vfsNewEnvModuleBuilder(r wazero.Runtime) wazero.HostModuleBuilder {
env := r.NewHostModuleBuilder("env")
env.NewFunctionBuilder().WithFunc(vfsLocaltime).Export("os_localtime")
env.NewFunctionBuilder().WithFunc(vfsRandomness).Export("os_randomness")
@@ -46,15 +53,39 @@ func vfsInstantiate(ctx context.Context, r wazero.Runtime) {
env.NewFunctionBuilder().WithFunc(vfsUnlock).Export("os_unlock")
env.NewFunctionBuilder().WithFunc(vfsCheckReservedLock).Export("os_check_reserved_lock")
env.NewFunctionBuilder().WithFunc(vfsFileControl).Export("os_file_control")
_, err = env.Instantiate(ctx)
if err != nil {
panic(err)
}
return env
}
type vfsOSMethods bool
// Poor man's namespaces.
const (
vfsOS vfsOSMethods = false
vfsFile vfsFileMethods = false
)
const vfsOS vfsOSMethods = false
type (
vfsOSMethods bool
vfsFileMethods bool
)
type vfsKey struct{}
type vfsState struct {
files []*os.File
}
func vfsContext(ctx context.Context) (context.Context, io.Closer) {
vfs := &vfsState{}
return context.WithValue(ctx, vfsKey{}, vfs), vfs
}
func (vfs *vfsState) Close() error {
for _, f := range vfs.files {
if f != nil {
f.Close()
}
}
vfs.files = nil
return nil
}
func vfsExit(ctx context.Context, mod api.Module, exitCode uint32) {
// Ensure other callers see the exit code.
@@ -134,66 +165,34 @@ func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32)
path := memory{mod}.readString(zPath, _MAX_PATHNAME)
err := os.Remove(path)
if errors.Is(err, fs.ErrNotExist) {
return _OK
return uint32(IOERR_DELETE_NOENT)
}
if err != nil {
return uint32(IOERR_DELETE)
}
if runtime.GOOS != "windows" && syncDir != 0 {
f, err := os.Open(filepath.Dir(path))
if err == nil {
err = f.Sync()
f.Close()
}
if err != nil {
return uint32(IOERR_DELETE)
return _OK
}
defer f.Close()
err = vfsOS.Sync(f, false, false)
if err != nil {
return uint32(IOERR_DIR_FSYNC)
}
}
return _OK
}
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags _AccessFlag, pResOut uint32) uint32 {
// Consider using [syscall.Access] for [ACCESS_READWRITE]/[ACCESS_READ]
// (as the Unix VFS does).
path := memory{mod}.readString(zPath, _MAX_PATHNAME)
fi, err := os.Stat(path)
ok, rc := vfsOS.Access(path, flags)
var res uint32
switch {
case flags == _ACCESS_EXISTS:
switch {
case err == nil:
res = 1
case errors.Is(err, fs.ErrNotExist):
res = 0
default:
return uint32(IOERR_ACCESS)
}
case err == nil:
var want fs.FileMode = syscall.S_IRUSR
if flags == _ACCESS_READWRITE {
want |= syscall.S_IWUSR
}
if fi.IsDir() {
want |= syscall.S_IXUSR
}
if fi.Mode()&want == want {
res = 1
} else {
res = 0
}
case errors.Is(err, fs.ErrPermission):
res = 0
default:
return uint32(IOERR_ACCESS)
if ok {
res = 1
}
memory{mod}.writeUint32(pResOut, res)
return _OK
return uint32(rc)
}
func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, flags OpenFlag, pOutFlags uint32) uint32 {
@@ -217,18 +216,17 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, fla
file, err = os.CreateTemp("", "*.db")
} else {
name := memory{mod}.readString(zName, _MAX_PATHNAME)
file, err = os.OpenFile(name, oflags, 0600)
file, err = vfsOS.OpenFile(name, oflags, 0600)
}
if err != nil {
return uint32(CANTOPEN)
}
if flags&OPEN_DELETEONCLOSE != 0 {
vfsOS.DeleteOnClose(file)
os.Remove(file.Name())
}
id := vfsGetFileID(file)
vfsFilePtr{mod, pFile}.SetID(id).SetLock(_NO_LOCK)
vfsFile.Open(ctx, mod, pFile, file)
if pOutFlags != 0 {
memory{mod}.writeUint32(pOutFlags, uint32(flags))
@@ -237,8 +235,7 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, fla
}
func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 {
id := vfsFilePtr{mod, pFile}.ID()
err := vfsCloseFile(id)
err := vfsFile.Close(ctx, mod, pFile)
if err != nil {
return uint32(IOERR_CLOSE)
}
@@ -248,7 +245,7 @@ func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 {
func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 {
buf := memory{mod}.view(zBuf, uint64(iAmt))
file := vfsFilePtr{mod, pFile}.OSFile()
file := vfsFile.GetOS(ctx, mod, pFile)
n, err := file.ReadAt(buf, int64(iOfst))
if n == int(iAmt) {
return _OK
@@ -265,7 +262,7 @@ func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfs
func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 {
buf := memory{mod}.view(zBuf, uint64(iAmt))
file := vfsFilePtr{mod, pFile}.OSFile()
file := vfsFile.GetOS(ctx, mod, pFile)
_, err := file.WriteAt(buf, int64(iOfst))
if err != nil {
return uint32(IOERR_WRITE)
@@ -274,7 +271,7 @@ func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOf
}
func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte uint64) uint32 {
file := vfsFilePtr{mod, pFile}.OSFile()
file := vfsFile.GetOS(ctx, mod, pFile)
err := file.Truncate(int64(nByte))
if err != nil {
return uint32(IOERR_TRUNCATE)
@@ -282,9 +279,11 @@ func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte uint64
return _OK
}
func vfsSync(ctx context.Context, mod api.Module, pFile, flags uint32) uint32 {
file := vfsFilePtr{mod, pFile}.OSFile()
err := file.Sync()
func vfsSync(ctx context.Context, mod api.Module, pFile uint32, flags _SyncFlag) uint32 {
dataonly := (flags & _SYNC_DATAONLY) != 0
fullsync := (flags & 0x0f) == _SYNC_FULL
file := vfsFile.GetOS(ctx, mod, pFile)
err := vfsOS.Sync(file, fullsync, dataonly)
if err != nil {
return uint32(IOERR_FSYNC)
}
@@ -292,10 +291,7 @@ func vfsSync(ctx context.Context, mod api.Module, pFile, flags uint32) uint32 {
}
func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) uint32 {
// This uses [os.File.Seek] because we don't care about the offset for reading/writing.
// But consider using [os.File.Stat] instead (as other VFSes do).
file := vfsFilePtr{mod, pFile}.OSFile()
file := vfsFile.GetOS(ctx, mod, pFile)
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return uint32(IOERR_SEEK)

57
vfs_file.go Normal file
View File

@@ -0,0 +1,57 @@
package sqlite3
import (
"context"
"os"
"github.com/tetratelabs/wazero/api"
)
func (vfsFileMethods) NewID(ctx context.Context, file *os.File) uint32 {
vfs := ctx.Value(vfsKey{}).(*vfsState)
// Find an empty slot.
for id, ptr := range vfs.files {
if ptr == nil {
vfs.files[id] = file
return uint32(id)
}
}
// Add a new slot.
vfs.files = append(vfs.files, file)
return uint32(len(vfs.files) - 1)
}
func (vfsFileMethods) Open(ctx context.Context, mod api.Module, pFile uint32, file *os.File) {
mem := memory{mod}
id := vfsFile.NewID(ctx, file)
mem.writeUint32(pFile+ptrlen, id)
mem.writeUint32(pFile+2*ptrlen, _NO_LOCK)
}
func (vfsFileMethods) Close(ctx context.Context, mod api.Module, pFile uint32) error {
mem := memory{mod}
id := mem.readUint32(pFile + ptrlen)
vfs := ctx.Value(vfsKey{}).(*vfsState)
file := vfs.files[id]
vfs.files[id] = nil
return file.Close()
}
func (vfsFileMethods) GetOS(ctx context.Context, mod api.Module, pFile uint32) *os.File {
mem := memory{mod}
id := mem.readUint32(pFile + ptrlen)
vfs := ctx.Value(vfsKey{}).(*vfsState)
return vfs.files[id]
}
func (vfsFileMethods) GetLock(ctx context.Context, mod api.Module, pFile uint32) vfsLockState {
mem := memory{mod}
return vfsLockState(mem.readUint32(pFile + 2*ptrlen))
}
func (vfsFileMethods) SetLock(ctx context.Context, mod api.Module, pFile uint32, lock vfsLockState) {
mem := memory{mod}
mem.writeUint32(pFile+2*ptrlen, uint32(lock))
}

View File

@@ -1,69 +0,0 @@
package sqlite3
import (
"os"
"sync"
"github.com/tetratelabs/wazero/api"
)
var (
vfsOpenFiles []*os.File
vfsOpenFilesMtx sync.Mutex
)
func vfsGetFileID(file *os.File) uint32 {
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
// Find an empty slot.
for id, ptr := range vfsOpenFiles {
if ptr == nil {
vfsOpenFiles[id] = file
return uint32(id)
}
}
// Add a new slot.
vfsOpenFiles = append(vfsOpenFiles, file)
return uint32(len(vfsOpenFiles) - 1)
}
func vfsCloseFile(id uint32) error {
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
file := vfsOpenFiles[id]
vfsOpenFiles[id] = nil
return file.Close()
}
type vfsFilePtr struct {
api.Module
ptr uint32
}
func (p vfsFilePtr) OSFile() *os.File {
id := p.ID()
vfsOpenFilesMtx.Lock()
defer vfsOpenFilesMtx.Unlock()
return vfsOpenFiles[id]
}
func (p vfsFilePtr) ID() uint32 {
return memory{p}.readUint32(p.ptr + ptrlen)
}
func (p vfsFilePtr) Lock() vfsLockState {
return vfsLockState(memory{p}.readUint32(p.ptr + 2*ptrlen))
}
func (p vfsFilePtr) SetID(id uint32) vfsFilePtr {
memory{p}.writeUint32(p.ptr+ptrlen, id)
return p
}
func (p vfsFilePtr) SetLock(lock vfsLockState) vfsFilePtr {
memory{p}.writeUint32(p.ptr+2*ptrlen, uint32(lock))
return p
}

View File

@@ -61,9 +61,8 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
panic(assertErr())
}
ptr := vfsFilePtr{mod, pFile}
file := ptr.OSFile()
cLock := ptr.Lock()
file := vfsFile.GetOS(ctx, mod, pFile)
cLock := vfsFile.GetLock(ctx, mod, pFile)
switch {
case cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK:
@@ -95,7 +94,7 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
if rc := vfsOS.GetSharedLock(file); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_SHARED_LOCK)
vfsFile.SetLock(ctx, mod, pFile, _SHARED_LOCK)
return _OK
case _RESERVED_LOCK:
@@ -106,7 +105,7 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
if rc := vfsOS.GetReservedLock(file); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_RESERVED_LOCK)
vfsFile.SetLock(ctx, mod, pFile, _RESERVED_LOCK)
return _OK
case _EXCLUSIVE_LOCK:
@@ -119,12 +118,12 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta
if rc := vfsOS.GetPendingLock(file); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_PENDING_LOCK)
vfsFile.SetLock(ctx, mod, pFile, _PENDING_LOCK)
}
if rc := vfsOS.GetExclusiveLock(file); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_EXCLUSIVE_LOCK)
vfsFile.SetLock(ctx, mod, pFile, _EXCLUSIVE_LOCK)
return _OK
default:
@@ -138,9 +137,8 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS
panic(assertErr())
}
ptr := vfsFilePtr{mod, pFile}
file := ptr.OSFile()
cLock := ptr.Lock()
file := vfsFile.GetOS(ctx, mod, pFile)
cLock := vfsFile.GetLock(ctx, mod, pFile)
// Connection state check.
if cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK {
@@ -157,12 +155,12 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS
if rc := vfsOS.DowngradeLock(file, cLock); rc != _OK {
return uint32(rc)
}
ptr.SetLock(_SHARED_LOCK)
vfsFile.SetLock(ctx, mod, pFile, _SHARED_LOCK)
return _OK
case _NO_LOCK:
rc := vfsOS.ReleaseLock(file, cLock)
ptr.SetLock(_NO_LOCK)
vfsFile.SetLock(ctx, mod, pFile, _NO_LOCK)
return uint32(rc)
default:
@@ -171,14 +169,13 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS
}
func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) uint32 {
ptr := vfsFilePtr{mod, pFile}
cLock := ptr.Lock()
cLock := vfsFile.GetLock(ctx, mod, pFile)
if cLock > _SHARED_LOCK {
panic(assertErr())
}
file := ptr.OSFile()
file := vfsFile.GetOS(ctx, mod, pFile)
locked, rc := vfsOS.CheckReservedLock(file)
var res uint32

View File

@@ -38,10 +38,13 @@ func Test_vfsLock(t *testing.T) {
pOutput = 32
)
mem := newMemory(128)
vfsFilePtr{mem.mod, pFile1}.SetID(vfsGetFileID(file1)).SetLock(_NO_LOCK)
vfsFilePtr{mem.mod, pFile2}.SetID(vfsGetFileID(file2)).SetLock(_NO_LOCK)
ctx, vfs := vfsContext(context.TODO())
defer vfs.Close()
rc := vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
vfsFile.Open(ctx, mem.mod, pFile1, file1)
vfsFile.Open(ctx, mem.mod, pFile2, file2)
rc := vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -49,12 +52,12 @@ func Test_vfsLock(t *testing.T) {
t.Error("file was locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile2, _SHARED_LOCK)
rc = vfsLock(ctx, mem.mod, pFile2, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -62,16 +65,16 @@ func Test_vfsLock(t *testing.T) {
t.Error("file was locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile2, _RESERVED_LOCK)
rc = vfsLock(ctx, mem.mod, pFile2, _RESERVED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsLock(context.TODO(), mem.mod, pFile2, _SHARED_LOCK)
rc = vfsLock(ctx, mem.mod, pFile2, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -79,12 +82,12 @@ func Test_vfsLock(t *testing.T) {
t.Error("file wasn't locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile2, _EXCLUSIVE_LOCK)
rc = vfsLock(ctx, mem.mod, pFile2, _EXCLUSIVE_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -92,12 +95,12 @@ func Test_vfsLock(t *testing.T) {
t.Error("file wasn't locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile1, _SHARED_LOCK)
rc = vfsLock(ctx, mem.mod, pFile1, _SHARED_LOCK)
if rc == _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -105,12 +108,12 @@ func Test_vfsLock(t *testing.T) {
t.Error("file wasn't locked")
}
rc = vfsUnlock(context.TODO(), mem.mod, pFile2, _SHARED_LOCK)
rc = vfsUnlock(ctx, mem.mod, pFile2, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(context.TODO(), mem.mod, pFile1, pOutput)
rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -118,7 +121,7 @@ func Test_vfsLock(t *testing.T) {
t.Error("file was locked")
}
rc = vfsLock(context.TODO(), mem.mod, pFile1, _SHARED_LOCK)
rc = vfsLock(ctx, mem.mod, pFile1, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}

View File

@@ -16,15 +16,17 @@ import (
func Test_vfsExit(t *testing.T) {
mem := newMemory(128)
ctx := context.TODO()
defer func() { _ = recover() }()
vfsExit(context.TODO(), mem.mod, 1)
vfsExit(ctx, mem.mod, 1)
t.Error("want panic")
}
func Test_vfsLocaltime(t *testing.T) {
mem := newMemory(128)
ctx := context.TODO()
rc := vfsLocaltime(context.TODO(), mem.mod, 0, 4)
rc := vfsLocaltime(ctx, mem.mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
@@ -71,24 +73,26 @@ func Test_vfsRandomness(t *testing.T) {
}
func Test_vfsSleep(t *testing.T) {
start := time.Now()
ctx := context.TODO()
rc := vfsSleep(context.TODO(), 0, 123456)
now := time.Now()
rc := vfsSleep(ctx, 0, 123456)
if rc != 0 {
t.Fatal("returned", rc)
}
want := 123456 * time.Microsecond
if got := time.Since(start); got < want {
if got := time.Since(now); got < want {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsCurrentTime(t *testing.T) {
mem := newMemory(128)
ctx := context.TODO()
now := time.Now()
rc := vfsCurrentTime(context.TODO(), mem.mod, 0, 4)
rc := vfsCurrentTime(ctx, mem.mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
@@ -101,10 +105,11 @@ func Test_vfsCurrentTime(t *testing.T) {
func Test_vfsCurrentTime64(t *testing.T) {
mem := newMemory(128)
ctx := context.TODO()
now := time.Now()
time.Sleep(time.Millisecond)
rc := vfsCurrentTime64(context.TODO(), mem.mod, 0, 4)
rc := vfsCurrentTime64(ctx, mem.mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
@@ -119,13 +124,14 @@ func Test_vfsCurrentTime64(t *testing.T) {
func Test_vfsFullPathname(t *testing.T) {
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(4, ".")
ctx := context.TODO()
rc := vfsFullPathname(context.TODO(), mem.mod, 0, 4, 0, 8)
rc := vfsFullPathname(ctx, mem.mod, 0, 4, 0, 8)
if rc != uint32(CANTOPEN_FULLPATH) {
t.Errorf("returned %d, want %d", rc, CANTOPEN_FULLPATH)
}
rc = vfsFullPathname(context.TODO(), mem.mod, 0, 4, _MAX_PATHNAME, 8)
rc = vfsFullPathname(ctx, mem.mod, 0, 4, _MAX_PATHNAME, 8)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -147,8 +153,9 @@ func Test_vfsDelete(t *testing.T) {
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(4, name)
ctx := context.TODO()
rc := vfsDelete(context.TODO(), mem.mod, 0, 4, 1)
rc := vfsDelete(ctx, mem.mod, 0, 4, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -157,8 +164,8 @@ func Test_vfsDelete(t *testing.T) {
t.Fatal("did not delete the file")
}
rc = vfsDelete(context.TODO(), mem.mod, 0, 4, 1)
if rc != _OK {
rc = vfsDelete(ctx, mem.mod, 0, 4, 1)
if rc != uint32(IOERR_DELETE_NOENT) {
t.Fatal("returned", rc)
}
}
@@ -177,8 +184,9 @@ func Test_vfsAccess(t *testing.T) {
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(8, dir)
ctx := context.TODO()
rc := vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_EXISTS, 4)
rc := vfsAccess(ctx, mem.mod, 0, 8, _ACCESS_EXISTS, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -186,7 +194,7 @@ func Test_vfsAccess(t *testing.T) {
t.Error("directory did not exist")
}
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_READWRITE, 4)
rc = vfsAccess(ctx, mem.mod, 0, 8, _ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -195,7 +203,7 @@ func Test_vfsAccess(t *testing.T) {
}
mem.writeString(8, file)
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, _ACCESS_READWRITE, 4)
rc = vfsAccess(ctx, mem.mod, 0, 8, _ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -206,9 +214,11 @@ func Test_vfsAccess(t *testing.T) {
func Test_vfsFile(t *testing.T) {
mem := newMemory(128)
ctx, vfs := vfsContext(context.TODO())
defer vfs.Close()
// Open a temporary file.
rc := vfsOpen(context.TODO(), mem.mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)
rc := vfsOpen(ctx, mem.mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -216,13 +226,13 @@ func Test_vfsFile(t *testing.T) {
// Write stuff.
text := "Hello world!"
mem.writeString(16, text)
rc = vfsWrite(context.TODO(), mem.mod, 4, 16, uint32(len(text)), 0)
rc = vfsWrite(ctx, mem.mod, 4, 16, uint32(len(text)), 0)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check file size.
rc = vfsFileSize(context.TODO(), mem.mod, 4, 16)
rc = vfsFileSize(ctx, mem.mod, 4, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -231,7 +241,7 @@ func Test_vfsFile(t *testing.T) {
}
// Partial read at offset.
rc = vfsRead(context.TODO(), mem.mod, 4, 16, uint32(len(text)), 4)
rc = vfsRead(ctx, mem.mod, 4, 16, uint32(len(text)), 4)
if rc != uint32(IOERR_SHORT_READ) {
t.Fatal("returned", rc)
}
@@ -240,13 +250,13 @@ func Test_vfsFile(t *testing.T) {
}
// Truncate the file.
rc = vfsTruncate(context.TODO(), mem.mod, 4, 4)
rc = vfsTruncate(ctx, mem.mod, 4, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check file size.
rc = vfsFileSize(context.TODO(), mem.mod, 4, 16)
rc = vfsFileSize(ctx, mem.mod, 4, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -255,7 +265,7 @@ func Test_vfsFile(t *testing.T) {
}
// Read at offset.
rc = vfsRead(context.TODO(), mem.mod, 4, 32, 4, 0)
rc = vfsRead(ctx, mem.mod, 4, 32, 4, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -264,7 +274,7 @@ func Test_vfsFile(t *testing.T) {
}
// Close the file.
rc = vfsClose(context.TODO(), mem.mod, 4)
rc = vfsClose(ctx, mem.mod, 4)
if rc != _OK {
t.Fatal("returned", rc)
}

View File

@@ -3,13 +3,46 @@
package sqlite3
import (
"io/fs"
"os"
"runtime"
"syscall"
"golang.org/x/sys/unix"
)
func (vfsOSMethods) DeleteOnClose(file *os.File) {
_ = os.Remove(file.Name())
func (vfsOSMethods) OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(name, flag, perm)
}
func (vfsOSMethods) Access(path string, flags _AccessFlag) (bool, xErrorCode) {
var access uint32 = unix.F_OK
switch flags {
case _ACCESS_READWRITE:
access = unix.R_OK | unix.W_OK
case _ACCESS_READ:
access = unix.R_OK
}
err := unix.Access(path, access)
if err == nil {
return true, _OK
}
return false, _OK
}
func (vfsOSMethods) Sync(file *os.File, fullsync, dataonly bool) error {
if runtime.GOOS == "darwin" && !fullsync {
return unix.Fsync(int(file.Fd()))
}
if runtime.GOOS == "linux" && dataonly {
//lint:ignore SA1019 OK on linux
_, _, err := unix.Syscall(unix.SYS_FDATASYNC, file.Fd(), 0, 0)
if err != 0 {
return err
}
return nil
}
return file.Sync()
}
func (vfsOSMethods) GetExclusiveLock(file *os.File) xErrorCode {
@@ -39,8 +72,8 @@ func (vfsOSMethods) ReleaseLock(file *os.File, _ vfsLockState) xErrorCode {
}
func (vfsOSMethods) unlock(file *os.File, start, len int64) xErrorCode {
err := vfsOS.fcntlSetLock(file, &syscall.Flock_t{
Type: syscall.F_UNLCK,
err := vfsOS.fcntlSetLock(file, &unix.Flock_t{
Type: unix.F_UNLCK,
Start: start,
Len: len,
})
@@ -51,85 +84,85 @@ func (vfsOSMethods) unlock(file *os.File, start, len int64) xErrorCode {
}
func (vfsOSMethods) readLock(file *os.File, start, len int64) xErrorCode {
return vfsOS.lockErrorCode(vfsOS.fcntlSetLock(file, &syscall.Flock_t{
Type: syscall.F_RDLCK,
return vfsOS.lockErrorCode(vfsOS.fcntlSetLock(file, &unix.Flock_t{
Type: unix.F_RDLCK,
Start: start,
Len: len,
}), IOERR_RDLOCK)
}
func (vfsOSMethods) writeLock(file *os.File, start, len int64) xErrorCode {
return vfsOS.lockErrorCode(vfsOS.fcntlSetLock(file, &syscall.Flock_t{
Type: syscall.F_WRLCK,
return vfsOS.lockErrorCode(vfsOS.fcntlSetLock(file, &unix.Flock_t{
Type: unix.F_WRLCK,
Start: start,
Len: len,
}), IOERR_LOCK)
}
func (vfsOSMethods) checkLock(file *os.File, start, len int64) (bool, xErrorCode) {
lock := syscall.Flock_t{
Type: syscall.F_RDLCK,
lock := unix.Flock_t{
Type: unix.F_RDLCK,
Start: start,
Len: len,
}
if vfsOS.fcntlGetLock(file, &lock) != nil {
return false, IOERR_CHECKRESERVEDLOCK
}
return lock.Type != syscall.F_UNLCK, _OK
return lock.Type != unix.F_UNLCK, _OK
}
func (vfsOSMethods) fcntlGetLock(file *os.File, lock *syscall.Flock_t) error {
func (vfsOSMethods) fcntlGetLock(file *os.File, lock *unix.Flock_t) error {
var F_OFD_GETLK int
switch runtime.GOOS {
case "linux":
// https://github.com/torvalds/linux/blob/master/include/uapi/asm-generic/fcntl.h
F_OFD_GETLK = 36 // F_OFD_GETLK
F_OFD_GETLK = 36
case "darwin":
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
F_OFD_GETLK = 92 // F_OFD_GETLK
F_OFD_GETLK = 92
case "illumos":
// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/fcntl.h
F_OFD_GETLK = 47 // F_OFD_GETLK
F_OFD_GETLK = 47
default:
return notImplErr
}
return syscall.FcntlFlock(file.Fd(), F_OFD_GETLK, lock)
return unix.FcntlFlock(file.Fd(), F_OFD_GETLK, lock)
}
func (vfsOSMethods) fcntlSetLock(file *os.File, lock *syscall.Flock_t) error {
func (vfsOSMethods) fcntlSetLock(file *os.File, lock *unix.Flock_t) error {
var F_OFD_SETLK int
switch runtime.GOOS {
case "linux":
// https://github.com/torvalds/linux/blob/master/include/uapi/asm-generic/fcntl.h
F_OFD_SETLK = 37 // F_OFD_SETLK
F_OFD_SETLK = 37
case "darwin":
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
F_OFD_SETLK = 90 // F_OFD_SETLK
F_OFD_SETLK = 90
case "illumos":
// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/fcntl.h
F_OFD_SETLK = 48 // F_OFD_SETLK
F_OFD_SETLK = 48
default:
return notImplErr
}
return syscall.FcntlFlock(file.Fd(), F_OFD_SETLK, lock)
return unix.FcntlFlock(file.Fd(), F_OFD_SETLK, lock)
}
func (vfsOSMethods) lockErrorCode(err error, def xErrorCode) xErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(syscall.Errno); ok {
if errno, ok := err.(unix.Errno); ok {
switch errno {
case
syscall.EACCES,
syscall.EAGAIN,
syscall.EBUSY,
syscall.EINTR,
syscall.ENOLCK,
syscall.EDEADLK,
syscall.ETIMEDOUT:
unix.EACCES,
unix.EAGAIN,
unix.EBUSY,
unix.EINTR,
unix.ENOLCK,
unix.EDEADLK,
unix.ETIMEDOUT:
return xErrorCode(BUSY)
case syscall.EPERM:
case unix.EPERM:
return xErrorCode(PERM)
}
}

View File

@@ -1,13 +1,69 @@
package sqlite3
import (
"errors"
"io/fs"
"os"
"syscall"
"golang.org/x/sys/windows"
)
func (vfsOSMethods) DeleteOnClose(file *os.File) {}
// OpenFile is a simplified copy of [os.openFileNolog]
// that uses syscall.FILE_SHARE_DELETE.
// https://go.dev/src/os/file_windows.go
//
// See: https://go.dev/issue/32088
func (vfsOSMethods) OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
if name == "" {
return nil, &os.PathError{Op: "open", Path: name, Err: syscall.ENOENT}
}
r, e := syscallOpen(name, flag, uint32(perm.Perm()))
if e != nil {
return nil, &os.PathError{Op: "open", Path: name, Err: e}
}
return os.NewFile(uintptr(r), name), nil
}
func (vfsOSMethods) Sync(file *os.File, fullsync, dataonly bool) error {
return file.Sync()
}
func (vfsOSMethods) Access(path string, flags _AccessFlag) (bool, xErrorCode) {
fi, err := os.Stat(path)
switch {
case flags == _ACCESS_EXISTS:
switch {
case err == nil:
return true, _OK
case errors.Is(err, fs.ErrNotExist):
return false, _OK
default:
return false, IOERR_ACCESS
}
case err == nil:
var want fs.FileMode = syscall.S_IRUSR
if flags == _ACCESS_READWRITE {
want |= syscall.S_IWUSR
}
if fi.IsDir() {
want |= syscall.S_IXUSR
}
if fi.Mode()&want == want {
return true, _OK
} else {
return false, _OK
}
case errors.Is(err, fs.ErrPermission):
return false, _OK
default:
return false, IOERR_ACCESS
}
}
func (vfsOSMethods) GetExclusiveLock(file *os.File) xErrorCode {
// Release the SHARED lock.
@@ -63,6 +119,9 @@ func (vfsOSMethods) ReleaseLock(file *os.File, state vfsLockState) xErrorCode {
func (vfsOSMethods) unlock(file *os.File, start, len uint32) xErrorCode {
err := windows.UnlockFileEx(windows.Handle(file.Fd()),
0, len, 0, &windows.Overlapped{Offset: start})
if err == windows.ERROR_NOT_LOCKED {
return _OK
}
if err != nil {
return IOERR_UNLOCK
}
@@ -95,8 +154,66 @@ func (vfsOSMethods) lockErrorCode(err error, def xErrorCode) xErrorCode {
if err == nil {
return _OK
}
if errno, _ := err.(syscall.Errno); errno == windows.ERROR_INVALID_HANDLE {
return def
if errno, ok := err.(syscall.Errno); ok {
// https://devblogs.microsoft.com/oldnewthing/20140905-00/?p=63
switch errno {
case
windows.ERROR_LOCK_VIOLATION,
windows.ERROR_IO_PENDING:
return xErrorCode(BUSY)
}
}
return xErrorCode(BUSY)
return def
}
// syscallOpen is a simplified copy of [syscall.Open]
// that uses syscall.FILE_SHARE_DELETE.
// https://go.dev/src/syscall/syscall_windows.go
func syscallOpen(path string, mode int, perm uint32) (fd syscall.Handle, err error) {
if len(path) == 0 {
return syscall.InvalidHandle, syscall.ERROR_FILE_NOT_FOUND
}
pathp, err := syscall.UTF16PtrFromString(path)
if err != nil {
return syscall.InvalidHandle, err
}
var access uint32
switch mode & (syscall.O_RDONLY | syscall.O_WRONLY | syscall.O_RDWR) {
case syscall.O_RDONLY:
access = syscall.GENERIC_READ
case syscall.O_WRONLY:
access = syscall.GENERIC_WRITE
case syscall.O_RDWR:
access = syscall.GENERIC_READ | syscall.GENERIC_WRITE
}
if mode&syscall.O_CREAT != 0 {
access |= syscall.GENERIC_WRITE
}
if mode&syscall.O_APPEND != 0 {
access &^= syscall.GENERIC_WRITE
access |= syscall.FILE_APPEND_DATA
}
sharemode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE)
var createmode uint32
switch {
case mode&(syscall.O_CREAT|syscall.O_EXCL) == (syscall.O_CREAT | syscall.O_EXCL):
createmode = syscall.CREATE_NEW
case mode&(syscall.O_CREAT|syscall.O_TRUNC) == (syscall.O_CREAT | syscall.O_TRUNC):
createmode = syscall.CREATE_ALWAYS
case mode&syscall.O_CREAT == syscall.O_CREAT:
createmode = syscall.OPEN_ALWAYS
case mode&syscall.O_TRUNC == syscall.O_TRUNC:
createmode = syscall.TRUNCATE_EXISTING
default:
createmode = syscall.OPEN_EXISTING
}
var attrs uint32 = syscall.FILE_ATTRIBUTE_NORMAL
if perm&syscall.S_IWRITE == 0 {
attrs = syscall.FILE_ATTRIBUTE_READONLY
}
if createmode == syscall.OPEN_EXISTING && access == syscall.GENERIC_READ {
// Necessary for opening directory handles.
attrs |= syscall.FILE_FLAG_BACKUP_SEMANTICS
}
return syscall.CreateFile(pathp, access, sharemode, nil, createmode, attrs, 0)
}